diff --git a/benchmarks/README.md b/benchmarks/README.md index af72d16d2ad4b..c5b8f5b9d2321 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -152,11 +152,10 @@ exit Grab the async profiler from https://github.com/jvm-profiling-tools/async-profiler and run `prof async` like so: ``` -gradlew -p benchmarks/ run --args 'LongKeyedBucketOrdsBenchmark.multiBucket -prof "async:libPath=/home/nik9000/Downloads/async-profiler-3.0-29ee888-linux-x64/lib/libasyncProfiler.so;dir=/tmp/prof;output=flamegraph"' +gradlew -p benchmarks/ run --args 'LongKeyedBucketOrdsBenchmark.multiBucket -prof "async:libPath=/home/nik9000/Downloads/async-profiler-4.0-linux-x64/lib/libasyncProfiler.so;dir=/tmp/prof;output=flamegraph"' ``` -Note: As of January 2025 the latest release of async profiler doesn't work - with our JDK but the nightly is fine. +Note: As of July 2025 the 4.0 release of the async profiler works well. If you are on Mac, this'll warn you that you downloaded the shared library from the internet. You'll need to go to settings and allow it to run. diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java similarity index 91% rename from benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java rename to benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java index d592366835c08..94483a136a5d2 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.benchmark.compute.operator; +package org.elasticsearch.benchmark._nightly.esql; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.NumericDocValuesField; @@ -24,8 +24,10 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.NumericUtils; import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; @@ -85,10 +87,23 @@ @State(Scope.Thread) @Fork(1) public class ValuesSourceReaderBenchmark { + static { + LogConfigurator.configureESLogging(); + } + + private static final String[] SUPPORTED_LAYOUTS = new String[] { "in_order", "shuffled", "shuffled_singles" }; + private static final String[] SUPPORTED_NAMES = new String[] { + "long", + "int", + "double", + "keyword", + "stored_keyword", + "3_stored_keywords", + "keyword_mv" }; + private static final int BLOCK_LENGTH = 16 * 1024; private static final int INDEX_SIZE = 10 * BLOCK_LENGTH; private static final int COMMIT_INTERVAL = 500; - private static final BigArrays BIG_ARRAYS = BigArrays.NON_RECYCLING_INSTANCE; private static final BlockFactory blockFactory = BlockFactory.getInstance( new NoopCircuitBreaker("noop"), BigArrays.NON_RECYCLING_INSTANCE @@ -104,8 +119,8 @@ static void selfTest() { ValuesSourceReaderBenchmark benchmark = new ValuesSourceReaderBenchmark(); benchmark.setupIndex(); try { - for (String layout : ValuesSourceReaderBenchmark.class.getField("layout").getAnnotationsByType(Param.class)[0].value()) { - for (String name : ValuesSourceReaderBenchmark.class.getField("name").getAnnotationsByType(Param.class)[0].value()) { + for (String layout : ValuesSourceReaderBenchmark.SUPPORTED_LAYOUTS) { + for (String name : ValuesSourceReaderBenchmark.SUPPORTED_NAMES) { benchmark.layout = layout; benchmark.name = name; try { @@ -119,7 +134,7 @@ static void selfTest() { } finally { benchmark.teardownIndex(); } - } catch (IOException | NoSuchFieldException e) { + } catch (IOException e) { throw new AssertionError(e); } } @@ -321,10 +336,10 @@ public FieldNamesFieldMapper.FieldNamesFieldType fieldNames() { * each page has a single document rather than {@code BLOCK_SIZE} docs. * */ - @Param({ "in_order", "shuffled", "shuffled_singles" }) + @Param({ "in_order", "shuffled" }) public String layout; - @Param({ "long", "int", "double", "keyword", "stored_keyword", "3_stored_keywords" }) + @Param({ "long", "keyword", "stored_keyword", "keyword_mv" }) public String name; private Directory directory; @@ -336,6 +351,7 @@ public FieldNamesFieldMapper.FieldNamesFieldType fieldNames() { public void benchmark() { ValuesSourceReaderOperator op = new ValuesSourceReaderOperator( blockFactory, + ByteSizeValue.ofMb(1).getBytes(), fields(name), List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> { throw new UnsupportedOperationException("can't load _source here"); @@ -390,6 +406,22 @@ public void benchmark() { } } } + case "keyword_mv" -> { + BytesRef scratch = new BytesRef(); + BytesRefBlock values = op.getOutput().getBlock(1); + for (int p = 0; p < values.getPositionCount(); p++) { + int count = values.getValueCount(p); + if (count > 0) { + int first = values.getFirstValueIndex(p); + for (int i = 0; i < count; i++) { + BytesRef r = values.getBytesRef(first + i, scratch); + r.offset++; + r.length--; + sum += Integer.parseInt(r.utf8ToString()); + } + } + } + } } } long expected = 0; @@ -399,6 +431,16 @@ public void benchmark() { expected += i % 1000; } break; + case "keyword_mv": + for (int i = 0; i < INDEX_SIZE; i++) { + int v1 = i % 1000; + expected += v1; + int v2 = i % 500; + if (v1 != v2) { + expected += v2; + } + } + break; case "3_stored_keywords": for (int i = 0; i < INDEX_SIZE; i++) { expected += 3 * (i % 1000); @@ -453,7 +495,9 @@ private void setupIndex() throws IOException { new StoredField("double", (double) i), new KeywordFieldMapper.KeywordField("keyword_1", new BytesRef(c + i % 1000), keywordFieldType), new KeywordFieldMapper.KeywordField("keyword_2", new BytesRef(c + i % 1000), keywordFieldType), - new KeywordFieldMapper.KeywordField("keyword_3", new BytesRef(c + i % 1000), keywordFieldType) + new KeywordFieldMapper.KeywordField("keyword_3", new BytesRef(c + i % 1000), keywordFieldType), + new KeywordFieldMapper.KeywordField("keyword_mv", new BytesRef(c + i % 1000), keywordFieldType), + new KeywordFieldMapper.KeywordField("keyword_mv", new BytesRef(c + i % 500), keywordFieldType) ) ); if (i % COMMIT_INTERVAL == 0) { diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index d144d7601349d..b447fd29bffa5 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -191,11 +191,12 @@ private static Operator operator(DriverContext driverContext, String grouping, S new BlockHash.GroupSpec(2, ElementType.BYTES_REF) ); case TOP_N_LONGS -> List.of( - new BlockHash.GroupSpec(0, ElementType.LONG, false, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT)) + new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT), null) ); default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]"); }; return new HashAggregationOperator( + groups, List.of(supplier(op, dataType, filter).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(groups.size()))), () -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false), driverContext diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 879418e7f954c..4d13389fed8dc 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -95,8 +95,7 @@ static void selfTest() { try { for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) { for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) { - run(Integer.parseInt(groups), dataType, 10, 0); - run(Integer.parseInt(groups), dataType, 10, 1); + run(Integer.parseInt(groups), dataType, 10); } } } catch (NoSuchFieldException e) { @@ -114,10 +113,7 @@ static void selfTest() { @Param({ BYTES_REF, INT, LONG }) public String dataType; - @Param({ "0", "1" }) - public int numOrdinalMerges; - - private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) { + private static Operator operator(DriverContext driverContext, int groups, String dataType) { if (groups == 1) { return new AggregationOperator( List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), @@ -126,26 +122,15 @@ private static Operator operator(DriverContext driverContext, int groups, String } List groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); return new HashAggregationOperator( + groupSpec, List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), driverContext ) { @Override public Page getOutput() { - mergeOrdinal(); return super.getOutput(); } - - // simulate OrdinalsGroupingOperator - void mergeOrdinal() { - var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext); - for (int i = 0; i < numOrdinalMerges; i++) { - for (int p = 0; p < groups; p++) { - merged.addIntermediateRow(p, aggregators.getFirst(), p); - } - } - aggregators.set(0, merged); - } }; } @@ -352,12 +337,12 @@ private static Block groupingBlock(int groups) { @Benchmark public void run() { - run(groups, dataType, OP_COUNT, numOrdinalMerges); + run(groups, dataType, OP_COUNT); } - private static void run(int groups, String dataType, int opCount, int numOrdinalMerges) { + private static void run(int groups, String dataType, int opCount) { DriverContext driverContext = driverContext(); - try (Operator operator = operator(driverContext, groups, dataType, numOrdinalMerges)) { + try (Operator operator = operator(driverContext, groups, dataType)) { Page page = page(groups, dataType); for (int i = 0; i < opCount; i++) { operator.addInput(page.shallowCopy()); diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java index e104aa85cccb8..6d4cd1116a02a 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java @@ -8,12 +8,14 @@ */ package org.elasticsearch.benchmark.vector; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.core.IOUtils; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; @@ -52,20 +54,26 @@ public class Int4ScorerBenchmark { LogConfigurator.configureESLogging(); // native access requires logging to be initialized } - @Param({ "384", "702", "1024" }) + @Param({ "384", "782", "1024" }) int dims; - int numVectors = 200; - int numQueries = 10; + int numVectors = 20 * ES91Int4VectorsScorer.BULK_SIZE; + int numQueries = 5; byte[] scratch; byte[][] binaryVectors; byte[][] binaryQueries; + float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE]; + + float[] scratchFloats = new float[3]; ES91Int4VectorsScorer scorer; Directory dir; IndexInput in; + OptimizedScalarQuantizer.QuantizationResult queryCorrections; + float centroidDp; + @Setup public void setup() throws IOException { binaryVectors = new byte[numVectors][dims]; @@ -77,9 +85,19 @@ public void setup() throws IOException { binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16); } out.writeBytes(binaryVector, 0, binaryVector.length); + ThreadLocalRandom.current().nextBytes(binaryVector); + out.writeBytes(binaryVector, 0, 14); // corrections } } + queryCorrections = new OptimizedScalarQuantizer.QuantizationResult( + ThreadLocalRandom.current().nextFloat(), + ThreadLocalRandom.current().nextFloat(), + ThreadLocalRandom.current().nextFloat(), + Short.toUnsignedInt((short) ThreadLocalRandom.current().nextInt()) + ); + centroidDp = ThreadLocalRandom.current().nextFloat(); + in = dir.openInput("vectors", IOContext.DEFAULT); binaryQueries = new byte[numVectors][dims]; for (byte[] binaryVector : binaryVectors) { @@ -105,18 +123,66 @@ public void scoreFromArray(Blackhole bh) throws IOException { in.seek(0); for (int i = 0; i < numVectors; i++) { in.readBytes(scratch, 0, dims); - bh.consume(VectorUtil.int4DotProduct(binaryQueries[j], scratch)); + int dp = VectorUtil.int4DotProduct(binaryQueries[j], scratch); + in.readFloats(scratchFloats, 0, 3); + float score = scorer.applyCorrections( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, // assuming no centroid dot product for this benchmark + scratchFloats[0], + scratchFloats[1], + Short.toUnsignedInt(in.readShort()), + scratchFloats[2], + dp + ); + bh.consume(score); } } } @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) - public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException { + public void scoreFromMemorySegment(Blackhole bh) throws IOException { for (int j = 0; j < numQueries; j++) { in.seek(0); for (int i = 0; i < numVectors; i++) { - bh.consume(scorer.int4DotProduct(binaryQueries[j])); + bh.consume( + scorer.score( + binaryQueries[j], + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp + ) + ); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegmentBulk(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { + scorer.scoreBulk( + binaryQueries[j], + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + scores + ); + for (float score : scores) { + bh.consume(score); + } } } } diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmarkTests.java similarity index 92% rename from benchmarks/src/test/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmarkTests.java rename to benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmarkTests.java index 7d72455f9fb22..e1d2b8f43100c 100644 --- a/benchmarks/src/test/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmarkTests.java +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmarkTests.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.benchmark.compute.operator; +package org.elasticsearch.benchmark._nightly.esql; import org.elasticsearch.test.ESTestCase; diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java index da72315521423..88ba8607b9281 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java @@ -34,6 +34,7 @@ import java.io.File; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Stream; import javax.inject.Inject; @@ -50,6 +51,8 @@ public abstract class ElasticsearchTestBasePlugin implements Plugin { public static final String DUMP_OUTPUT_ON_FAILURE_PROP_NAME = "dumpOutputOnFailure"; + public static final Set TEST_TASKS_WITH_ENTITLEMENTS = Set.of("test", "internalClusterTest"); + @Inject protected abstract ProviderFactory getProviderFactory(); @@ -174,14 +177,23 @@ public void execute(Task t) { nonInputProperties.systemProperty("workspace.dir", Util.locateElasticsearchWorkspace(project.getGradle())); // we use 'temp' relative to CWD since this is per JVM and tests are forbidden from writing to CWD nonInputProperties.systemProperty("java.io.tmpdir", test.getWorkingDir().toPath().resolve("temp")); + if (test.getName().equals("internalClusterTest")) { + // configure a node home directory independent of the Java temp dir so that entitlements can be properly enforced + nonInputProperties.systemProperty("tempDir", test.getWorkingDir().toPath().resolve("nodesTemp")); + } SourceSetContainer sourceSets = project.getExtensions().getByType(SourceSetContainer.class); SourceSet mainSourceSet = sourceSets.findByName(SourceSet.MAIN_SOURCE_SET_NAME); SourceSet testSourceSet = sourceSets.findByName(SourceSet.TEST_SOURCE_SET_NAME); - if ("test".equals(test.getName()) && mainSourceSet != null && testSourceSet != null) { + SourceSet internalClusterTestSourceSet = sourceSets.findByName("internalClusterTest"); + + if (TEST_TASKS_WITH_ENTITLEMENTS.contains(test.getName()) && mainSourceSet != null && testSourceSet != null) { FileCollection mainRuntime = mainSourceSet.getRuntimeClasspath(); FileCollection testRuntime = testSourceSet.getRuntimeClasspath(); - FileCollection testOnlyFiles = testRuntime.minus(mainRuntime); + FileCollection internalClusterTestRuntime = ("internalClusterTest".equals(test.getName()) + && internalClusterTestSourceSet != null) ? internalClusterTestSourceSet.getRuntimeClasspath() : project.files(); + FileCollection testOnlyFiles = testRuntime.plus(internalClusterTestRuntime).minus(mainRuntime); + test.doFirst(task -> test.environment("es.entitlement.testOnlyPath", testOnlyFiles.getAsPath())); } @@ -241,14 +253,15 @@ public void execute(Task t) { * Computes and sets the {@code --patch-module=java.base} and {@code --add-opens=java.base} JVM command line options. */ private void configureJavaBaseModuleOptions(Project project) { - project.getTasks().withType(Test.class).matching(task -> task.getName().equals("test")).configureEach(test -> { - FileCollection patchedImmutableCollections = patchedImmutableCollections(project); + project.getTasks().withType(Test.class).configureEach(test -> { + // patch immutable collections only for "test" task + FileCollection patchedImmutableCollections = test.getName().equals("test") ? patchedImmutableCollections(project) : null; if (patchedImmutableCollections != null) { test.getInputs().files(patchedImmutableCollections); test.systemProperty("tests.hackImmutableCollections", "true"); } - FileCollection entitlementBridge = entitlementBridge(project); + FileCollection entitlementBridge = TEST_TASKS_WITH_ENTITLEMENTS.contains(test.getName()) ? entitlementBridge(project) : null; if (entitlementBridge != null) { test.getInputs().files(entitlementBridge); } @@ -312,27 +325,30 @@ private static void configureEntitlements(Project project) { } FileCollection bridgeFiles = bridgeConfig; - project.getTasks().withType(Test.class).configureEach(test -> { - // See also SystemJvmOptions.maybeAttachEntitlementAgent. - - // Agent - if (agentFiles.isEmpty() == false) { - test.getInputs().files(agentFiles); - test.systemProperty("es.entitlement.agentJar", agentFiles.getAsPath()); - test.systemProperty("jdk.attach.allowAttachSelf", true); - } + project.getTasks() + .withType(Test.class) + .matching(test -> TEST_TASKS_WITH_ENTITLEMENTS.contains(test.getName())) + .configureEach(test -> { + // See also SystemJvmOptions.maybeAttachEntitlementAgent. + + // Agent + if (agentFiles.isEmpty() == false) { + test.getInputs().files(agentFiles); + test.systemProperty("es.entitlement.agentJar", agentFiles.getAsPath()); + test.systemProperty("jdk.attach.allowAttachSelf", true); + } - // Bridge - if (bridgeFiles.isEmpty() == false) { - String modulesContainingEntitlementInstrumentation = "java.logging,java.net.http,java.naming,jdk.net"; - test.getInputs().files(bridgeFiles); - // Tests may not be modular, but the JDK still is - test.jvmArgs( - "--add-exports=java.base/org.elasticsearch.entitlement.bridge=ALL-UNNAMED," - + modulesContainingEntitlementInstrumentation - ); - } - }); + // Bridge + if (bridgeFiles.isEmpty() == false) { + String modulesContainingEntitlementInstrumentation = "java.logging,java.net.http,java.naming,jdk.net"; + test.getInputs().files(bridgeFiles); + // Tests may not be modular, but the JDK still is + test.jvmArgs( + "--add-exports=java.base/org.elasticsearch.entitlement.bridge=ALL-UNNAMED," + + modulesContainingEntitlementInstrumentation + ); + } + }); } } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/RestrictedBuildApiService.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/RestrictedBuildApiService.java index 4f3c4b3d94f68..205930133156c 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/RestrictedBuildApiService.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/RestrictedBuildApiService.java @@ -56,8 +56,6 @@ private static ListMultimap, String> createLegacyRestTestBasePluginUsag map.put(LegacyRestTestBasePlugin.class, ":x-pack:qa:third-party:jira"); map.put(LegacyRestTestBasePlugin.class, ":x-pack:qa:third-party:pagerduty"); map.put(LegacyRestTestBasePlugin.class, ":x-pack:qa:third-party:slack"); - map.put(LegacyRestTestBasePlugin.class, ":x-pack:plugin:async-search:qa:rest"); - map.put(LegacyRestTestBasePlugin.class, ":x-pack:plugin:autoscaling:qa:rest"); map.put(LegacyRestTestBasePlugin.class, ":x-pack:plugin:deprecation:qa:early-deprecation-rest"); map.put(LegacyRestTestBasePlugin.class, ":x-pack:plugin:deprecation:qa:rest"); map.put(LegacyRestTestBasePlugin.class, ":x-pack:plugin:downsample:qa:with-security"); diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/dependencies/patches/awsv2sdk/Awsv2ClassPatcher.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/dependencies/patches/awsv2sdk/Awsv2ClassPatcher.java deleted file mode 100644 index 1e515afd8404b..0000000000000 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/dependencies/patches/awsv2sdk/Awsv2ClassPatcher.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.gradle.internal.dependencies.patches.awsv2sdk; - -import org.elasticsearch.gradle.internal.dependencies.patches.PatcherInfo; -import org.elasticsearch.gradle.internal.dependencies.patches.Utils; -import org.gradle.api.artifacts.transform.CacheableTransform; -import org.gradle.api.artifacts.transform.InputArtifact; -import org.gradle.api.artifacts.transform.TransformAction; -import org.gradle.api.artifacts.transform.TransformOutputs; -import org.gradle.api.artifacts.transform.TransformParameters; -import org.gradle.api.file.FileSystemLocation; -import org.gradle.api.provider.Provider; -import org.gradle.api.tasks.Classpath; -import org.jetbrains.annotations.NotNull; - -import java.io.File; -import java.util.List; - -import static org.elasticsearch.gradle.internal.dependencies.patches.PatcherInfo.classPatcher; - -@CacheableTransform -public abstract class Awsv2ClassPatcher implements TransformAction { - - private static final String JAR_FILE_TO_PATCH = "aws-query-protocol"; - - private static final List CLASS_PATCHERS = List.of( - // This patcher is needed because of this AWS bug: https://github.com/aws/aws-sdk-java-v2/issues/5968 - // As soon as the bug is resolved and we upgrade our AWS SDK v2 libraries, we can remove this. - classPatcher( - "software/amazon/awssdk/protocols/query/internal/marshall/ListQueryMarshaller.class", - "213e84d9a745bdae4b844334d17aecdd6499b36df32aa73f82dc114b35043009", - StringFormatInPathResolverPatcher::new - ) - ); - - @Classpath - @InputArtifact - public abstract Provider getInputArtifact(); - - @Override - public void transform(@NotNull TransformOutputs outputs) { - File inputFile = getInputArtifact().get().getAsFile(); - - if (inputFile.getName().startsWith(JAR_FILE_TO_PATCH)) { - System.out.println("Patching " + inputFile.getName()); - File outputFile = outputs.file(inputFile.getName().replace(".jar", "-patched.jar")); - Utils.patchJar(inputFile, outputFile, CLASS_PATCHERS); - } else { - System.out.println("Skipping " + inputFile.getName()); - outputs.file(getInputArtifact()); - } - } -} diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/dependencies/patches/awsv2sdk/StringFormatInPathResolverPatcher.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/dependencies/patches/awsv2sdk/StringFormatInPathResolverPatcher.java deleted file mode 100644 index 506dab001dbe7..0000000000000 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/dependencies/patches/awsv2sdk/StringFormatInPathResolverPatcher.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.gradle.internal.dependencies.patches.awsv2sdk; - -import org.objectweb.asm.ClassVisitor; -import org.objectweb.asm.ClassWriter; -import org.objectweb.asm.MethodVisitor; -import org.objectweb.asm.Type; - -import java.util.Locale; - -import static org.objectweb.asm.Opcodes.ASM9; -import static org.objectweb.asm.Opcodes.GETSTATIC; -import static org.objectweb.asm.Opcodes.INVOKESTATIC; - -class StringFormatInPathResolverPatcher extends ClassVisitor { - - StringFormatInPathResolverPatcher(ClassWriter classWriter) { - super(ASM9, classWriter); - } - - @Override - public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { - return new ReplaceCallMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions)); - } - - /** - * Replaces calls to String.format(format, args); with calls to String.format(Locale.ROOT, format, args); - */ - private static class ReplaceCallMethodVisitor extends MethodVisitor { - private static final String CLASS_INTERNAL_NAME = Type.getInternalName(String.class); - private static final String METHOD_NAME = "format"; - private static final String OLD_METHOD_DESCRIPTOR = Type.getMethodDescriptor( - Type.getType(String.class), - Type.getType(String.class), - Type.getType(Object[].class) - ); - private static final String NEW_METHOD_DESCRIPTOR = Type.getMethodDescriptor( - Type.getType(String.class), - Type.getType(Locale.class), - Type.getType(String.class), - Type.getType(Object[].class) - ); - - private boolean foundFormatPattern = false; - - ReplaceCallMethodVisitor(MethodVisitor methodVisitor) { - super(ASM9, methodVisitor); - } - - @Override - public void visitLdcInsn(Object value) { - if (value instanceof String s && s.startsWith("%s")) { - if (foundFormatPattern) { - throw new IllegalStateException( - "A previous string format constant was not paired with a String.format() call. " - + "Patching would generate an unbalances stack" - ); - } - // Push the extra arg on the stack - mv.visitFieldInsn(GETSTATIC, Type.getInternalName(Locale.class), "ROOT", Type.getDescriptor(Locale.class)); - foundFormatPattern = true; - } - super.visitLdcInsn(value); - } - - @Override - public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) { - if (opcode == INVOKESTATIC - && foundFormatPattern - && CLASS_INTERNAL_NAME.equals(owner) - && METHOD_NAME.equals(name) - && OLD_METHOD_DESCRIPTOR.equals(descriptor)) { - // Replace the call with String.format(Locale.ROOT, format, args) - mv.visitMethodInsn(INVOKESTATIC, CLASS_INTERNAL_NAME, METHOD_NAME, NEW_METHOD_DESCRIPTOR, false); - foundFormatPattern = false; - } else { - super.visitMethodInsn(opcode, owner, name, descriptor, isInterface); - } - } - } -} diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties index c443b280e3dd2..c552581a81738 100644 --- a/build-tools-internal/version.properties +++ b/build-tools-internal/version.properties @@ -17,7 +17,7 @@ jna = 5.12.1 netty = 4.1.118.Final commons_lang3 = 3.9 google_oauth_client = 1.34.1 -awsv2sdk = 2.30.38 +awsv2sdk = 2.31.78 reactive_streams = 1.0.4 antlr4 = 4.13.1 diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/test/TestBuildInfoPlugin.java b/build-tools/src/main/java/org/elasticsearch/gradle/test/TestBuildInfoPlugin.java index ed20d40582f57..c0aabfe17e56f 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/test/TestBuildInfoPlugin.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/test/TestBuildInfoPlugin.java @@ -58,9 +58,12 @@ public void apply(Project project) { }); if (project.getRootProject().getName().equals("elasticsearch")) { - project.getTasks().withType(Test.class).matching(test -> List.of("test").contains(test.getName())).configureEach(test -> { - test.systemProperty("es.entitlement.enableForTests", "true"); - }); + project.getTasks() + .withType(Test.class) + .matching(test -> List.of("test", "internalClusterTest").contains(test.getName())) + .configureEach(test -> { + test.systemProperty("es.entitlement.enableForTests", "true"); + }); } } } diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/RunTask.java b/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/RunTask.java index 725de1ac9448c..76700e9092eb6 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/RunTask.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/RunTask.java @@ -303,7 +303,6 @@ else if (node.getSettingKeys().contains("telemetry.metrics.enabled") == false) { if (cliDebug) { enableCliDebug(); } - enableEntitlements(); } @TaskAction diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/TestClustersAware.java b/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/TestClustersAware.java index 2e313fa73c4ee..4a45e8e4a03c4 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/TestClustersAware.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/testclusters/TestClustersAware.java @@ -88,12 +88,4 @@ default void enableCliDebug() { } } } - - default void enableEntitlements() { - for (ElasticsearchCluster cluster : getClusters()) { - for (ElasticsearchNode node : cluster.getNodes()) { - node.cliJvmArgs("-Des.entitlements.enabled=true"); - } - } - } } diff --git a/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/MachineDependentHeap.java b/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/MachineDependentHeap.java index b68e374bbdb94..1a397ab8aa005 100644 --- a/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/MachineDependentHeap.java +++ b/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/MachineDependentHeap.java @@ -40,6 +40,8 @@ public class MachineDependentHeap { private static final FeatureFlag NEW_ML_MEMORY_COMPUTATION_FEATURE_FLAG = new FeatureFlag("new_ml_memory_computation"); + private boolean useNewMlMemoryComputation = false; + public MachineDependentHeap() {} /** @@ -55,6 +57,11 @@ public final List determineHeapSettings( SystemMemoryInfo systemMemoryInfo, List userDefinedJvmOptions ) throws IOException, InterruptedException { + if (userDefinedJvmOptions.contains("-Des.new_ml_memory_computation_feature_flag_enabled=true") + || NEW_ML_MEMORY_COMPUTATION_FEATURE_FLAG.isEnabled()) { + useNewMlMemoryComputation = true; + } + // TODO: this could be more efficient, to only parse final options once final Map finalJvmOptions = JvmOption.findFinalOptions(userDefinedJvmOptions); if (isMaxHeapSpecified(finalJvmOptions) || isMinHeapSpecified(finalJvmOptions) || isInitialHeapSpecified(finalJvmOptions)) { @@ -107,7 +114,7 @@ protected int getHeapSizeMb(Settings nodeSettings, MachineNodeRole role, long av case ML_ONLY -> { double heapFractionBelow16GB = 0.4; double heapFractionAbove16GB = 0.1; - if (NEW_ML_MEMORY_COMPUTATION_FEATURE_FLAG.isEnabled()) { + if (useNewMlMemoryComputation) { heapFractionBelow16GB = 0.4 / (1.0 + JvmErgonomics.DIRECT_MEMORY_TO_HEAP_FACTOR); heapFractionAbove16GB = 0.1 / (1.0 + JvmErgonomics.DIRECT_MEMORY_TO_HEAP_FACTOR); } diff --git a/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java b/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java index 5191b60f1f8c9..3c0d3072b0e57 100644 --- a/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java +++ b/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.jdk.RuntimeVersionFeature; import java.io.IOException; import java.nio.file.Files; @@ -85,7 +84,6 @@ static List systemJvmOptions(Settings nodeSettings, final Map s).toList(); } @@ -160,14 +158,6 @@ private static Stream maybeWorkaroundG1Bug() { return Stream.of(); } - private static Stream maybeAllowSecurityManager(boolean useEntitlements) { - if (RuntimeVersionFeature.isSecurityManagerAvailable()) { - // Will become conditional on useEntitlements once entitlements can run without SM - return Stream.of("-Djava.security.manager=allow"); - } - return Stream.of(); - } - private static Stream maybeAttachEntitlementAgent(Path esHome, boolean useEntitlements) { if (useEntitlements == false) { return Stream.empty(); @@ -191,7 +181,6 @@ private static Stream maybeAttachEntitlementAgent(Path esHome, boolean u // into java.base, we must export the bridge from java.base to these modules, as a comma-separated list String modulesContainingEntitlementInstrumentation = "java.logging,java.net.http,java.naming,jdk.net"; return Stream.of( - "-Des.entitlements.enabled=true", "-XX:+EnableDynamicAgentLoading", "-Djdk.attach.allowAttachSelf=true", "--patch-module=java.base=" + bridgeJar, diff --git a/docs/changelog/128639.yaml b/docs/changelog/128639.yaml new file mode 100644 index 0000000000000..8fbfa9927a938 --- /dev/null +++ b/docs/changelog/128639.yaml @@ -0,0 +1,6 @@ +pr: 128639 +summary: Substitue `date_trunc` with `round_to` when the pre-calculated rounding points + are available +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/129090.yaml b/docs/changelog/129090.yaml new file mode 100644 index 0000000000000..a394a795b09ad --- /dev/null +++ b/docs/changelog/129090.yaml @@ -0,0 +1,6 @@ +pr: 129090 +summary: Enable force inference endpoint deleting for invalid models and after stopping + model deployment fails +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/129282.yaml b/docs/changelog/129282.yaml new file mode 100644 index 0000000000000..75e56899ee23e --- /dev/null +++ b/docs/changelog/129282.yaml @@ -0,0 +1,6 @@ +pr: 129282 +summary: "Fix query rewrite logic to preserve `boosts` and `queryName` for `match`,\ + \ `knn`, and `sparse_vector` queries on semantic_text fields" +area: Search +type: bug +issues: [] diff --git a/docs/changelog/129745.yaml b/docs/changelog/129745.yaml new file mode 100644 index 0000000000000..35cfd0671bd64 --- /dev/null +++ b/docs/changelog/129745.yaml @@ -0,0 +1,6 @@ +pr: 129745 +summary: "ESQL: Fix `mv_expand` inconsistent column order" +area: ES|QL +type: bug +issues: + - 129000 diff --git a/docs/changelog/129848.yaml b/docs/changelog/129848.yaml new file mode 100644 index 0000000000000..8a22e00fb6115 --- /dev/null +++ b/docs/changelog/129848.yaml @@ -0,0 +1,5 @@ +pr: 129848 +summary: "[ML] Add Azure AI Rerank support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/129929.yaml b/docs/changelog/129929.yaml new file mode 100644 index 0000000000000..c2296a64ab434 --- /dev/null +++ b/docs/changelog/129929.yaml @@ -0,0 +1,5 @@ +pr: 129929 +summary: Add support for RLIKE (LIST) with pushdown +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/130092.yaml b/docs/changelog/130092.yaml new file mode 100644 index 0000000000000..0e54e5f013d23 --- /dev/null +++ b/docs/changelog/130092.yaml @@ -0,0 +1,5 @@ +pr: 130092 +summary: "Added Llama provider support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/130427.yaml b/docs/changelog/130427.yaml index 4666cf0987cda..1727a8d2c4f27 100644 --- a/docs/changelog/130427.yaml +++ b/docs/changelog/130427.yaml @@ -1,5 +1,17 @@ pr: 130427 -summary: Disallow brackets in unquoted index pattersn +summary: Disallow brackets in unquoted index patterns area: ES|QL -type: bug -issues: [] +type: breaking +issues: + - 130378 +breaking: + title: Unquoted index patterns do not allow `(` and `)` characters + area: ES|QL + details: >- + Previously, ES|QL accepted unquoted index patterns containing brackets, such as `FROM index(1) | ENRICH policy(2)`. + + This query syntax is no longer valid because it could conflict with subquery syntax, where brackets are used as delimiters. + + Brackets are now only allowed in quoted index patterns. For example: `FROM "index(1)" | ENRICH "policy(2)"`. + impact: "This affects existing queries containing brackets in index or policy names, i.e. in FROM, ENRICH, and LOOKUP JOIN commands." + notable: false diff --git a/docs/changelog/130544.yaml b/docs/changelog/130544.yaml new file mode 100644 index 0000000000000..415357d929f8d --- /dev/null +++ b/docs/changelog/130544.yaml @@ -0,0 +1,6 @@ +pr: 130544 +summary: Sync Inference with Trained Model stats +area: Machine Learning +type: bug +issues: + - 130339 diff --git a/docs/changelog/130855.yaml b/docs/changelog/130855.yaml new file mode 100644 index 0000000000000..ee95181f033de --- /dev/null +++ b/docs/changelog/130855.yaml @@ -0,0 +1,6 @@ +pr: 130855 +summary: Add checks that optimizers do not modify the layout +area: ES|QL +type: enhancement +issues: + - 125576 diff --git a/docs/changelog/130909.yaml b/docs/changelog/130909.yaml new file mode 100644 index 0000000000000..a00d6cbdca570 --- /dev/null +++ b/docs/changelog/130909.yaml @@ -0,0 +1,5 @@ +pr: 130909 +summary: Allow adjustment of transport TLS handshake timeout +area: Network +type: enhancement +issues: [] diff --git a/docs/changelog/130914.yaml b/docs/changelog/130914.yaml new file mode 100644 index 0000000000000..da38b52e5f879 --- /dev/null +++ b/docs/changelog/130914.yaml @@ -0,0 +1,6 @@ +pr: 130914 +summary: Fix LIMIT NPE with null value +area: ES|QL +type: bug +issues: + - 130908 diff --git a/docs/changelog/130924.yaml b/docs/changelog/130924.yaml new file mode 100644 index 0000000000000..09b0f3b90533c --- /dev/null +++ b/docs/changelog/130924.yaml @@ -0,0 +1,6 @@ +pr: 130924 +summary: Check field data type before casting when applying geo distance sort +area: Search +type: bug +issues: + - 129500 diff --git a/docs/changelog/130939.yaml b/docs/changelog/130939.yaml new file mode 100644 index 0000000000000..86058b797e405 --- /dev/null +++ b/docs/changelog/130939.yaml @@ -0,0 +1,5 @@ +pr: 130939 +summary: Expose HTTP connection metrics to telemetry +area: Network +type: enhancement +issues: [] diff --git a/docs/changelog/130940.yaml b/docs/changelog/130940.yaml new file mode 100644 index 0000000000000..1adab0cc81926 --- /dev/null +++ b/docs/changelog/130940.yaml @@ -0,0 +1,6 @@ +pr: 130940 +summary: Block trained model updates from inference +area: Machine Learning +type: enhancement +issues: + - 129999 diff --git a/docs/changelog/130947.yaml b/docs/changelog/130947.yaml new file mode 100644 index 0000000000000..bcca93dbc681a --- /dev/null +++ b/docs/changelog/130947.yaml @@ -0,0 +1,5 @@ +pr: 130947 +summary: "[main]Prepare Index Like fix for backport to 9.1 and 8.19" +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/131015.yaml b/docs/changelog/131015.yaml new file mode 100644 index 0000000000000..91e0519e88ec6 --- /dev/null +++ b/docs/changelog/131015.yaml @@ -0,0 +1,5 @@ +pr: 131015 +summary: Move streams status actions to cluster:monitor group +area: Data streams +type: bug +issues: [] diff --git a/docs/changelog/131032.yaml b/docs/changelog/131032.yaml new file mode 100644 index 0000000000000..c7cbc3af0f9c7 --- /dev/null +++ b/docs/changelog/131032.yaml @@ -0,0 +1,5 @@ +pr: 131032 +summary: "Fix: `GET _synonyms` returns synonyms with empty rules" +area: Relevance +type: bug +issues: [] diff --git a/docs/changelog/131050.yaml b/docs/changelog/131050.yaml new file mode 100644 index 0000000000000..f8c932b464dba --- /dev/null +++ b/docs/changelog/131050.yaml @@ -0,0 +1,6 @@ +pr: 131050 +summary: Upgrade AWS Java SDK to 2.31.78 +area: "Snapshot/Restore" +type: upgrade +issues: [] + diff --git a/docs/changelog/131053.yaml b/docs/changelog/131053.yaml new file mode 100644 index 0000000000000..b30a7c8ee8cc5 --- /dev/null +++ b/docs/changelog/131053.yaml @@ -0,0 +1,5 @@ +pr: 131053 +summary: Split large pages on load sometimes +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/131056.yaml b/docs/changelog/131056.yaml new file mode 100644 index 0000000000000..3058c5da2f523 --- /dev/null +++ b/docs/changelog/131056.yaml @@ -0,0 +1,5 @@ +pr: 131056 +summary: Add existing shards allocator settings to failure store allowed list +area: Data streams +type: bug +issues: [] diff --git a/docs/changelog/131061.yaml b/docs/changelog/131061.yaml new file mode 100644 index 0000000000000..be12ae8d3f137 --- /dev/null +++ b/docs/changelog/131061.yaml @@ -0,0 +1,5 @@ +pr: 131061 +summary: Speed up reading multivalued keywords +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/131081.yaml b/docs/changelog/131081.yaml new file mode 100644 index 0000000000000..e0557f17968a4 --- /dev/null +++ b/docs/changelog/131081.yaml @@ -0,0 +1,6 @@ +pr: 131081 +summary: Fix knn search error when dimensions are not set +area: Vector Search +type: bug +issues: + - 129550 diff --git a/docs/changelog/131111.yaml b/docs/changelog/131111.yaml new file mode 100644 index 0000000000000..ac5d950a3d31d --- /dev/null +++ b/docs/changelog/131111.yaml @@ -0,0 +1,5 @@ +pr: 131111 +summary: Don't allow field caps to use semantic queries as index filters +area: Search +type: bug +issues: [] diff --git a/docs/changelog/131113.yaml b/docs/changelog/131113.yaml new file mode 100644 index 0000000000000..cca54f0a302f0 --- /dev/null +++ b/docs/changelog/131113.yaml @@ -0,0 +1,5 @@ +pr: 131113 +summary: Including `max_tokens` through the Service API for Anthropic +area: Machine Learning +type: bug +issues: [] diff --git a/docs/changelog/131173.yaml b/docs/changelog/131173.yaml new file mode 100644 index 0000000000000..74dcedefe26e9 --- /dev/null +++ b/docs/changelog/131173.yaml @@ -0,0 +1,5 @@ +pr: 131173 +summary: Add attribute count to `SamlAttribute` `toString` +area: Authentication +type: enhancement +issues: [] diff --git a/docs/changelog/131200.yaml b/docs/changelog/131200.yaml new file mode 100644 index 0000000000000..49a88fa79f90c --- /dev/null +++ b/docs/changelog/131200.yaml @@ -0,0 +1,5 @@ +pr: 131200 +summary: Improve lost-increment message in repo analysis +area: Snapshot/Restore +type: enhancement +issues: [] diff --git a/docs/changelog/131296.yaml b/docs/changelog/131296.yaml new file mode 100644 index 0000000000000..a3cf791bea0b3 --- /dev/null +++ b/docs/changelog/131296.yaml @@ -0,0 +1,5 @@ +pr: 131296 +summary: Enable failure store for newly created APM datastreams +area: Ingest Node +type: enhancement +issues: [] diff --git a/docs/changelog/131391.yaml b/docs/changelog/131391.yaml new file mode 100644 index 0000000000000..acac2a5f4da96 --- /dev/null +++ b/docs/changelog/131391.yaml @@ -0,0 +1,6 @@ +pr: 131391 +summary: Fix bug in point in time response +area: Search +type: bug +issues: + - 131026 diff --git a/docs/changelog/131395.yaml b/docs/changelog/131395.yaml new file mode 100644 index 0000000000000..500b761be1472 --- /dev/null +++ b/docs/changelog/131395.yaml @@ -0,0 +1,5 @@ +pr: 131395 +summary: Enable failure store for newly created OTel data streams +area: Data streams +type: enhancement +issues: [] diff --git a/docs/changelog/131426.yaml b/docs/changelog/131426.yaml new file mode 100644 index 0000000000000..4f79415ba069d --- /dev/null +++ b/docs/changelog/131426.yaml @@ -0,0 +1,6 @@ +pr: 131426 +summary: Disallow remote enrich after lu join +area: ES|QL +type: bug +issues: + - 129372 diff --git a/docs/changelog/131442.yaml b/docs/changelog/131442.yaml new file mode 100644 index 0000000000000..23d00cd7d028d --- /dev/null +++ b/docs/changelog/131442.yaml @@ -0,0 +1,5 @@ +pr: 131442 +summary: Track inference deployments +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/131510.yaml b/docs/changelog/131510.yaml new file mode 100644 index 0000000000000..ccdd727fdc818 --- /dev/null +++ b/docs/changelog/131510.yaml @@ -0,0 +1,5 @@ +pr: 131510 +summary: Upgrade apm-agent to 1.55.0 +area: Infra/Metrics +type: upgrade +issues: [] diff --git a/docs/changelog/131525.yaml b/docs/changelog/131525.yaml new file mode 100644 index 0000000000000..233c4ff643643 --- /dev/null +++ b/docs/changelog/131525.yaml @@ -0,0 +1,6 @@ +pr: 131525 +summary: Fix semantic highlighting bug on flat quantized fields +area: Highlighting +type: bug +issues: + - 131443 diff --git a/docs/docset.yml b/docs/docset.yml index 15bd674a5fb5e..831b83809a381 100644 --- a/docs/docset.yml +++ b/docs/docset.yml @@ -111,3 +111,6 @@ subs: feat-imp: "feature importance" feat-imp-cap: "Feature importance" nlp: "natural language processing" + index-manage-app: "Index Management" + connectors-app: "Connectors" + ingest-pipelines-app: "Ingest Pipelines" \ No newline at end of file diff --git a/docs/internal/GeneralArchitectureGuide.md b/docs/internal/GeneralArchitectureGuide.md index d74490e62e9df..9cacb946ca138 100644 --- a/docs/internal/GeneralArchitectureGuide.md +++ b/docs/internal/GeneralArchitectureGuide.md @@ -182,12 +182,12 @@ capabilities. ## Serializations -## Settings +# Settings Elasticsearch supports [cluster-level settings][] and [index-level settings][], configurable via [node-level file settings][] (e.g. `elasticsearch.yml` file), command line arguments and REST APIs. -### Declaring a Setting +## Declaring a Setting [cluster-level settings]: https://www.elastic.co/guide/en/elasticsearch/reference/current/cluster-update-settings.html [index-level settings]: https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-update-settings.html @@ -222,7 +222,7 @@ settings. [SettingsModule constructor]: https://github.com/elastic/elasticsearch/blob/v8.13.2/server/src/main/java/org/elasticsearch/node/NodeConstruction.java#L491-L495 [getSettings()]: https://github.com/elastic/elasticsearch/blob/v8.13.2/server/src/main/java/org/elasticsearch/plugins/Plugin.java#L203-L208 -### Dynamically updating a Setting +## Dynamically updating a Setting Externally, [TransportClusterUpdateSettingsAction][] and [TransportUpdateSettingsAction][] (and the corresponding REST endpoints) allow users to dynamically change cluster and index settings, respectively. Internally, `AbstractScopedSettings` (parent class @@ -244,9 +244,9 @@ state must ever be reloaded from persisted state. [Metadata]: https://github.com/elastic/elasticsearch/blob/v8.13.2/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java#L212-L213 [applied here]: https://github.com/elastic/elasticsearch/blob/v8.13.2/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java#L2437 -## Deprecations +# Deprecations -## Backwards Compatibility +# Backwards Compatibility major releases are mostly about breaking compatibility and dropping deprecated functionality. @@ -292,18 +292,32 @@ See the [public upgrade docs][] for the upgrade process. [public upgrade docs]: https://www.elastic.co/guide/en/elasticsearch/reference/current/setup-upgrade.html -## Plugins +# Plugins (what warrants a plugin?) (what plugins do we have?) -## Testing +# Observability + +Elasticsearch emits logs as described in the [public logging docs][], and exposes a good deal of information about its inner workings using +all its management and stats APIs. Elasticsearch also integrates with the [Elastic APM Java agent][] to perform distributed tracing (as +described in [TRACING.md][]) and metrics collection (as described in [METERING.md][]). This agent exposes the data it collects to an +[OpenTelemetry][] service such as [Elastic APM Server][]. + +[public logging docs]: https://www.elastic.co/docs/deploy-manage/monitor/logging-configuration +[Elastic APM Java agent]: https://www.elastic.co/docs/reference/apm/agents/java +[OpenTelemetry]: https://opentelemetry.io/ +[Elastic APM Server]: https://www.elastic.co/docs/solutions/observability/apm +[TRACING.md]: https://github.com/elastic/elasticsearch/blob/v8.18.3/TRACING.md +[METERING.md]: https://github.com/elastic/elasticsearch/blob/v8.18.3/modules/apm/METERING.md + +# Testing (Overview of our testing frameworks. Discuss base test classes.) -### Unit Testing +## Unit Testing -### REST Testing +## REST Testing -### Integration Testing +## Integration Testing diff --git a/docs/redirects.yml b/docs/redirects.yml index 7fb1997268c92..27a4ae08ad6ca 100644 --- a/docs/redirects.yml +++ b/docs/redirects.yml @@ -1,3 +1,25 @@ redirects: -# Related to https://github.com/elastic/elasticsearch/pull/130716/ - 'reference/query-languages/eql/eql-ex-threat-detection.md': 'docs-content://explore-analyze/query-filter/languages/example-detect-threats-with-eql.md' \ No newline at end of file + # Related to https://github.com/elastic/elasticsearch/pull/130716/ + 'reference/query-languages/eql/eql-ex-threat-detection.md': 'docs-content://explore-analyze/query-filter/languages/example-detect-threats-with-eql.md' + + # https://github.com/elastic/elasticsearch/pull/131385 + 'reference/elasticsearch/rest-apis/retrievers.md': + to: 'reference/elasticsearch/rest-apis/retrievers.md' + anchors: {} # pass-through unlisted anchors in the `many` ruleset + many: + - to: 'reference/elasticsearch/rest-apis/retrievers/standard-retriever.md' + anchors: {'standard-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/knn-retriever.md' + anchors: {'knn-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/linear-retriever.md' + anchors: {'linear-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md' + anchors: {'rrf-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/rescorer-retriever.md' + anchors: {'rescorer-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/text-similarity-reranker-retriever.md' + anchors: {'text-similarity-reranker-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/rule-retriever.md' + anchors: {'rule-retriever'} + - to: 'reference/elasticsearch/rest-apis/retrievers/pinned-retriever.md' + anchors: {'pinned-retriever'} \ No newline at end of file diff --git a/docs/reference/elasticsearch/configuration-reference/security-settings.md b/docs/reference/elasticsearch/configuration-reference/security-settings.md index 1ec6600038841..7ca0dbdb97993 100644 --- a/docs/reference/elasticsearch/configuration-reference/security-settings.md +++ b/docs/reference/elasticsearch/configuration-reference/security-settings.md @@ -1933,6 +1933,8 @@ You can configure the following TLS/SSL settings. `xpack.security.transport.ssl.trust_restrictions.x509_fields` ![logo cloud](https://doc-icons.s3.us-east-2.amazonaws.com/logo_cloud.svg "Supported on Elastic Cloud Hosted") : Specifies which field(s) from the TLS certificate is used to match for the restricted trust management that is used for remote clusters connections. This should only be set when a self managed cluster can not create certificates that follow the Elastic Cloud pattern. The default value is ["subjectAltName.otherName.commonName"], the Elastic Cloud pattern. "subjectAltName.dnsName" is also supported and can be configured in addition to or in replacement of the default. +`xpack.security.transport.ssl.handshake_timeout` +: Specifies the timeout for a TLS handshake when opening a transport connection. Defaults to `10s`. ### Transport TLS/SSL key and trusted certificate settings [security-transport-tls-ssl-key-trusted-certificate-settings] @@ -2131,6 +2133,9 @@ You can configure the following TLS/SSL settings. For more information, see Oracle’s [Java Cryptography Architecture documentation](https://docs.oracle.com/en/java/javase/11/security/oracle-providers.md#GUID-7093246A-31A3-4304-AC5F-5FB6400405E2). +`xpack.security.remote_cluster_server.ssl.handshake_timeout` +: Specifies the timeout for a TLS handshake when handling an inbound remote-cluster connection. Defaults to `10s`. + ### Remote cluster server (API key based model) TLS/SSL key and trusted certificate settings [security-remote-cluster-server-tls-ssl-key-trusted-certificate-settings] @@ -2260,6 +2265,9 @@ You can configure the following TLS/SSL settings. For more information, see Oracle’s [Java Cryptography Architecture documentation](https://docs.oracle.com/en/java/javase/11/security/oracle-providers.md#GUID-7093246A-31A3-4304-AC5F-5FB6400405E2). +`xpack.security.remote_cluster_client.ssl.handshake_timeout` +: Specifies the timeout for a TLS handshake when opening a remote-cluster connection. Defaults to `10s`. + ### Remote cluster client (API key based model) TLS/SSL key and trusted certificate settings [security-remote-cluster-client-tls-ssl-key-trusted-certificate-settings] diff --git a/docs/reference/elasticsearch/index-settings/slow-log.md b/docs/reference/elasticsearch/index-settings/slow-log.md index 20b416360a1ca..cb8f6d05e8fbf 100644 --- a/docs/reference/elasticsearch/index-settings/slow-log.md +++ b/docs/reference/elasticsearch/index-settings/slow-log.md @@ -20,6 +20,7 @@ Events that meet the specified threshold are emitted into [{{es}} logging](docs- * If [{{es}} monitoring](docs-content://deploy-manage/monitor/stack-monitoring.md) is enabled, from [Stack Monitoring](docs-content://deploy-manage/monitor/monitoring-data/visualizing-monitoring-data.md). Slow log events have a `logger` value of `index.search.slowlog` or `index.indexing.slowlog`. * From local {{es}} service logs directory. Slow log files have a suffix of `_index_search_slowlog.json` or `_index_indexing_slowlog.json`. +See this [this video](https://www.youtube.com/watch?v=ulUPJshB5bU) for a walkthrough of setting and reviewing slow logs. ## Slow log format [slow-log-format] diff --git a/docs/reference/elasticsearch/index.md b/docs/reference/elasticsearch/index.md index a2060ee9c384c..8ffec42e3a1a0 100644 --- a/docs/reference/elasticsearch/index.md +++ b/docs/reference/elasticsearch/index.md @@ -1,4 +1,4 @@ -# Elasticsearch and index management +# Elasticsearch This section contains reference information for {{es}} and index management features. @@ -7,7 +7,7 @@ To learn more about {{es}} features and how to get started, refer to the [{{es}} For more details about query and scripting languages, check these sections: * [Query languages](../query-languages/index.md) * [Scripting languages](../scripting-languages/index.md) - + {{es}} also provides the following REST APIs: * [{{es}} API](https://www.elastic.co/docs/api/doc/elasticsearch) diff --git a/docs/reference/elasticsearch/mapping-reference/keyword.md b/docs/reference/elasticsearch/mapping-reference/keyword.md index 7b0c3a0537676..a642261179dcb 100644 --- a/docs/reference/elasticsearch/mapping-reference/keyword.md +++ b/docs/reference/elasticsearch/mapping-reference/keyword.md @@ -70,7 +70,19 @@ The following parameters are accepted by `keyword` fields: : Multi-fields allow the same string value to be indexed in multiple ways for different purposes, such as one field for search and a multi-field for sorting and aggregations. [`ignore_above`](/reference/elasticsearch/mapping-reference/ignore-above.md) -: Do not index any string longer than this value. Defaults to `2147483647` in standard indices so that all values would be accepted, and `8191` in logsdb indices to protect against Lucene's term byte-length limit of `32766`. Please however note that default dynamic mapping rules create a sub `keyword` field that overrides this default by setting `ignore_above: 256`. +: Do not index any field containing a string with more characters than this value. This is important because {{es}} + will reject entire documents if they contain keyword fields that exceed `32766` UTF-8 encoded bytes. + + To avoid any risk of document rejection, set this value to `8191` or less. Fields with strings exceeding this + length will be excluded from indexing. + + The defaults are complicated: + + | Index type | Default | Effect | + | ---------- | ------- | ------ | + | Standard indices | `2147483647` (effectively unbounded) | Documents will be rejected if this keyword exceeds `32766` UTF-8 encoded bytes. | + | `logsdb` indices | `8191` | This `keyword` field will never cause documents to be rejected. If this field is longer than `8191` characters it won't be indexed but its values are still available from `_source`. | + | [dynamic mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md) for string fields | `text` field with a [sub](/reference/elasticsearch/mapping-reference/multi-fields.md)-`keyword` field with an `ignore_above` of `256` | All string fields are available. Values longer than 256 characters are only available for full text search and won't have a value in their `.keyword` sub-field, so they can not be used for exact matching over _search. | [`index`](/reference/elasticsearch/mapping-reference/mapping-index.md) : Should the field be quickly searchable? Accepts `true` (default) and `false`. `keyword` fields that only have [`doc_values`](/reference/elasticsearch/mapping-reference/doc-values.md) enabled can still be queried, albeit slower. diff --git a/docs/reference/elasticsearch/rest-apis/index.md b/docs/reference/elasticsearch/rest-apis/index.md index e3dd8f9897986..871666cc99209 100644 --- a/docs/reference/elasticsearch/rest-apis/index.md +++ b/docs/reference/elasticsearch/rest-apis/index.md @@ -16,3 +16,667 @@ This section includes: - [Common options](/reference/elasticsearch/rest-apis/common-options.md) - [Compatibility](/reference/elasticsearch/rest-apis/compatibility.md) - [Examples](/reference/elasticsearch/rest-apis/api-examples.md) + +## API endpoints + +### [Autoscaling](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-autoscaling) + +The autoscaling APIs enable you to create and manage autoscaling policies and retrieve information about autoscaling capacity. Autoscaling adjusts resources based on demand. A deployment can use autoscaling to scale resources as needed, ensuring sufficient capacity to meet workload requirements. + +| API | Description | +| --- | ----------- | +| [Get Autoscaling Policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-autoscaling-get-autoscaling-policy) | Retrieves a specific autoscaling policy. | +| [Create or update an autoscaling policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-autoscaling-put-autoscaling-policy) | Creates or updates an autoscaling policy. | +| [Delete Autoscaling Policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-autoscaling-delete-autoscaling-policy) | Deletes an existing autoscaling policy. | +| [Get Autoscaling Capacity](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-autoscaling-get-autoscaling-capacity) | Estimates autoscaling capacity for current cluster state. | + +### [Behavioral analytics](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-analytics) + +```{applies_to} +stack: deprecated +``` + +The behavioral analytics APIs enable you to create and manage analytics collections and retrieve information about analytics collections. Behavioral Analytics is an analytics event collection platform. You can use it to analyze your users' searching and clicking behavior. Leverage this information to improve the relevance of your search results and identify gaps in your content. + +| API | Description | +| --- | ----------- | +| [Get Collections](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-get-behavioral-analytics) | Lists all behavioral analytics collections. | +| [Create Collection](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-put-behavioral-analytics) | Creates a new behavioral analytics collection. | +| [Delete Collection](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-delete-behavioral-analytics) | Deletes a behavioral analytics collection. | +| [Create Event](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-post-behavioral-analytics-event) | Sends a behavioral analytics event to a collection. | + +### [Compact and aligned text (CAT)](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-cat) + +The compact and aligned text (CAT) APIs return human-readable text as a response, instead of a JSON object. The CAT APIs aim are intended only for human consumption using the Kibana console or command line. They are not intended for use by applications. For application consumption, it's recommend to use a corresponding JSON API. + +| API | Description | +| --- | ----------- | +| [Get aliases](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-aliases) | Returns index aliases. | +| [Get allocation](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-allocation) | Provides a snapshot of shard allocation across nodes. | +| [Get component templates](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-component-templates) | Returns information about component templates. | +| [Get count](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-count) | Returns document count for specified indices. | +| [Get fielddata](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-fielddata) | Shows fielddata memory usage by field. | +| [Get health](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-health) | Returns cluster health status. | +| [Get help](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-help) | Shows help for CAT APIs. | +| [Get index information](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-indices) | Returns index statistics. | +| [Get master](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-master) | Returns information about the elected master node. | +| [Get ml data frame analytics](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-ml-data-frame-analytics) | Returns data frame analytics jobs. | +| [Get ml datafeeds](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-ml-datafeeds) | Returns information about datafeeds. | +| [Get ml jobs](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-ml-jobs) | Returns anomaly detection jobs. | +| [Get ml trained models](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-ml-trained-models) | Returns trained machine learning models. | +| [Get nodeattrs](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-nodeattrs) | Returns custom node attributes. | +| [Get node information](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-nodes) | Returns cluster node info and statistics. | +| [Get pending tasks](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-pending-tasks) | Returns cluster pending tasks. | +| [Get plugins](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-plugins) | Returns information about installed plugins. | +| [Get recovery](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-recovery) | Returns shard recovery information. | +| [Get repositories](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-repositories) | Returns snapshot repository information. | +| [Get segments](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-segments) | Returns low-level segment information. | +| [Get shard information](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-shards) | Returns shard allocation across nodes. | +| [Get snapshots](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-snapshots) | Returns snapshot information. | +| [Get tasks](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-tasks) | Returns information about running tasks. | +| [Get templates](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-templates) | Returns index template information. | +| [Get thread pool](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-thread-pool) | Returns thread pool statistics. | +| [Get transforms](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cat-transforms) | Returns transform information. | + +### [Cluster](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-cluster) + +The cluster APIs enable you to retrieve information about your infrastructure on cluster, node, or shard level. You can manage cluster settings and voting configuration exceptions, collect node statistics and retrieve node information. + +| API | Description | +| --- | ----------- | +| [Get cluster health](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-health) | Returns health status of the cluster. | +| [Get cluster info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-info) | Returns basic information about the cluster. | +| [Reroute cluster](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-reroute) | Manually reassigns shard allocations. | +| [Get cluster state](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-state) | Retrieves the current cluster state. | +| [Explain shard allocation](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-allocation-explain) | Get explanations for shard allocations in the cluster. | +| [Update cluster settings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-put-settings) | Updates persistent or transient cluster settings. | +| [Get cluster stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-stats) | Returns cluster-wide statistics, including node, index, and shard metrics. | +| [Get cluster pending tasks](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-pending-tasks) | Lists cluster-level tasks that are pending execution. | +| [Get cluster settings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-get-settings) | Retrieves the current cluster-wide settings, including persistent and transient settings. | +| [Get cluster remote info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-remote-info) | Returns information about configured remote clusters for cross-cluster search and replication. | +| [Update cluster voting config exclusions](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-post-voting-config-exclusions) | Update the cluster voting config exclusions by node IDs or node names. | +| [Delete voting config exclusions](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-delete-voting-config-exclusions) | Clears voting configuration exclusions, allowing previously excluded nodes to participate in master elections. | + +### [Cluster - Health](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-health_report) + +The cluster - health API provides you a report with the health status of an Elasticsearch cluster. + +| API | Description | +| --- | ----------- | +| [Get cluster health report](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-health-report) | Returns health status of the cluster, including index-level details. | + +### [Connector](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-connector) + +The connector and sync jobs APIs provide a convenient way to create and manage Elastic connectors and sync jobs in an internal index. + +| API | Description | +| --- | ----------- | +| [Get connector](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-get) | Retrieves a connector configuration. | +| [Put connector](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-put) | Creates or updates a connector configuration. | +| [Delete connector](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-delete) | Deletes a connector configuration. | +| [Start connector sync job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-post) | Starts a sync job for a connector. | +| [Get connector sync job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-get) | Retrieves sync job details for a connector. | +| [Get all connectors](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-list) | Retrieves a list of all connector configurations. | +| [Get all connector sync jobs](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-list) | Retrieves a list of all connector sync jobs. | +| [Delete connector sync job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-delete) | Deletes a connector sync job. | + +The connector and sync jobs APIs provide a convenient way to create and manage Elastic connectors and sync jobs in an internal index. + +| API | Description | +| --- | ----------- | +| [Get connector](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-get) | Retrieves a connector configuration. | +| [Put connector](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-put) | Creates or updates a connector configuration. | +| [Delete connector](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-delete) | Deletes a connector configuration. | +| [Start connector sync job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-post) | Starts a sync job for a connector. | +| [Get connector sync job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-get) | Retrieves sync job details for a connector. | + +### [Cross-cluster replication (CCR)](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-ccr) + +The cross-cluster replication (CCR) APIs enable you to run cross-cluster replication operations, such as creating and managing follower indices or auto-follow patterns. With CCR, you can replicate indices across clusters to continue handling search requests in the event of a datacenter outage, prevent search volume from impacting indexing throughput, and reduce search latency by processing search requests in geo-proximity to the user. + +| API | Description | +| --- | ----------- | +| [Create or update auto-follow pattern](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-put-auto-follow-pattern) | Creates or updates an auto-follow pattern. | +| [Delete auto-follow pattern](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-delete-auto-follow-pattern) | Deletes an auto-follow pattern. | +| [Get auto-follow pattern](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-get-auto-follow-pattern) | Retrieves auto-follow pattern configuration. | +| [Pause auto-follow pattern](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-pause-auto-follow-pattern) | Pauses an auto-follow pattern. | +| [Resume auto-follow pattern](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-resume-auto-follow-pattern) | Resumes a paused auto-follow pattern. | +| [Forget follower](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-forget-follower) | Removes follower retention leases from leader index. | +| [Create follower](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-follow) | Creates a follower index. | +| [Get follower](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-follow-info) | Retrieves information about follower indices. | +| [Get follower stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-follow-stats) | Retrieves stats about follower indices. | +| [Pause follower](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-pause-follow) | Pauses replication of a follower index. | +| [Resume follower](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-resume-follow) | Resumes replication of a paused follower index. | +| [Unfollow index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-unfollow) | Converts a follower index into a regular index. | +| [Get CCR stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ccr-stats) | Retrieves overall CCR statistics for the cluster. | + +### [Data stream](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-data-stream) + +The data stream APIs enable you to create and manage data streams and data stream lifecycles. A data stream lets you store append-only time series data across multiple indices while giving you a single named resource for requests. Data streams are well-suited for logs, events, metrics, and other continuously generated data. + +| API | Description | +| --- | ----------- | +| [Create data stream](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-create-data-stream) | Creates a new data stream. | +| [Delete data stream](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-delete-data-stream) | Deletes an existing data stream. | +| [Get data stream](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-data-stream) | Retrieves one or more data streams. | +| [Modify data stream](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-modify-data-stream) | Updates the backing index configuration for a data stream. | +| [Promote data stream write index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-promote-data-stream) | Promotes a backing index to be the write index. | +| [Data streams stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-data-streams-stats) | Returns statistics about data streams. | +| [Migrate to data stream](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-migrate-to-data-stream) | Migrates an index or indices to a data stream. | + +### [Document](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-document) + +The document APIs enable you to create and manage documents in an {{es}} index. + +| API | Description | +| --- | ----------- | +| [Index document](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-create) | Indexes a document into a specific index. | +| [Get document](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-get) | Retrieves a document by ID. | +| [Delete document](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-delete) | Deletes a document by ID. | +| [Update document](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-update) | Updates a document using a script or partial doc. | +| [Bulk](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-bulk) | Performs multiple indexing or delete operations in a single API call. | +| [Multi-get document](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-mget) | Retrieves multiple documents by ID in one request. | +| [Update documents by query](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-update-by-query) | Updates documents that match a query. | +| [Delete documents by query](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-delete-by-query) | Deletes documents that match a query. | +| [Get term vectors](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-termvectors) | Retrieves term vectors for a document. | +| [Multi-termvectors](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-mtermvectors) | Retrieves term vectors for multiple documents. | +| [Reindex](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-reindex) | Copies documents from one index to another. | +| [Reindex Rethrottle](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-reindex-rethrottle) | Changes the throttle for a running reindex task. | +| [Explain](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-explain) | Explains how a document matches (or doesn't match) a query. | +| [Get source](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-get-source) | Retrieves the source of a document by ID. | +| [Exists](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-exists) | Checks if a document exists by ID. | + +### [Enrich](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-enrich) + +The enrich APIs enable you to manage enrich policies. An enrich policy is a set of configuration options used to add the right enrich data to the right incoming documents. + +| API | Description | +| --- | ----------- | +| [Create or update enrich policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-enrich-put-policy) | Creates or updates an enrich policy. | +| [Get enrich policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-enrich-get-policy) | Retrieves enrich policy definitions. | +| [Delete enrich policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-enrich-delete-policy) | Deletes an enrich policy. | +| [Execute enrich policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-enrich-execute-policy) | Executes an enrich policy to create an enrich index. | +| [Get enrich stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-enrich-stats) | Returns enrich coordinator and policy execution statistics. | + +### [EQL](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-eql) + +The EQL APIs enable you to run EQL-related operations. Event Query Language (EQL) is a query language for event-based time series data, such as logs, metrics, and traces. + +| API | Description | +| --- | ----------- | +| [Submit EQL search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-eql-search) | Runs an EQL search. | +| [Get EQL search status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-eql-get) | Retrieves the status of an asynchronous EQL search. | +| [Get EQL search results](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-eql-get) | Retrieves results of an asynchronous EQL search. | +| [Delete EQL search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-eql-delete) | Cancels an asynchronous EQL search. | + +### [ES|QL](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-esql) + +The ES|QL APIs enable you to run ES|QL-related operations. The Elasticsearch Query Language (ES|QL) provides a powerful way to filter, transform, and analyze data stored in Elasticsearch, and in the future in other runtimes. + +| API | Description | +| --- | ----------- | +| [ES|QL Query](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-esql-query) | Executes an ES|QL query using a SQL-like syntax. | +| [ES|QL Async Submit](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-esql-async-query) | Submits an ES|QL query to run asynchronously. | +| [ES|QL Async Get](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-esql-async-query-get) | Retrieves results of an asynchronous ES|QL query. | +| [ES|QL Async Delete](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-esql-async-query-delete) | Cancels an asynchronous ES|QL query. | + +### [Features](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-features) + +The feature APIs enable you to introspect and manage features provided by {{es}} and {{es}} plugins. + +| API | Description | +| --- | ----------- | +| [Get Features](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-features-get-features) | Lists all available features in the cluster. | +| [Reset Features](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-features-reset-features) | Resets internal state for system features. | + +### [Fleet](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-fleet) + +The Fleet APIs support Fleet’s use of Elasticsearch as a data store for internal agent and action data. + +| API | Description | +| --- | ----------- | +| [Run Multiple Fleet Searches](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-fleet-msearch) | Runs several Fleet searches with a single API request. | +| [Run a Fleet Search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-fleet-search) | Runs a Fleet search. | +| [Get global checkpoints](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-fleet-global-checkpoints) | Get the current global checkpoints for an index. | + +### [Graph explore](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-graph) + +The graph explore APIs enable you to extract and summarize information about the documents and terms in an {{es}} data stream or index. + +| API | Description | +| --- | ----------- | +| [Graph Explore](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-graph-explore) | Discovers relationships between indexed terms using relevance-based graph exploration. | + +### [Index](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-indices) + +The index APIs enable you to manage individual indices, index settings, aliases, mappings, and index templates. + +| API | Description | +| --- | ----------- | +| [Create index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-create) | Creates a new index with optional settings and mappings. | +| [Delete index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-delete) | Deletes an existing index. | +| [Get index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get) | Retrieves information about one or more indices. | +| [Open index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-open) | Opens a closed index to make it available for operations. | +| [Close index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-close) | Closes an index to free up resources. | +| [Shrink index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-shrink) | Shrinks an existing index into a new index with fewer primary shards. | +| [Split index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-split) | Splits an existing index into a new index with more primary shards. | +| [Clone index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-clone) | Clones an existing index into a new index. | +| [Manage index aliases](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-alias) | Manages index aliases. | +| [Update field mappings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) | Updates index mappings. | +| [Get field mappings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-mapping) | Retrieves index mappings. | +| [Get index settings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-settings) | Retrieves settings for one or more indices. | +| [Update index settings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-settings) | Updates index-level settings dynamically. | +| [Get index templates](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-template) | Retrieves legacy index templates. | +| [Put index template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-template) | Creates or updates a legacy index template. | +| [Delete index template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-delete-template) | Deletes a legacy index template. | +| [Get composable index templates](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-index-template) | Retrieves composable index templates. | +| [Put composable index template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-index-template) | Creates or updates a composable index template. | +| [Delete composable index template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-delete-index-template) | Deletes a composable index template. | +| [Get index alias](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-alias) | Retrieves index aliases. | +| [Delete index alias](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-delete-alias) | Deletes index aliases. | +| [Refresh index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-refresh) | Refreshes one or more indices, making recent changes searchable. | +| [Flush index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-flush) | Performs a flush operation on one or more indices. | +| [Clear index cache](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-clear-cache) | Clears caches associated with one or more indices. | +| [Force merge index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-forcemerge) | Merges index segments to reduce their number and improve performance. | +| [Freeze index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-freeze) | Freezes an index, making it read-only and minimizing its resource usage. | +| [Unfreeze index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-unfreeze) | Unfreezes a frozen index, making it writeable and fully functional. | +| [Rollover index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-rollover) | Rolls over an alias to a new index when conditions are met. | +| [Resolve index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-resolve) | Resolves expressions to index names, aliases, and data streams. | +| [Simulate index template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-simulate-index-template) | Simulates the application of a composable index template. | +| [Simulate template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-simulate-template) | Simulates the application of a legacy index template. | +| [Get mapping](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-mapping) | Retrieves mapping definitions for one or more indices. | +| [Put mapping](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) | Updates mapping definitions for one or more indices. | +| [Reload search analyzers](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-reload-search-analyzers) | Reloads search analyzers for one or more indices. | +| [Shrink index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-shrink) | Shrinks an existing index into a new index with fewer primary shards. | +| [Split index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-split) | Splits an existing index into a new index with more primary shards. | +| [Clone index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-clone) | Clones an existing index into a new index. | + +### [Index lifecycle management](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-ilm) + +The index lifecycle management APIs enable you to set up policies to automatically manage the index lifecycle. + +| API | Description | +| --- | ----------- | +| [Put Lifecycle Policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-put-lifecycle) | Creates or updates an ILM policy. | +| [Get Lifecycle Policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-get-lifecycle) | Retrieves one or more ILM policies. | +| [Delete Lifecycle Policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-delete-lifecycle) | Deletes an ILM policy. | +| [Explain Lifecycle](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-explain-lifecycle) | Shows the current lifecycle step for indices. | +| [Move to Step](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-move-to-step) | Manually moves an index to the next step in its lifecycle. | +| [Retry Lifecycle Step](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-retry) | Retries the current lifecycle step for failed indices. | +| [Start ILM](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-start) | Starts the ILM plugin. | +| [Stop ILM](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-stop) | Stops the ILM plugin. | +| [Get ILM Status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ilm-get-status) | Returns the status of the ILM plugin. | + +### [Inference](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-inference) + +The inference APIs enable you to create inference endpoints and integrate with machine learning models of different services - such as Amazon Bedrock, Anthropic, Azure AI Studio, Cohere, Google AI, Mistral, OpenAI, or HuggingFace. + +| API | Description | +| --- | ----------- | +| [Put Inference Endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) | Creates an inference endpoint. | +| [Get Inference Endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-get) | Retrieves one or more inference endpoints. | +| [Delete Inference Endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-delete) | Deletes an inference endpoint. | +| [Infer](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-inference) | Runs inference using a deployed model. | + +### [Info](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-info) + +The info API provides basic build, version, and cluster information. + +| API | Description | +| --- | ----------- | +| [Get cluster information](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-info) | Returns basic information about the cluster. | + +### [Ingest](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-ingest) + +The ingest APIs enable you to manage tasks and resources related to ingest pipelines and processors. + +| API | Description | +| --- | ----------- | +| [Create or update pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-put-pipeline) | Creates or updates an ingest pipeline. | +| [Get pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-get-pipeline) | Retrieves one or more ingest pipelines. | +| [Delete pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-delete-pipeline) | Deletes an ingest pipeline. | +| [Simulate pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-simulate) | Simulates a document through an ingest pipeline. | +| [Get built-in grok patterns](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-processor-grok) | Returns a list of built-in grok patterns. | +| [Get processor types](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-processor-types) | Returns a list of available processor types. | +| [Put pipeline processor](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-put-processor) | Creates or updates a custom pipeline processor. | +| [Delete pipeline processor](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ingest-delete-processor) | Deletes a custom pipeline processor. | + +### [Licensing](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-license) + +The licensing APIs enable you to manage your licenses. + +| API | Description | +| --- | ----------- | +| [Get license](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-license-get) | Retrieves the current license for the cluster. | +| [Update license](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-license-post) | Updates the license for the cluster. | +| [Delete license](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-license-delete) | Removes the current license. | +| [Start basic license](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-license-post-start-basic) | Starts a basic license. | +| [Start trial license](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-license-post-start-trial) | Starts a trial license. | +| [Get the trial status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-license-get-trial-status) | Returns the status of the current trial license. | + +### [Logstash](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-logstash) + +The logstash APIs enable you to manage pipelines that are used by Logstash Central Management. + +| API | Description | +| --- | ----------- | +| [Create or update Logstash pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-logstash-put-pipeline) | Creates or updates a Logstash pipeline. | +| [Get Logstash pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-logstash-get-pipeline) | Retrieves one or more Logstash pipelines. | +| [Delete Logstash pipeline](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-logstash-delete-pipeline) | Deletes a Logstash pipeline. | + +### [Machine learning](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-ml) + +The machine learning APIs enable you to retrieve information related to the {{stack}} {{ml}} features. + +| API | Description | +| --- | ----------- | +| [Get machine learning memory stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-memory-stats) | Gets information about how machine learning jobs and trained models are using memory. | +| [Get machine learning info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-info) | Gets defaults and limits used by machine learning. | +| [Set upgrade mode](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-set-upgrade-mode) | Sets a cluster wide upgrade_mode setting that prepares machine learning indices for an upgrade. | +| [Get ML job stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-job-stats) | Retrieves usage statistics for ML jobs. | +| [Get ML calendar events](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-calendar-events) | Retrieves scheduled events for ML calendars. | +| [Get ML filters](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-filters) | Retrieves ML filters. | +| [Put ML filter](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-filter) | Creates or updates an ML filter. | +| [Delete ML filter](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-filter) | Deletes an ML filter. | +| [Get ML info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-info) | Gets overall ML info. | +| [Get ML model snapshots](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-model-snapshots) | Retrieves model snapshots for ML jobs. | +| [Revert ML model snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-revert-model-snapshot) | Reverts an ML job to a previous model snapshot. | +| [Delete expired ML data](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-expired-data) | Deletes expired ML results and model snapshots. | + +### [Machine learning anomaly detection](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-ml-anomaly) + +The machine learning anomaly detection APIs enbale you to perform anomaly detection activities. + + +| API | Description | +| --- | ----------- | +| [Put Job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-job) | Creates an anomaly detection job. | +| [Get Job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-jobs) | Retrieves configuration info for anomaly detection jobs. | +| [Delete Job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-job) | Deletes an anomaly detection job. | +| [Open Job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-open-job) | Opens an existing anomaly detection job. | +| [Close anomaly detection jobs](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-close-job) | Closes an anomaly detection job. | +| [Flush Job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-flush-job) | Forces any buffered data to be processed. | +| [Forecast Job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-forecast) | Generates forecasts for anomaly detection jobs. | +| [Get Buckets](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-buckets) | Retrieves bucket results from a job. | +| [Get Records](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-records) | Retrieves anomaly records for a job. | +| [Get calendar configuration info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-calendars) | Gets calendar configuration information. | +| [Create a calendar](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-calendar) | Create a calendar. | +| [Delete a calendar](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-calendar) | Delete a calendar. | +| [Delete events from a calendar](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-calendar) | Delete events from a calendar. | +| [Add anomaly detection job to calendar](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-calendar-job) | Add an anomoly detection job to a calendar. | +| [Delete anomaly detection jobs from calendar](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-calendar-job) | Deletes anomoly detection jobs from a calendar. | +| [Get datafeeds configuration info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-datafeeds) | Get configuration information for a datafeed. | +| [Create datafeed](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-datafeed) | Creates a datafeed. | +| [Delete a datafeed](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-datafeed) | Deletes a datafeed. | +| [Delete expired ML data](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-expired-data) | Delete all job results, model snapshots and forecast data that have exceeded their retention days period. | +| [Delete expired ML data](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-expired-data) | Delete all job results, model snapshots and forecast data that have exceeded their retention days period. | +| [Get filters](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-filters) | Get a single filter or all filters. | +| [Get anomaly detection job results for influencers](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-influencers) | Get anomaly detection job results for entities that contributed to or are to blame for anomalies. | +| [Get anomaly detection job stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-job-stats) | Get anomaly detection job stats. | +| [Get anomaly detection jobs configuration info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-jobs) | You can get information for multiple anomaly detection jobs in a single API request by using a group name, a comma-separated list of jobs, or a wildcard expression. | + +### [Machine learning data frame analytics](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-ml-data-frame) + +The machine learning data frame analytics APIs enbale you to perform data frame analytics activities. + +| API | Description | +| --- | ----------- | +| [Create a data frame analytics job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-data-frame-analytics) | Creates a data frame analytics job. | +| [Get data frame analytics job configuration info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-data-frame-analyticss) | Retrieves configuration and results for analytics jobs. | +| [Delete a data frame analytics job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-data-frame-analytics) | Deletes a data frame analytics job. | +| [Start a data frame analytics job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-start-data-frame-analytics) | Starts a data frame analytics job. | +| [Stop data frame analytics jobs](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-stop-data-frame-analytics) | Stops a running data frame analytics job. | + +### [Machine learning trained models](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-ml-trained-model) + +The machine learning trained models APIs enable you to perform model management operations. + +| API | Description | +| --- | ----------- | +| [Put Trained Model](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-put-trained-model) | Uploads a trained model for inference. | +| [Get Trained Models](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-trained-models) | Retrieves configuration and stats for trained models. | +| [Delete Trained Model](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-delete-trained-model) | Deletes a trained model. | +| [Start Deployment](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-start-trained-model-deployment) | Starts a trained model deployment. | +| [Stop Deployment](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-stop-trained-model-deployment) | Stops a trained model deployment. | +| [Get Deployment Stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-ml-get-trained-models-stats) | Retrieves stats for deployed models. | + +### [Migration](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-migration) + +The migration APIs power {{kib}}'s Upgrade Assistant feature. + + +| API | Description | +| --- | ----------- | +| [Deprecation Info](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-migration-deprecations) | Retrieves deprecation warnings for cluster and indices. | +| [Get Feature Upgrade Status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-migration-get-feature-upgrade-status) | Checks upgrade status of system features. | +| [Post Feature Upgrade](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-migration-post-feature-upgrade) | Upgrades internal system features after a version upgrade. | + +### [Node lifecycle](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-shutdown) + +The node lifecycle APIs enable you to prepare nodes for temporary or permanent shutdown, monitor the shutdown status, and enable a previously shut-down node to resume normal operations. + +| API | Description | +| --- | ----------- | +| [Exclude nodes from voting](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-post-voting-config-exclusions) | Excludes nodes from voting in master elections. | +| [Clear voting config exclusions](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-delete-voting-config-exclusions) | Clears voting config exclusions. | + +### [Query rules](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-query_rules) + +Query rules enable you to configure per-query rules that are applied at query time to queries that match the specific rule. Query rules are organized into rulesets, collections of query rules that are matched against incoming queries. Query rules are applied using the rule query. If a query matches one or more rules in the ruleset, the query is re-written to apply the rules before searching. This allows pinning documents for only queries that match a specific term. + +| API | Description | +| --- | ----------- | +| [Create or update query ruleset](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-query-ruleset-put-query-ruleset) | Creates or updates a query ruleset. | +| [Get query ruleset](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-query-ruleset-get-query-ruleset) | Retrieves one or more query rulesets. | +| [Delete query ruleset](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-query-ruleset-delete-query-ruleset) | Deletes a query ruleset. | + +### [Rollup](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-rollup) + +The rollup APIs enable you to create, manage, and retrieve infromation about rollup jobs. + +| API | Description | +| --- | ----------- | +| [Create or update rollup job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-put-job) | Creates or updates a rollup job for summarizing historical data. | +| [Get rollup jobs](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-get-jobs) | Retrieves configuration for one or more rollup jobs. | +| [Delete rollup job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-delete-job) | Deletes a rollup job. | +| [Start rollup job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-start-job) | Starts a rollup job. | +| [Stop rollup job](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-stop-job) | Stops a running rollup job. | +| [Get rollup capabilities](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-get-rollup-caps) | Returns the capabilities of rollup jobs. | +| [Search rollup data](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-rollup-rollup-search) | Searches rolled-up data using a rollup index. | + +### [Script](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-script) + +Use the script support APIs to get a list of supported script contexts and languages. Use the stored script APIs to manage stored scripts and search templates. + + +| API | Description | +| --- | ----------- | +| [Add or update stored script](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-script-put-script) | Adds or updates a stored script. | +| [Get stored script](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-script-get-script) | Retrieves a stored script. | +| [Delete stored script](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-script-delete-script) | Deletes a stored script. | +| [Execute Painless script](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-script-painless-execute) | Executes a script using the Painless language. | +| [Get script contexts](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-get-script-context) | Returns available script execution contexts. | +| [Get script languages](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-get-script-languages) | Returns available scripting languages. | + +### [Search](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-search) + +The search APIs enable you to search and aggregate data stored in {{es}} indices and data streams. + +| API | Description | +| --- | ----------- | +| [Search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search) | Executes a search query on one or more indices. | +| [Multi search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-msearch) | Executes multiple search requests in a single API call. | +| [Search template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-template) | Executes a search using a stored or inline template. | +| [Render search template](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-render-search-template) | Renders a search template with parameters. | +| [Explain search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-explain) | Explains how a document scores against a query. | +| [Validate query](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-validate-query) | Validates a query without executing it. | +| [Get field capabilities](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-field-caps) | Returns the capabilities of fields across indices. | +| [Scroll search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-scroll) | Efficiently retrieves large numbers of results (pagination). | +| [Clear scroll](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-clear-scroll) | Clears search contexts for scroll requests. | + +### [Search application](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-search_application) + +The search applcation APIs enable you to manage tasks and resources related to Search Applications. + +| API | Description | +| --- | ----------- | +| [Create or update search application](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-put) | Creates or updates a search application. | +| [Get search application](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-get) | Retrieves a search application by name. | +| [Delete search application](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-delete) | Deletes a search application. | +| [Search search application](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search-application-search) | Executes a search using a search application. | + +### [Searchable snapshots](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-searchable_snapshots) + +The searchable snapshots APIs enable you to perform searchable snapshots operations. + +| API | Description | +| --- | ----------- | +| [Mount searchable snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-searchable-snapshots-mount) | Mounts a snapshot as a searchable index. | +| [Clear searchable snapshot cache](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-searchable-snapshots-clear-cache) | Clears the cache of searchable snapshots. | +| [Get searchable snapshot stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-searchable-snapshots-stats) | Returns stats about searchable snapshots. | + +### [Security](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-security) + +The security APIs enable you to perform security activities, and add, update, retrieve, and remove application privileges, role mappings, and roles. You can also create and update API keys and create and invalidate bearer tokens. + + +| API | Description | +| --- | ----------- | +| [Create or update user](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-put-user) | Creates or updates a user in the native realm. | +| [Get user](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-get-user) | Retrieves one or more users. | +| [Delete user](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-delete-user) | Deletes a user from the native realm. | +| [Create or update role](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-put-role) | Creates or updates a role. | +| [Get role](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-get-role) | Retrieves one or more roles. | +| [Delete role](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-delete-role) | Deletes a role. | +| [Create API key](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-create-api-key) | Creates an API key for access without basic auth. | +| [Invalidate API key](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-invalidate-api-key) | Invalidates one or more API keys. | +| [Authenticate](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-authenticate) | Retrieves information about the authenticated user. | + +### [Snapshot and restore](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-snapshot) + +The snapshot and restore APIs enable you to set up snapshot repositories, manage snapshot backups, and restore snapshots to a running cluster. + +| API | Description | +| --- | ----------- | +| [Clean up snapshot repository](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-cleanup-repository) | Removes stale data from a repository. | +| [Clone snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-clone) | Clones indices from a snapshot into a new snapshot. | +| [Get snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-get) | Retrieves information about snapshots. | +| [Create snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-create) | Creates a snapshot of one or more indices. | +| [Delete snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-delete) | Deletes a snapshot from a repository. | +| [Get snapshot repository](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-get-repository) | Retrieves information about snapshot repositories. | +| [Create or update snapshot repository](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-create-repository) | Registers or updates a snapshot repository. | +| [Delete snapshot repository](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-delete-repository) | Deletes a snapshot repository. | +| [Restore snapshot](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-restore) | Restores a snapshot. | +| [Analyze snapshot repository](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-repository-analyze) | Analyzes a snapshot repository for correctness and performance. | +| [Verify snapshot repository](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-repository-verify-integrity) | Verifies access to a snapshot repository. | +| [Get snapshot status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-snapshot-status) | Gets the status of a snapshot. | + +### [Snapshot lifecycle management](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-slm) + +The snapshot lifecycle management APIs enable you to set up policies to automatically take snapshots and control how long they are retained. + +| API | Description | +| --- | ----------- | +| [Get snapshot lifecycle policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-get-lifecycle) | Retrieves one or more snapshot lifecycle policies. | +| [Create or update snapshot lifecycle policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-put-lifecycle) | Creates or updates a snapshot lifecycle policy. | +| [Delete snapshot lifecycle policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-delete-lifecycle) | Deletes a snapshot lifecycle policy. | +| [Execute snapshot lifecycle policy](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-execute-lifecycle) | Triggers a snapshot lifecycle policy manually. | +| [Execute snapshot retention](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-execute-retention) | Manually apply the retention policy to force immediate removal of snapshots that are expired according to the snapshot lifecycle policy retention rules. | +| [Get snapshot lifecycle stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-get-stats) | Returns statistics about snapshot lifecycle executions. | +| [Get snapshot lifecycle status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-get-status) | Returns the status of the snapshot lifecycle management feature. | +| [Start snapshot lifecycle management](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-start) | Starts the snapshot lifecycle management feature. | +| [Stop snapshot lifecycle management](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-slm-stop) | Stops the snapshot lifecycle management feature. | + +### [SQL](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-sql) + +The SQL APIs enable you to run SQL queries on Elasticsearch indices and data streams. + +| API | Description | +| --- | ----------- | +| [Clear SQL cursor](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-sql-clear-cursor) | Clears the server-side cursor for an SQL search. | +| [Delete async SQL search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-sql-delete-async) | Deletes an async SQL search. | +| [Get async SQL search results](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-sql-get-async) | Retrieves results of an async SQL query. | +| [Get async SQL search status](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-sql-get-async-status) | Gets the current status of an async SQL search or a stored synchronous SQL search. | +| [SQL query](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-sql-query) | Executes an SQL query. | +| [Translate SQL](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-sql-translate) | Translates SQL into Elasticsearch DSL. | + +### [Synonyms](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-synonyms) + +The synonyms management APIs provide a convenient way to define and manage synonyms in an internal system index. Related synonyms can be grouped in a "synonyms set". + +| API | Description | +| --- | ----------- | +| [Get synonym set](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-synonyms-get-synonym) | Retrieves a synonym set by ID. | +| [Create of update synonym set](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-synonyms-put-synonym) | Creates or updates a synonym set. | +| [Delete synonym set](https://www.elastic.co/docs/api/doc/elasticsearch/endpoint/synonyms.delete_synonym) | Deletes a synonym set. | +| [Get synonym rule](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-synonyms-get-synonym-rule) | | +| [Get synonym sets](https://www.elastic.co/docs/api/doc/elasticsearch/endpoint/synonyms.get_synonyms) | Lists all synonym sets. | + +### [Task management](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-tasks) + +The task management APIs enable you to retrieve information about tasks or cancel tasks running in a cluster. + +| API | Description | +| --- | ----------- | +| [Cancel a task](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-tasks-cancel) | Cancels a running task. | +| [Get task information](https://www.elastic.co/docs/api/doc/elasticsearch/v9/operation/operation-tasks-get) | | +| [Get all tasks](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-tasks-list) | Retrieves information about running tasks. | + +### [Text structure](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-text_structure) + +The text structure APIs enable you to find the structure of a text field in an {{es}} index. + +| API | Description | +| --- | ----------- | +| [Find the structure of a text field](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-text-structure-find-field-structure) | | +| [Find the structure of a text message](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-text-structure-find-message-structure) | | +| [Find the structure of a text file](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-text-structure-find-structure) | Analyzes a text file and returns its structure. | +| [Test a Grok pattern](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-text-structure-test-grok-pattern) | | + +### [Transform](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-transform) + +The transform APIs enable you to create and manage transforms. + +| API | Description | +| --- | ----------- | +| [Get transforms](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-get-transform) | Retrieves configuration for one or more transforms. | +| [Create a transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-put-transform) | Creates or updates a transform job. | +| [Get transform stats](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-get-transform-stats) | Get usage information for transforms. | +| [Preview transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-preview-transform) | Previews the results of a transform job. | +| [Reset a transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-reset-transform) | Previews the results of a transform job. | +| [Delete transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-delete-transform) | Deletes a transform job. | +| [Schedule a transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-schedule-now-transform) | Previews the results of a transform job. | +| [Start transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-start-transform) | Starts a transform job. | +| [Stop transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-stop-transform) | Stops a running transform job. | +| [Update transform](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-update-transform) | Updates certain properties of a transform. | +| [Upgrade all transforms](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-transform-upgrade-transforms) | Updates certain properties of a transform. | + +### [Usage](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-xpack) + +The usage API provides usage information about the installed X-Pack features. + +| API | Description | +| --- | ----------- | +| [Get information](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-xpack-info) | Gets information about build details, license status, and a list of features currently available under the installed license. | +| [Get usage information](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-xpack-usage) | Get information about the features that are currently enabled and available under the current license. | + + +### [Watcher](https://www.elastic.co/docs/api/doc/elasticsearch/v9/group/endpoint-watcher) + +You can use Watcher to watch for changes or anomalies in your data and perform the necessary actions in response. + +| API | Description | +| --- | ----------- | +| [Acknowledge a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-ack-watch) | Acknowledges a watch action. | +| [Activate a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-activate-watch) | Activates a watch. | +| [Deactivates a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-deactivate-watch) | Deactivates a watch. | +| [Get a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-get-watch) | Retrieves a watch by ID. | +| [Create or update a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-put-watch) | Creates or updates a watch. | +| [Delete a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-delete-watch) | Deletes a watch. | +| [Run a watch](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-execute-watch) | Executes a watch manually. | +| [Get Watcher index settings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-get-settings) | Get settings for the Watcher internal index | +| [Update Watcher index settings](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-update-settings) | Update settings for the Watcher internal index | +| [Query watches](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-query-watches) | Get all registered watches in a paginated manner and optionally filter watches by a query. | +| [Start the watch service](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-start) | Starts the Watcher service. | +| [Get Watcher statistics](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-stats) | Returns statistics about the Watcher service. | +| [Stop the watch service](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-watcher-stop) | Stops the Watcher service. | diff --git a/docs/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md b/docs/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md index dc7f02a119397..35be4531e3ac9 100644 --- a/docs/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md +++ b/docs/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md @@ -34,7 +34,7 @@ For the most up-to-date API details, refer to [Search APIs](https://www.elastic. :::: -You can use RRF as part of a [search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search) to combine and rank documents using separate sets of top documents (result sets) from a combination of [child retrievers](/reference/elasticsearch/rest-apis/retrievers.md) using an [RRF retriever](/reference/elasticsearch/rest-apis/retrievers.md#rrf-retriever). A minimum of **two** child retrievers is required for ranking. +You can use RRF as part of a [search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search) to combine and rank documents using separate sets of top documents (result sets) from a combination of [child retrievers](/reference/elasticsearch/rest-apis/retrievers.md) using an [RRF retriever](/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md). A minimum of **two** child retrievers is required for ranking. An RRF retriever is an optional object defined as part of a search request’s [retriever parameter](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#request-body-retriever). The RRF retriever object contains the following parameters: diff --git a/docs/reference/elasticsearch/rest-apis/retrievers.md b/docs/reference/elasticsearch/rest-apis/retrievers.md index 01472c45ff2f8..fa3810a435a01 100644 --- a/docs/reference/elasticsearch/rest-apis/retrievers.md +++ b/docs/reference/elasticsearch/rest-apis/retrievers.md @@ -3,6 +3,7 @@ mapped_pages: - https://www.elastic.co/guide/en/elasticsearch/reference/current/retriever.html applies_to: stack: all + serverless: ga --- # Retrievers [retriever] @@ -14,1026 +15,63 @@ Refer to [*Retrievers*](docs-content://solutions/search/retrievers-overview.md) :::: - -::::{admonition} New API reference -For the most up-to-date API details, refer to [Search APIs](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-search). - -:::: - - The following retrievers are available: -`standard` -: A [retriever](#standard-retriever) that replaces the functionality of a traditional [query](/reference/query-languages/querydsl.md). - `knn` -: A [retriever](#knn-retriever) that replaces the functionality of a [knn search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-api-knn). +: The [knn](retrievers/knn-retriever.md) retriever replaces the functionality of a [knn search](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-api-knn). `linear` -: A [retriever](#linear-retriever) that linearly combines the scores of other retrievers for the top documents. +: The [linear](retrievers/linear-retriever.md) retriever linearly combines the scores of other retrievers for the top documents. + +`pinned` {applies_to}`stack: GA 9.1` +: The [pinned](retrievers/pinned-retriever.md) retriever always places specified documents at the top of the results, with the remaining hits provided by a secondary retriever. `rescorer` -: A [retriever](#rescorer-retriever) that replaces the functionality of the [query rescorer](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore). +: The [rescorer](retrievers/rescorer-retriever.md) retriever replaces the functionality of the [query rescorer](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore). `rrf` -: A [retriever](#rrf-retriever) that produces top documents from [reciprocal rank fusion (RRF)](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md). - -`text_similarity_reranker` -: A [retriever](#text-similarity-reranker-retriever) that enhances search results by re-ranking documents based on semantic similarity to a specified inference text, using a machine learning model. - -`pinned` {applies_to}`stack: GA 9.1` -: A [retriever](#pinned-retriever) that always places specified documents at the top of the results, with the remaining hits provided by a secondary retriever. +: The [rrf](retrievers/rrf-retriever.md) retriever produces top documents from [reciprocal rank fusion (RRF)](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md). `rule` -: A [retriever](#rule-retriever) that applies contextual [Searching with query rules](/reference/elasticsearch/rest-apis/searching-with-query-rules.md#query-rules) to pin or exclude documents for specific queries. - -## Standard Retriever [standard-retriever] - -A standard retriever returns top documents from a traditional [query](/reference/query-languages/querydsl.md). - - -#### Parameters: [standard-retriever-parameters] - -`query` -: (Optional, [query object](/reference/query-languages/querydsl.md)) - - Defines a query to retrieve a set of top documents. - - -`filter` -: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - - Applies a [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to this retriever, where all documents must match this query but do not contribute to the score. - - -`search_after` -: (Optional, [search after object](/reference/elasticsearch/rest-apis/paginate-search-results.md#search-after)) - - Defines a search after object parameter used for pagination. - - -`terminate_after` -: (Optional, integer) Maximum number of documents to collect for each shard. If a query reaches this limit, {{es}} terminates the query early. {{es}} collects documents before sorting. - - ::::{important} - Use with caution. {{es}} applies this parameter to each shard handling the request. When possible, let {{es}} perform early termination automatically. Avoid specifying this parameter for requests that target data streams with backing indices across multiple data tiers. - :::: - - -`sort` -: (Optional, [sort object](/reference/elasticsearch/rest-apis/sort-search-results.md)) A sort object that specifies the order of matching documents. - - -`min_score` -: (Optional, `float`) - - Minimum [`_score`](/reference/query-languages/query-dsl/query-filter-context.md#relevance-scores) for matching documents. Documents with a lower `_score` are not included in the top documents. - - -`collapse` -: (Optional, [collapse object](/reference/elasticsearch/rest-apis/collapse-search-results.md)) - - Collapses the top documents by a specified key into a single top document per key. - - -### Restrictions [_restrictions] - -When a retriever tree contains a compound retriever (a retriever with two or more child retrievers) the [search after](/reference/elasticsearch/rest-apis/paginate-search-results.md#search-after) parameter is not supported. - - -### Example [standard-retriever-example] - -```console -GET /restaurants/_search -{ - "retriever": { <1> - "standard": { <2> - "query": { <3> - "bool": { <4> - "should": [ <5> - { - "match": { <6> - "region": "Austria" - } - } - ], - "filter": [ <7> - { - "term": { <8> - "year": "2019" <9> - } - } - ] - } - } - } - } -} -``` - -1. Opens the `retriever` object. -2. The `standard` retriever is used for defining traditional {{es}} queries. -3. The entry point for defining the search query. -4. The `bool` object allows for combining multiple query clauses logically. -5. The `should` array indicates conditions under which a document will match. Documents matching these conditions will have increased relevancy scores. -6. The `match` object finds documents where the `region` field contains the word "Austria." -7. The `filter` array provides filtering conditions that must be met but do not contribute to the relevancy score. -8. The `term` object is used for exact matches, in this case, filtering documents by the `year` field. -9. The exact value to match in the `year` field. - - - - -## kNN Retriever [knn-retriever] - -A kNN retriever returns top documents from a [k-nearest neighbor search (kNN)](docs-content://solutions/search/vector/knn.md). - - -#### Parameters [knn-retriever-parameters] - -`field` -: (Required, string) - - The name of the vector field to search against. Must be a [`dense_vector` field with indexing enabled](/reference/elasticsearch/mapping-reference/dense-vector.md#index-vectors-knn-search). - - -`query_vector` -: (Required if `query_vector_builder` is not defined, array of `float`) - - Query vector. Must have the same number of dimensions as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector. - - -`query_vector_builder` -: (Required if `query_vector` is not defined, query vector builder object) - - Defines a [model](docs-content://solutions/search/vector/knn.md#knn-semantic-search) to build a query vector. - - -`k` -: (Required, integer) - - Number of nearest neighbors to return as top hits. This value must be fewer than or equal to `num_candidates`. - - -`num_candidates` -: (Required, integer) - - The number of nearest neighbor candidates to consider per shard. Needs to be greater than `k`, or `size` if `k` is omitted, and cannot exceed 10,000. {{es}} collects `num_candidates` results from each shard, then merges them to find the top `k` results. Increasing `num_candidates` tends to improve the accuracy of the final `k` results. Defaults to `Math.min(1.5 * k, 10_000)`. - - -`filter` -: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - - Query to filter the documents that can match. The kNN search will return the top `k` documents that also match this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents are allowed to match. - - -`similarity` -: (Optional, float) - - The minimum similarity required for a document to be considered a match. The similarity value calculated relates to the raw [`similarity`](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-similarity) used. Not the document score. The matched documents are then scored according to [`similarity`](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-similarity) and the provided `boost` is applied. - - The `similarity` parameter is the direct vector similarity calculation. - - * `l2_norm`: also known as Euclidean, will include documents where the vector is within the `dims` dimensional hypersphere with radius `similarity` with origin at `query_vector`. - * `cosine`, `dot_product`, and `max_inner_product`: Only return vectors where the cosine similarity or dot-product are at least the provided `similarity`. - - Read more here: [knn similarity search](docs-content://solutions/search/vector/knn.md#knn-similarity-search) - - -`rescore_vector` -: (Optional, object) Apply oversampling and rescoring to quantized vectors. - -::::{note} -Rescoring only makes sense for quantized vectors; when [quantization](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-quantization) is not used, the original vectors are used for scoring. Rescore option will be ignored for non-quantized `dense_vector` fields. -:::: - - -`oversample` -: (Required, float) - - Applies the specified oversample factor to `k` on the approximate kNN search. The approximate kNN search will: - - * Retrieve `num_candidates` candidates per shard. - * From these candidates, the top `k * oversample` candidates per shard will be rescored using the original vectors. - * The top `k` rescored candidates will be returned. - - -See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details. - - -### Restrictions [_restrictions_2] - -The parameters `query_vector` and `query_vector_builder` cannot be used together. - - -### Example [knn-retriever-example] - -```console -GET /restaurants/_search -{ - "retriever": { - "knn": { <1> - "field": "vector", <2> - "query_vector": [10, 22, 77], <3> - "k": 10, <4> - "num_candidates": 10 <5> - } - } -} -``` - -1. Configuration for k-nearest neighbor (knn) search, which is based on vector similarity. -2. Specifies the field name that contains the vectors. -3. The query vector against which document vectors are compared in the `knn` search. -4. The number of nearest neighbors to return as top hits. This value must be fewer than or equal to `num_candidates`. -5. The size of the initial candidate set from which the final `k` nearest neighbors are selected. - - - - -## Linear Retriever [linear-retriever] - -A retriever that normalizes and linearly combines the scores of other retrievers. - - -#### Parameters [linear-retriever-parameters] - -::::{note} -Either `query` or `retrievers` must be specified. -Combining `query` and `retrievers` is not supported. -:::: - -`query` {applies_to}`stack: ga 9.1` -: (Optional, String) - - The query to use when using the [multi-field query format](#multi-field-query-format). - -`fields` {applies_to}`stack: ga 9.1` -: (Optional, array of strings) - - The fields to query when using the [multi-field query format](#multi-field-query-format). - Fields can include boost values using the `^` notation (e.g., `"field^2"`). - If not specified, uses the index's default fields from the `index.query.default_field` index setting, which is `*` by default. - -`normalizer` {applies_to}`stack: ga 9.1` -: (Optional, String) - - The normalizer to use when using the [multi-field query format](#multi-field-query-format). - See [normalizers](#linear-retriever-normalizers) for supported values. - Required when `query` is specified. - - ::::{warning} - Avoid using `none` as that will disable normalization and may bias the result set towards lexical matches. - See [field grouping](#multi-field-field-grouping) for more information. - :::: - -`retrievers` -: (Optional, array of objects) - - A list of the sub-retrievers' configuration, that we will take into account and whose result sets we will merge through a weighted sum. - Each configuration can have a different weight and normalization depending on the specified retriever. - -`rank_window_size` -: (Optional, integer) - - This value determines the size of the individual result sets per query. A higher value will improve result relevance at the cost of performance. - The final ranked result set is pruned down to the search request’s [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). - `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. - Defaults to 10. - -`filter` -: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - - Applies the specified [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever’s specifications. - -Each entry in the `retrievers` array specifies the following parameters: - -`retriever` -: (Required, a `retriever` object) - - Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size` results, which will later be merged based on the specified `weight` and `normalizer`. - -`weight` -: (Optional, float) - - The weight that each score of this retriever’s top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0. - -`normalizer` -: (Optional, String) - - Specifies how the retriever’s score will be normalized before applying the specified `weight`. - See [normalizers](#linear-retriever-normalizers) for supported values. - Defaults to `none`. - -See also [this hybrid search example](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-retriever) using a linear retriever on how to independently configure and apply normalizers to retrievers. - -#### Normalizers [linear-retriever-normalizers] - -The `linear` retriever supports the following normalizers: - -* `none`: No normalization -* `minmax`: Normalizes scores based on the following formula: - - ``` - score = (score - min) / (max - min) - ``` -* `l2_norm`: Normalizes scores using the L2 norm of the score values - - -## RRF Retriever [rrf-retriever] - -An [RRF](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md) retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers. -Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set. - - -#### Parameters [rrf-retriever-parameters] - -::::{note} -Either `query` or `retrievers` must be specified. -Combining `query` and `retrievers` is not supported. -:::: - -`query` {applies_to}`stack: ga 9.1` -: (Optional, String) - - The query to use when using the [multi-field query format](#multi-field-query-format). - -`fields` {applies_to}`stack: ga 9.1` -: (Optional, array of strings) - - The fields to query when using the [multi-field query format](#multi-field-query-format). - If not specified, uses the index's default fields from the `index.query.default_field` index setting, which is `*` by default. - -`retrievers` -: (Optional, array of retriever objects) - - A list of child retrievers to specify which sets of returned top documents will have the RRF formula applied to them. - Each child retriever carries an equal weight as part of the RRF formula. Two or more child retrievers are required. - -`rank_constant` -: (Optional, integer) - - This value determines how much influence documents in individual result sets per query have over the final ranked result set. A higher value indicates that lower ranked documents have more influence. This value must be greater than or equal to `1`. Defaults to `60`. - -`rank_window_size` -: (Optional, integer) - - This value determines the size of the individual result sets per query. - A higher value will improve result relevance at the cost of performance. - The final ranked result set is pruned down to the search request’s [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). - `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. - Defaults to 10. - -`filter` -: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - - Applies the specified [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever’s specifications. - -### Example: Hybrid search [rrf-retriever-example-hybrid] - -A simple hybrid search example (lexical search + dense vector search) combining a `standard` retriever with a `knn` retriever using RRF: - -```console -GET /restaurants/_search -{ - "retriever": { - "rrf": { <1> - "retrievers": [ <2> - { - "standard": { <3> - "query": { - "multi_match": { - "query": "Austria", - "fields": [ - "city", - "region" - ] - } - } - } - }, - { - "knn": { <4> - "field": "vector", - "query_vector": [10, 22, 77], - "k": 10, - "num_candidates": 10 - } - } - ], - "rank_constant": 1, <5> - "rank_window_size": 50 <6> - } - } -} -``` - -1. Defines a retriever tree with an RRF retriever. -2. The sub-retriever array. -3. The first sub-retriever is a `standard` retriever. -4. The second sub-retriever is a `knn` retriever. -5. The rank constant for the RRF retriever. -6. The rank window size for the RRF retriever. - - - -### Example: Hybrid search with sparse vectors [rrf-retriever-example-hybrid-sparse] - -A more complex hybrid search example (lexical search + ELSER sparse vector search + dense vector search) using RRF: - -```console -GET movies/_search -{ - "retriever": { - "rrf": { - "retrievers": [ - { - "standard": { - "query": { - "sparse_vector": { - "field": "plot_embedding", - "inference_id": "my-elser-model", - "query": "films that explore psychological depths" - } - } - } - }, - { - "standard": { - "query": { - "multi_match": { - "query": "crime", - "fields": [ - "plot", - "title" - ] - } - } - } - }, - { - "knn": { - "field": "vector", - "query_vector": [10, 22, 77], - "k": 10, - "num_candidates": 10 - } - } - ] - } - } -} -``` - - -## Rescorer Retriever [rescorer-retriever] - -The `rescorer` retriever re-scores only the results produced by its child retriever. For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard. - -For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally. - -When using the `rescorer`, an error is returned if the following conditions are not met: - -* The minimum configured rescore’s `window_size` is: - - * Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups. - * Greater than or equal to the `size` of the search request when used as the primary retriever in the tree. - -* And the maximum rescore’s `window_size` is: - - * Smaller than or equal to the `size` or `rank_window_size` of the child retriever. - - - -#### Parameters [rescorer-retriever-parameters] - -`rescore` -: (Required. [A rescorer definition or an array of rescorer definitions](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore)) - - Defines the [rescorers](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore) applied sequentially to the top documents returned by the child retriever. - - -`retriever` -: (Required. `retriever`) - - Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked. - - -`filter` -: (Optional. [query object or list of query objects](/reference/query-languages/querydsl.md)) - - Applies a [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to the retriever, ensuring that all documents match the filter criteria without affecting their scores. - - - -### Example [rescorer-retriever-example] - -The `rescorer` retriever can be placed at any level within the retriever tree. The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever: - -```console -GET movies/_search -{ - "size": 10, <1> - "retriever": { - "rescorer": { <2> - "rescore": { - "window_size": 50, <3> - "query": { <4> - "rescore_query": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0", - "params": { - "queryVector": [-0.5, 90.0, -10, 14.8, -156.0] - } - } - } - } - } - }, - "retriever": { <5> - "rrf": { - "rank_window_size": 100, <6> - "retrievers": [ - { - "standard": { - "query": { - "sparse_vector": { - "field": "plot_embedding", - "inference_id": "my-elser-model", - "query": "films that explore psychological depths" - } - } - } - }, - { - "standard": { - "query": { - "multi_match": { - "query": "crime", - "fields": [ - "plot", - "title" - ] - } - } - } - }, - { - "knn": { - "field": "vector", - "query_vector": [10, 22, 77], - "k": 10, - "num_candidates": 10 - } - } - ] - } - } - } - } -} -``` - -1. Specifies the number of top documents to return in the final response. -2. A `rescorer` retriever applied as the final step. -3. Defines the number of documents to rescore from the child retriever. -4. The definition of the `query` rescorer. -5. Specifies the child retriever definition. -6. Defines the number of documents returned by the `rrf` retriever, which limits the available documents to - - - -## Text Similarity Re-ranker Retriever [text-similarity-reranker-retriever] - -The `text_similarity_reranker` retriever uses an NLP model to improve search results by reordering the top-k documents based on their semantic similarity to the query. - -::::{tip} -Refer to [*Semantic re-ranking*](docs-content://solutions/search/ranking/semantic-reranking.md) for a high level overview of semantic re-ranking. - -:::: - - -### Prerequisites [_prerequisites_15] - -To use `text_similarity_reranker`, you can rely on the preconfigured `.rerank-v1-elasticsearch` inference endpoint, which uses the [Elastic Rerank model](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-rerank.md) and serves as the default if no `inference_id` is provided. This model is optimized for reranking based on text similarity. If you'd like to use a different model, you can set up a custom inference endpoint for the `rerank` task using the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). The endpoint should be configured with a machine learning model capable of computing text similarity. Refer to [the Elastic NLP model reference](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-model-ref.md#ml-nlp-model-ref-text-similarity) for a list of third-party text similarity models supported by {{es}}. - -You have the following options: - -* Use the built-in [Elastic Rerank](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-rerank.md) cross-encoder model via the inference API’s {{es}} service. See [this example](https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-service-elasticsearch.html#inference-example-elastic-reranker) for creating an endpoint using the Elastic Rerank model. -* Use the [Cohere Rerank inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. -* Use the [Google Vertex AI inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. -* Upload a model to {{es}} with [Eland](eland://reference/machine-learning.md#ml-nlp-pytorch) using the `text_similarity` NLP task type. - - * Then set up an [{{es}} service inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. - * Refer to the [example](#text-similarity-reranker-retriever-example-eland) on this page for a step-by-step guide. - - -::::{important} -Scores from the re-ranking process are normalized using the following formula before returned to the user, to avoid having negative scores. - -```text -score = max(score, 0) + min(exp(score), 1) -``` - -Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity). To revert back if needed, one can use: - -```text -score = score - 1, if score >= 0 -score = ln(score), if score < 0 -``` - -:::: - - - -#### Parameters [text-similarity-reranker-retriever-parameters] - -`retriever` -: (Required, `retriever`) - - The child retriever that generates the initial set of top documents to be re-ranked. - - -`field` -: (Required, `string`) - - The document field to be used for text similarity comparisons. This field should contain the text that will be evaluated against the `inferenceText`. - - -`inference_id` -: (Optional, `string`) - - Unique identifier of the inference endpoint created using the {{infer}} API. If you don’t specify an inference endpoint, the `inference_id` field defaults to `.rerank-v1-elasticsearch`, a preconfigured endpoint for the elasticsearch `.rerank-v1` model. - - -`inference_text` -: (Required, `string`) - - The text snippet used as the basis for similarity comparison. - - -`rank_window_size` -: (Optional, `int`) - - The number of top documents to consider in the re-ranking process. Defaults to `10`. - - -`min_score` -: (Optional, `float`) - - Sets a minimum threshold score for including documents in the re-ranked results. Documents with similarity scores below this threshold will be excluded. Note that score calculations vary depending on the model used. - - -`filter` -: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - - Applies the specified [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to the child `retriever`. If the child retriever already specifies any filters, then this top-level filter is applied in conjuction with the filter defined in the child retriever. - - - -### Example: Elastic Rerank [text-similarity-reranker-retriever-example-elastic-rerank] - -::::{tip} -Refer to this [Python notebook](https://github.com/elastic/elasticsearch-labs/blob/main/notebooks/search/12-semantic-reranking-elastic-rerank.ipynb) for an end-to-end example using Elastic Rerank. - -:::: - - -This example demonstrates how to deploy the [Elastic Rerank](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-rerank.md) model and use it to re-rank search results using the `text_similarity_reranker` retriever. - -Follow these steps: - -1. Create an inference endpoint for the `rerank` task using the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). - - ```console - PUT _inference/rerank/my-elastic-rerank - { - "service": "elasticsearch", - "service_settings": { - "model_id": ".rerank-v1", - "num_threads": 1, - "adaptive_allocations": { <1> - "enabled": true, - "min_number_of_allocations": 1, - "max_number_of_allocations": 10 - } - } - } - ``` - - 1. [Adaptive allocations](docs-content://deploy-manage/autoscaling/trained-model-autoscaling.md#enabling-autoscaling-through-apis-adaptive-allocations) will be enabled with the minimum of 1 and the maximum of 10 allocations. - -2. Define a `text_similarity_rerank` retriever: - - ```console - POST _search - { - "retriever": { - "text_similarity_reranker": { - "retriever": { - "standard": { - "query": { - "match": { - "text": "How often does the moon hide the sun?" - } - } - } - }, - "field": "text", - "inference_id": "my-elastic-rerank", - "inference_text": "How often does the moon hide the sun?", - "rank_window_size": 100, - "min_score": 0.5 - } - } - } - ``` - - - -### Example: Cohere Rerank [text-similarity-reranker-retriever-example-cohere] - -This example enables out-of-the-box semantic search by re-ranking top documents using the Cohere Rerank API. This approach eliminates the need to generate and store embeddings for all indexed documents. This requires a [Cohere Rerank inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) that is set up for the `rerank` task type. - -```console -GET /index/_search -{ - "retriever": { - "text_similarity_reranker": { - "retriever": { - "standard": { - "query": { - "match_phrase": { - "text": "landmark in Paris" - } - } - } - }, - "field": "text", - "inference_id": "my-cohere-rerank-model", - "inference_text": "Most famous landmark in Paris", - "rank_window_size": 100, - "min_score": 0.5 - } - } -} -``` - - -### Example: Semantic re-ranking with a Hugging Face model [text-similarity-reranker-retriever-example-eland] - -The following example uses the `cross-encoder/ms-marco-MiniLM-L-6-v2` model from Hugging Face to rerank search results based on semantic similarity. The model must be uploaded to {{es}} using [Eland](eland://reference/machine-learning.md#ml-nlp-pytorch). - -::::{tip} -Refer to [the Elastic NLP model reference](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-model-ref.md#ml-nlp-model-ref-text-similarity) for a list of third party text similarity models supported by {{es}}. - -:::: - - -Follow these steps to load the model and create a semantic re-ranker. - -1. Install Eland using `pip` - - ```sh - python -m pip install eland[pytorch] - ``` - -2. Upload the model to {{es}} using Eland. This example assumes you have an Elastic Cloud deployment and an API key. Refer to the [Eland documentation](eland://reference/machine-learning.md#ml-nlp-pytorch-auth) for more authentication options. - - ```sh - eland_import_hub_model \ - --cloud-id $CLOUD_ID \ - --es-api-key $ES_API_KEY \ - --hub-model-id cross-encoder/ms-marco-MiniLM-L-6-v2 \ - --task-type text_similarity \ - --clear-previous \ - --start - ``` - -3. Create an inference endpoint for the `rerank` task - - ```console - PUT _inference/rerank/my-msmarco-minilm-model - { - "service": "elasticsearch", - "service_settings": { - "num_allocations": 1, - "num_threads": 1, - "model_id": "cross-encoder__ms-marco-minilm-l-6-v2" - } - } - ``` - -4. Define a `text_similarity_rerank` retriever. - - ```console - POST movies/_search - { - "retriever": { - "text_similarity_reranker": { - "retriever": { - "standard": { - "query": { - "match": { - "genre": "drama" - } - } - } - }, - "field": "plot", - "inference_id": "my-msmarco-minilm-model", - "inference_text": "films that explore psychological depths" - } - } - } - ``` +: The [rule](retrievers/rule-retriever.md) retriever applies contextual [Searching with query rules](/reference/elasticsearch/rest-apis/searching-with-query-rules.md#query-rules) to pin or exclude documents for specific queries. - This retriever uses a standard `match` query to search the `movie` index for films tagged with the genre "drama". It then re-ranks the results based on semantic similarity to the text in the `inference_text` parameter, using the model we uploaded to {{es}}. - - - - -## Query Rules Retriever [rule-retriever] - -The `rule` retriever enables fine-grained control over search results by applying contextual [query rules](/reference/elasticsearch/rest-apis/searching-with-query-rules.md#query-rules) to pin or exclude documents for specific queries. This retriever has similar functionality to the [rule query](/reference/query-languages/query-dsl/query-dsl-rule-query.md), but works out of the box with other retrievers. - -### Prerequisites [_prerequisites_16] - -To use the `rule` retriever you must first create one or more query rulesets using the [query rules management APIs](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-query_rules). - - -#### Parameters [rule-retriever-parameters] - -`retriever` -: (Required, `retriever`) - - The child retriever that returns the results to apply query rules on top of. This can be a standalone retriever such as the [standard](#standard-retriever) or [knn](#knn-retriever) retriever, or it can be a compound retriever. - - -`ruleset_ids` -: (Required, `array`) - - An array of one or more unique [query ruleset](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-query_rules) IDs with query-based rules to match and apply as applicable. Rulesets and their associated rules are evaluated in the order in which they are specified in the query and ruleset. The maximum number of rulesets to specify is 10. - - -`match_criteria` -: (Required, `object`) - - Defines the match criteria to apply to rules in the given query ruleset(s). Match criteria should match the keys defined in the `criteria.metadata` field of the rule. - - -`rank_window_size` -: (Optional, `int`) - - The number of top documents to return from the `rule` retriever. Defaults to `10`. - - - -### Example: Rule retriever [rule-retriever-example] - -This example shows the rule retriever executed without any additional retrievers. It runs the query defined by the `retriever` and applies the rules from `my-ruleset` on top of the returned results. - -```console -GET movies/_search -{ - "retriever": { - "rule": { - "match_criteria": { - "query_string": "harry potter" - }, - "ruleset_ids": [ - "my-ruleset" - ], - "retriever": { - "standard": { - "query": { - "query_string": { - "query": "harry potter" - } - } - } - } - } - } -} -``` - - -### Example: Rule retriever combined with RRF [rule-retriever-example-rrf] - -This example shows how to combine the `rule` retriever with other rerank retrievers such as [rrf](#rrf-retriever) or [text_similarity_reranker](#text-similarity-reranker-retriever). - -::::{warning} -The `rule` retriever will apply rules to any documents returned from its defined `retriever` or any of its sub-retrievers. This means that for the best results, the `rule` retriever should be the outermost defined retriever. Nesting a `rule` retriever as a sub-retriever under a reranker such as `rrf` or `text_similarity_reranker` may not produce the expected results. - -:::: - - -```console -GET movies/_search -{ - "retriever": { - "rule": { <1> - "match_criteria": { - "query_string": "harry potter" - }, - "ruleset_ids": [ - "my-ruleset" - ], - "retriever": { - "rrf": { <2> - "retrievers": [ - { - "standard": { - "query": { - "query_string": { - "query": "sorcerer's stone" - } - } - } - }, - { - "standard": { - "query": { - "query_string": { - "query": "chamber of secrets" - } - } - } - } - ] - } - } - } - } -} -``` - -1. The `rule` retriever is the outermost retriever, applying rules to the search results that were previously reranked using the `rrf` retriever. -2. The `rrf` retriever returns results from all of its sub-retrievers, and the output of the `rrf` retriever is used as input to the `rule` retriever. +`standard` +: The [standard](retrievers/standard-retriever.md) retriever replaces the functionality of a traditional [query](/reference/query-languages/querydsl.md). -## Pinned Retriever [pinned-retriever] -```yaml {applies_to} -stack: ga 9.1 -``` +`text_similarity_reranker` +: The [text_similarity_reranker](retrievers/text-similarity-reranker-retriever.md) retriever enhances search results by re-ranking documents based on semantic similarity to a specified inference text, using a machine learning model. +## Common usage guidelines [retriever-common-parameters] -A `pinned` retriever returns top documents by always placing specific documents at the top of the results, with the remaining hits provided by a secondary retriever. This retriever offers similar functionality to the [pinned query](/reference/query-languages/query-dsl/query-dsl-pinned-query.md), but works seamlessly with other retrievers. This is useful for promoting certain documents for particular queries, regardless of their relevance score. -#### Parameters [pinned-retriever-parameters] +### Using `from` and `size` with a retriever tree [retriever-size-pagination] -`ids` -: (Optional, array of strings) +The [`from`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-from-param) and [`size`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param) parameters are provided globally as part of the general [search API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search). They are applied to all retrievers in a retriever tree, unless a specific retriever overrides the `size` parameter using a different parameter such as `rank_window_size`. Though, the final search hits are always limited to `size`. - A list of document IDs to pin at the top of the results, in the order provided. -`docs` -: (Optional, array of objects) +### Using aggregations with a retriever tree [retriever-aggregations] - A list of objects specifying documents to pin. Each object must contain at least an `_id` field, and may also specify `_index` if pinning documents across multiple indices. +[Aggregations](/reference/aggregations/index.md) are globally specified as part of a search request. The query used for an aggregation is the combination of all leaf retrievers as `should` clauses in a [boolean query](/reference/query-languages/query-dsl/query-dsl-bool-query.md). -`retriever` -: (Optional, retriever object) - A retriever (for example a `standard` retriever or a specialized retriever such as `rrf` retriever) used to retrieve the remaining documents after the pinned ones. +### Restrictions on search parameters when specifying a retriever [retriever-restrictions] -Either `ids` or `docs` must be specified. +When a retriever is specified as part of a search, the following elements are not allowed at the top-level: -### Example using `docs` [pinned-retriever-example-documents] +* [`query`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#request-body-search-query) +* [`knn`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-api-knn) +* [`search_after`](/reference/elasticsearch/rest-apis/paginate-search-results.md#search-after) +* [`terminate_after`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#request-body-search-terminate-after) +* [`sort`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-sort-param) +* [`rescore`](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore) use a [rescorer retriever](retrievers/rescorer-retriever.md) instead -```console -GET /restaurants/_search -{ - "retriever": { - "pinned": { - "docs": [ - { "_id": "doc1", "_index": "my-index" }, - { "_id": "doc2" } - ], - "retriever": { - "standard": { - "query": { - "match": { - "title": "elasticsearch" - } - } - } - } - } - } -} -``` ## Multi-field query format [multi-field-query-format] ```yaml {applies_to} stack: ga 9.1 ``` -The `linear` and `rrf` retrievers support a multi-field query format that provides a simplified way to define searches across multiple fields without explicitly specifying inner retrievers. +The [`linear`](retrievers/linear-retriever.md) and [`rrf`](retrievers/rrf-retriever.md) retrievers support a multi-field query format that provides a simplified way to define searches across multiple fields without explicitly specifying inner retrievers. This format automatically generates appropriate inner retrievers based on the field types and query parameters. This is a great way to search an index, knowing little to nothing about its schema, while also handling normalization across lexical and semantic matches. @@ -1201,28 +239,4 @@ Note, however, that wildcard field patterns will only resolve to fields that eit ### Examples - [RRF with the multi-field query format](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-rrf-multi-field-query-format) -- [Linear retriever with the multi-field query format](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-multi-field-query-format) - -## Common usage guidelines [retriever-common-parameters] - - -### Using `from` and `size` with a retriever tree [retriever-size-pagination] - -The [`from`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-from-param) and [`size`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param) parameters are provided globally as part of the general [search API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search). They are applied to all retrievers in a retriever tree, unless a specific retriever overrides the `size` parameter using a different parameter such as `rank_window_size`. Though, the final search hits are always limited to `size`. - - -### Using aggregations with a retriever tree [retriever-aggregations] - -[Aggregations](/reference/aggregations/index.md) are globally specified as part of a search request. The query used for an aggregation is the combination of all leaf retrievers as `should` clauses in a [boolean query](/reference/query-languages/query-dsl/query-dsl-bool-query.md). - - -### Restrictions on search parameters when specifying a retriever [retriever-restrictions] - -When a retriever is specified as part of a search, the following elements are not allowed at the top-level: - -* [`query`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#request-body-search-query) -* [`knn`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-api-knn) -* [`search_after`](/reference/elasticsearch/rest-apis/paginate-search-results.md#search-after) -* [`terminate_after`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#request-body-search-terminate-after) -* [`sort`](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-sort-param) -* [`rescore`](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore) use a [rescorer retriever](#rescorer-retriever) instead +- [Linear retriever with the multi-field query format](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-multi-field-query-format) \ No newline at end of file diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/knn-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/knn-retriever.md new file mode 100644 index 0000000000000..817f4d3fab681 --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/knn-retriever.md @@ -0,0 +1,109 @@ +--- +applies_to: + stack: all + serverless: ga +--- + +# kNN retriever [knn-retriever] + +A kNN retriever returns top documents from a [k-nearest neighbor search (kNN)](docs-content://solutions/search/vector/knn.md). + + +## Parameters [knn-retriever-parameters] + +`field` +: (Required, string) + + The name of the vector field to search against. Must be a [`dense_vector` field with indexing enabled](/reference/elasticsearch/mapping-reference/dense-vector.md#index-vectors-knn-search). + + +`query_vector` +: (Required if `query_vector_builder` is not defined, array of `float`) + + Query vector. Must have the same number of dimensions as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector. + + +`query_vector_builder` +: (Required if `query_vector` is not defined, query vector builder object) + + Defines a [model](docs-content://solutions/search/vector/knn.md#knn-semantic-search) to build a query vector. + + +`k` +: (Required, integer) + + Number of nearest neighbors to return as top hits. This value must be fewer than or equal to `num_candidates`. + + +`num_candidates` +: (Required, integer) + + The number of nearest neighbor candidates to consider per shard. Needs to be greater than `k`, or `size` if `k` is omitted, and cannot exceed 10,000. {{es}} collects `num_candidates` results from each shard, then merges them to find the top `k` results. Increasing `num_candidates` tends to improve the accuracy of the final `k` results. Defaults to `Math.min(1.5 * k, 10_000)`. + + +`filter` +: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) + + Query to filter the documents that can match. The kNN search will return the top `k` documents that also match this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents are allowed to match. + + +`similarity` +: (Optional, float) + + The minimum similarity required for a document to be considered a match. The similarity value calculated relates to the raw [`similarity`](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-similarity) used. Not the document score. The matched documents are then scored according to [`similarity`](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-similarity) and the provided `boost` is applied. + + The `similarity` parameter is the direct vector similarity calculation. + + * `l2_norm`: also known as Euclidean, will include documents where the vector is within the `dims` dimensional hypersphere with radius `similarity` with origin at `query_vector`. + * `cosine`, `dot_product`, and `max_inner_product`: Only return vectors where the cosine similarity or dot-product are at least the provided `similarity`. + + Read more here: [knn similarity search](docs-content://solutions/search/vector/knn.md#knn-similarity-search) + + +`rescore_vector` +: (Optional, object) Apply oversampling and rescoring to quantized vectors. + +::::{note} +Rescoring only makes sense for quantized vectors; when [quantization](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-quantization) is not used, the original vectors are used for scoring. Rescore option will be ignored for non-quantized `dense_vector` fields. +:::: + + +`oversample` +: (Required, float) + + Applies the specified oversample factor to `k` on the approximate kNN search. The approximate kNN search will: + + * Retrieve `num_candidates` candidates per shard. + * From these candidates, the top `k * oversample` candidates per shard will be rescored using the original vectors. + * The top `k` rescored candidates will be returned. + + +See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details. + + +## Restrictions [_restrictions_2] + +The parameters `query_vector` and `query_vector_builder` cannot be used together. + + +## Example [knn-retriever-example] + +```console +GET /restaurants/_search +{ + "retriever": { + "knn": { <1> + "field": "vector", <2> + "query_vector": [10, 22, 77], <3> + "k": 10, <4> + "num_candidates": 10 <5> + } + } +} +``` + +1. Configuration for k-nearest neighbor (knn) search, which is based on vector similarity. +2. Specifies the field name that contains the vectors. +3. The query vector against which document vectors are compared in the `knn` search. +4. The number of nearest neighbors to return as top hits. This value must be fewer than or equal to `num_candidates`. +5. The size of the initial candidate set from which the final `k` nearest neighbors are selected. diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md new file mode 100644 index 0000000000000..119438a60057a --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md @@ -0,0 +1,93 @@ +--- +applies_to: + stack: all + serverless: ga +--- + +# Linear retriever [linear-retriever] + +A retriever that normalizes and linearly combines the scores of other retrievers. + + +## Parameters [linear-retriever-parameters] + +::::{note} +Either `query` or `retrievers` must be specified. +Combining `query` and `retrievers` is not supported. +:::: + +`query` {applies_to}`stack: ga 9.1` +: (Optional, String) + + The query to use when using the [multi-field query format](../retrievers.md#multi-field-query-format). + +`fields` {applies_to}`stack: ga 9.1` +: (Optional, array of strings) + + The fields to query when using the [multi-field query format](../retrievers.md#multi-field-query-format). + Fields can include boost values using the `^` notation (e.g., `"field^2"`). + If not specified, uses the index's default fields from the `index.query.default_field` index setting, which is `*` by default. + +`normalizer` {applies_to}`stack: ga 9.1` +: (Optional, String) + + The normalizer to use when using the [multi-field query format](../retrievers.md#multi-field-query-format). + See [normalizers](#linear-retriever-normalizers) for supported values. + Required when `query` is specified. + + ::::{warning} + Avoid using `none` as that will disable normalization and may bias the result set towards lexical matches. + See [field grouping](../retrievers.md#multi-field-field-grouping) for more information. + :::: + +`retrievers` +: (Optional, array of objects) + + A list of the sub-retrievers' configuration, that we will take into account and whose result sets we will merge through a weighted sum. + Each configuration can have a different weight and normalization depending on the specified retriever. + +`rank_window_size` +: (Optional, integer) + + This value determines the size of the individual result sets per query. A higher value will improve result relevance at the cost of performance. + The final ranked result set is pruned down to the search request’s [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). + `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. + Defaults to 10. + +`filter` +: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) + + Applies the specified [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever’s specifications. + +Each entry in the `retrievers` array specifies the following parameters: + +`retriever` +: (Required, a `retriever` object) + + Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size` results, which will later be merged based on the specified `weight` and `normalizer`. + +`weight` +: (Optional, float) + + The weight that each score of this retriever’s top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0. + +`normalizer` +: (Optional, String) + + Specifies how the retriever’s score will be normalized before applying the specified `weight`. + See [normalizers](#linear-retriever-normalizers) for supported values. + Defaults to `none`. + +See also [this hybrid search example](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-retriever) using a linear retriever on how to independently configure and apply normalizers to retrievers. + +## Normalizers [linear-retriever-normalizers] + +The `linear` retriever supports the following normalizers: + +* `none`: No normalization +* `minmax`: Normalizes scores based on the following formula: + + ``` + score = (score - min) / (max - min) + ``` +* `l2_norm`: Normalizes scores using the L2 norm of the score values diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/pinned-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/pinned-retriever.md new file mode 100644 index 0000000000000..572fae25b333f --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/pinned-retriever.md @@ -0,0 +1,55 @@ +--- +applies_to: + stack: ga 9.1 + serverless: ga +--- + +# Pinned retriever [pinned-retriever] + +A `pinned` retriever returns top documents by always placing specific documents at the top of the results, with the remaining hits provided by a secondary retriever. + +This retriever offers similar functionality to the [pinned query](/reference/query-languages/query-dsl/query-dsl-pinned-query.md), but works seamlessly with other retrievers. This is useful for promoting certain documents for particular queries, regardless of their relevance score. + +## Parameters [pinned-retriever-parameters] + +`ids` +: (Optional, array of strings) + + A list of document IDs to pin at the top of the results, in the order provided. + +`docs` +: (Optional, array of objects) + + A list of objects specifying documents to pin. Each object must contain at least an `_id` field, and may also specify `_index` if pinning documents across multiple indices. + +`retriever` +: (Optional, retriever object) + + A retriever (for example a `standard` retriever or a specialized retriever such as `rrf` retriever) used to retrieve the remaining documents after the pinned ones. + +Either `ids` or `docs` must be specified. + +## Example using `docs` [pinned-retriever-example-documents] + +```console +GET /restaurants/_search +{ + "retriever": { + "pinned": { + "docs": [ + { "_id": "doc1", "_index": "my-index" }, + { "_id": "doc2" } + ], + "retriever": { + "standard": { + "query": { + "match": { + "title": "elasticsearch" + } + } + } + } + } + } +} +``` diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/rescorer-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/rescorer-retriever.md new file mode 100644 index 0000000000000..b9f12e6dce795 --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/rescorer-retriever.md @@ -0,0 +1,123 @@ +--- +applies_to: + stack: all + serverless: ga +--- + +# Rescorer retriever [rescorer-retriever] + +The `rescorer` retriever re-scores only the results produced by its child retriever. For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard. + +For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally. + +When using the `rescorer`, an error is returned if the following conditions are not met: + +* The minimum configured rescore’s `window_size` is: + + * Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups. + * Greater than or equal to the `size` of the search request when used as the primary retriever in the tree. + +* And the maximum rescore’s `window_size` is: + + * Smaller than or equal to the `size` or `rank_window_size` of the child retriever. + +## Parameters [rescorer-retriever-parameters] + +`rescore` +: (Required. [A rescorer definition or an array of rescorer definitions](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore)) + + Defines the [rescorers](/reference/elasticsearch/rest-apis/filter-search-results.md#rescore) applied sequentially to the top documents returned by the child retriever. + + +`retriever` +: (Required. `retriever`) + + Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked. + + +`filter` +: (Optional. [query object or list of query objects](/reference/query-languages/querydsl.md)) + + Applies a [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to the retriever, ensuring that all documents match the filter criteria without affecting their scores. + + + +## Example [rescorer-retriever-example] + +The `rescorer` retriever can be placed at any level within the retriever tree. The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever: + +```console +GET movies/_search +{ + "size": 10, <1> + "retriever": { + "rescorer": { <2> + "rescore": { + "window_size": 50, <3> + "query": { <4> + "rescore_query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0", + "params": { + "queryVector": [-0.5, 90.0, -10, 14.8, -156.0] + } + } + } + } + } + }, + "retriever": { <5> + "rrf": { + "rank_window_size": 100, <6> + "retrievers": [ + { + "standard": { + "query": { + "sparse_vector": { + "field": "plot_embedding", + "inference_id": "my-elser-model", + "query": "films that explore psychological depths" + } + } + } + }, + { + "standard": { + "query": { + "multi_match": { + "query": "crime", + "fields": [ + "plot", + "title" + ] + } + } + } + }, + { + "knn": { + "field": "vector", + "query_vector": [10, 22, 77], + "k": 10, + "num_candidates": 10 + } + } + ] + } + } + } + } +} +``` + +1. Specifies the number of top documents to return in the final response. +2. A `rescorer` retriever applied as the final step. +3. Defines the number of documents to rescore from the child retriever. +4. The definition of the `query` rescorer. +5. Specifies the child retriever definition. +6. Defines the number of documents returned by the `rrf` retriever, which limits the available documents to + diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md new file mode 100644 index 0000000000000..622f8881cf84f --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md @@ -0,0 +1,149 @@ +--- +applies_to: + stack: all + serverless: ga +--- + +# RRF retriever [rrf-retriever] + +An [RRF](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md) retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers. +Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set. + + +## Parameters [rrf-retriever-parameters] + +::::{note} +Either `query` or `retrievers` must be specified. +Combining `query` and `retrievers` is not supported. +:::: + +`query` {applies_to}`stack: ga 9.1` +: (Optional, String) + + The query to use when using the [multi-field query format](../retrievers.md#multi-field-query-format). + +`fields` {applies_to}`stack: ga 9.1` +: (Optional, array of strings) + + The fields to query when using the [multi-field query format](../retrievers.md#multi-field-query-format). + If not specified, uses the index's default fields from the `index.query.default_field` index setting, which is `*` by default. + +`retrievers` +: (Optional, array of retriever objects) + + A list of child retrievers to specify which sets of returned top documents will have the RRF formula applied to them. + Each child retriever carries an equal weight as part of the RRF formula. Two or more child retrievers are required. + +`rank_constant` +: (Optional, integer) + + This value determines how much influence documents in individual result sets per query have over the final ranked result set. A higher value indicates that lower ranked documents have more influence. This value must be greater than or equal to `1`. Defaults to `60`. + +`rank_window_size` +: (Optional, integer) + + This value determines the size of the individual result sets per query. + A higher value will improve result relevance at the cost of performance. + The final ranked result set is pruned down to the search request’s [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). + `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. + Defaults to 10. + +`filter` +: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) + + Applies the specified [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever’s specifications. + +## Example: Hybrid search [rrf-retriever-example-hybrid] + +A simple hybrid search example (lexical search + dense vector search) combining a `standard` retriever with a `knn` retriever using RRF: + +```console +GET /restaurants/_search +{ + "retriever": { + "rrf": { <1> + "retrievers": [ <2> + { + "standard": { <3> + "query": { + "multi_match": { + "query": "Austria", + "fields": [ + "city", + "region" + ] + } + } + } + }, + { + "knn": { <4> + "field": "vector", + "query_vector": [10, 22, 77], + "k": 10, + "num_candidates": 10 + } + } + ], + "rank_constant": 1, <5> + "rank_window_size": 50 <6> + } + } +} +``` + +1. Defines a retriever tree with an RRF retriever. +2. The sub-retriever array. +3. The first sub-retriever is a `standard` retriever. +4. The second sub-retriever is a `knn` retriever. +5. The rank constant for the RRF retriever. +6. The rank window size for the RRF retriever. + +## Example: Hybrid search with sparse vectors [rrf-retriever-example-hybrid-sparse] + +A more complex hybrid search example (lexical search + ELSER sparse vector search + dense vector search) using RRF: + +```console +GET movies/_search +{ + "retriever": { + "rrf": { + "retrievers": [ + { + "standard": { + "query": { + "sparse_vector": { + "field": "plot_embedding", + "inference_id": "my-elser-model", + "query": "films that explore psychological depths" + } + } + } + }, + { + "standard": { + "query": { + "multi_match": { + "query": "crime", + "fields": [ + "plot", + "title" + ] + } + } + } + }, + { + "knn": { + "field": "vector", + "query_vector": [10, 22, 77], + "k": 10, + "num_candidates": 10 + } + } + ] + } + } +} +``` + diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md new file mode 100644 index 0000000000000..e9bd4ac78d2cc --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md @@ -0,0 +1,121 @@ +--- +applies_to: + stack: all + serverless: ga +--- + +# Query rules retriever [rule-retriever] + +The `rule` retriever enables fine-grained control over search results by applying contextual [query rules](/reference/elasticsearch/rest-apis/searching-with-query-rules.md#query-rules) to pin or exclude documents for specific queries. This retriever has similar functionality to the [rule query](/reference/query-languages/query-dsl/query-dsl-rule-query.md), but works out of the box with other retrievers. + +## Prerequisites [_prerequisites_16] + +To use the `rule` retriever you must first create one or more query rulesets using the [query rules management APIs](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-query_rules). + +## Parameters [rule-retriever-parameters] + +`retriever` +: (Required, `retriever`) + + The child retriever that returns the results to apply query rules on top of. This can be a standalone retriever such as the [standard](standard-retriever.md) or [knn](knn-retriever.md) retriever, or it can be a compound retriever. + + +`ruleset_ids` +: (Required, `array`) + + An array of one or more unique [query ruleset](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-query_rules) IDs with query-based rules to match and apply as applicable. Rulesets and their associated rules are evaluated in the order in which they are specified in the query and ruleset. The maximum number of rulesets to specify is 10. + + +`match_criteria` +: (Required, `object`) + + Defines the match criteria to apply to rules in the given query ruleset(s). Match criteria should match the keys defined in the `criteria.metadata` field of the rule. + + +`rank_window_size` +: (Optional, `int`) + + The number of top documents to return from the `rule` retriever. Defaults to `10`. + +## Example: Rule retriever [rule-retriever-example] + +This example shows the rule retriever executed without any additional retrievers. It runs the query defined by the `retriever` and applies the rules from `my-ruleset` on top of the returned results. + +```console +GET movies/_search +{ + "retriever": { + "rule": { + "match_criteria": { + "query_string": "harry potter" + }, + "ruleset_ids": [ + "my-ruleset" + ], + "retriever": { + "standard": { + "query": { + "query_string": { + "query": "harry potter" + } + } + } + } + } + } +} +``` + +## Example: Rule retriever combined with RRF [rule-retriever-example-rrf] + +This example shows how to combine the `rule` retriever with other rerank retrievers such as [rrf](rrf-retriever.md) or [text_similarity_reranker](text-similarity-reranker-retriever.md). + +::::{warning} +The `rule` retriever will apply rules to any documents returned from its defined `retriever` or any of its sub-retrievers. This means that for the best results, the `rule` retriever should be the outermost defined retriever. Nesting a `rule` retriever as a sub-retriever under a reranker such as `rrf` or `text_similarity_reranker` may not produce the expected results. + +:::: + + +```console +GET movies/_search +{ + "retriever": { + "rule": { <1> + "match_criteria": { + "query_string": "harry potter" + }, + "ruleset_ids": [ + "my-ruleset" + ], + "retriever": { + "rrf": { <2> + "retrievers": [ + { + "standard": { + "query": { + "query_string": { + "query": "sorcerer's stone" + } + } + } + }, + { + "standard": { + "query": { + "query_string": { + "query": "chamber of secrets" + } + } + } + } + ] + } + } + } + } +} +``` + +1. The `rule` retriever is the outermost retriever, applying rules to the search results that were previously reranked using the `rrf` retriever. +2. The `rrf` retriever returns results from all of its sub-retrievers, and the output of the `rrf` retriever is used as input to the `rule` retriever. + diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/standard-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/standard-retriever.md new file mode 100644 index 0000000000000..e4c6e4b7554da --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/standard-retriever.md @@ -0,0 +1,98 @@ +--- +applies_to: + stack: all + serverless: ga +--- +# Standard retriever [standard-retriever] + +A standard retriever returns top documents from a traditional [query](/reference/query-languages/querydsl.md). + + +### Parameters: [standard-retriever-parameters] + +`query` +: (Optional, [query object](/reference/query-languages/querydsl.md)) + + Defines a query to retrieve a set of top documents. + + +`filter` +: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) + + Applies a [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to this retriever, where all documents must match this query but do not contribute to the score. + + +`search_after` +: (Optional, [search after object](/reference/elasticsearch/rest-apis/paginate-search-results.md#search-after)) + + Defines a search after object parameter used for pagination. + + +`terminate_after` +: (Optional, integer) Maximum number of documents to collect for each shard. If a query reaches this limit, {{es}} terminates the query early. {{es}} collects documents before sorting. + + ::::{important} + Use with caution. {{es}} applies this parameter to each shard handling the request. When possible, let {{es}} perform early termination automatically. Avoid specifying this parameter for requests that target data streams with backing indices across multiple data tiers. + :::: + + +`sort` +: (Optional, [sort object](/reference/elasticsearch/rest-apis/sort-search-results.md)) A sort object that specifies the order of matching documents. + + +`min_score` +: (Optional, `float`) + + Minimum [`_score`](/reference/query-languages/query-dsl/query-filter-context.md#relevance-scores) for matching documents. Documents with a lower `_score` are not included in the top documents. + + +`collapse` +: (Optional, [collapse object](/reference/elasticsearch/rest-apis/collapse-search-results.md)) + + Collapses the top documents by a specified key into a single top document per key. + + +## Restrictions [_restrictions] + +When a retriever tree contains a compound retriever (a retriever with two or more child retrievers) the [search after](/reference/elasticsearch/rest-apis/paginate-search-results.md#search-after) parameter is not supported. + + +## Example [standard-retriever-example] + +```console +GET /restaurants/_search +{ + "retriever": { <1> + "standard": { <2> + "query": { <3> + "bool": { <4> + "should": [ <5> + { + "match": { <6> + "region": "Austria" + } + } + ], + "filter": [ <7> + { + "term": { <8> + "year": "2019" <9> + } + } + ] + } + } + } + } +} +``` + +1. Opens the `retriever` object. +2. The `standard` retriever is used for defining traditional {{es}} queries. +3. The entry point for defining the search query. +4. The `bool` object allows for combining multiple query clauses logically. +5. The `should` array indicates conditions under which a document will match. Documents matching these conditions will have increased relevancy scores. +6. The `match` object finds documents where the `region` field contains the word "Austria." +7. The `filter` array provides filtering conditions that must be met but do not contribute to the relevancy score. +8. The `term` object is used for exact matches, in this case, filtering documents by the `year` field. +9. The exact value to match in the `year` field. diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/text-similarity-reranker-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/text-similarity-reranker-retriever.md new file mode 100644 index 0000000000000..9abb236a45d1e --- /dev/null +++ b/docs/reference/elasticsearch/rest-apis/retrievers/text-similarity-reranker-retriever.md @@ -0,0 +1,248 @@ +--- +applies_to: + stack: all + serverless: ga +--- + +# Text similarity re-ranker retriever [text-similarity-reranker-retriever] + +The `text_similarity_reranker` retriever uses an NLP model to improve search results by reordering the top-k documents based on their semantic similarity to the query. + +::::{tip} +Refer to [*Semantic re-ranking*](docs-content://solutions/search/ranking/semantic-reranking.md) for a high level overview of semantic re-ranking. +:::: + +## Prerequisites [_prerequisites_15] + +To use `text_similarity_reranker`, you can rely on the preconfigured `.rerank-v1-elasticsearch` inference endpoint, which uses the [Elastic Rerank model](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-rerank.md) and serves as the default if no `inference_id` is provided. This model is optimized for reranking based on text similarity. If you'd like to use a different model, you can set up a custom inference endpoint for the `rerank` task using the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). The endpoint should be configured with a machine learning model capable of computing text similarity. Refer to [the Elastic NLP model reference](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-model-ref.md#ml-nlp-model-ref-text-similarity) for a list of third-party text similarity models supported by {{es}}. + +You have the following options: + +* Use the built-in [Elastic Rerank](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-rerank.md) cross-encoder model via the inference API’s {{es}} service. See [this example](https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-service-elasticsearch.html#inference-example-elastic-reranker) for creating an endpoint using the Elastic Rerank model. +* Use the [Cohere Rerank inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. +* Use the [Google Vertex AI inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. +* Upload a model to {{es}} with [Eland](eland://reference/machine-learning.md#ml-nlp-pytorch) using the `text_similarity` NLP task type. + + * Then set up an [{{es}} service inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. + * Refer to the [example](#text-similarity-reranker-retriever-example-eland) on this page for a step-by-step guide. + + +::::{important} +Scores from the re-ranking process are normalized using the following formula before returned to the user, to avoid having negative scores. + +```text +score = max(score, 0) + min(exp(score), 1) +``` + +Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity). To revert back if needed, one can use: + +```text +score = score - 1, if score >= 0 +score = ln(score), if score < 0 +``` + +:::: + +## Parameters [text-similarity-reranker-retriever-parameters] + +`retriever` +: (Required, `retriever`) + + The child retriever that generates the initial set of top documents to be re-ranked. + + +`field` +: (Required, `string`) + + The document field to be used for text similarity comparisons. This field should contain the text that will be evaluated against the `inferenceText`. + + +`inference_id` +: (Optional, `string`) + + Unique identifier of the inference endpoint created using the {{infer}} API. If you don’t specify an inference endpoint, the `inference_id` field defaults to `.rerank-v1-elasticsearch`, a preconfigured endpoint for the elasticsearch `.rerank-v1` model. + + +`inference_text` +: (Required, `string`) + + The text snippet used as the basis for similarity comparison. + + +`rank_window_size` +: (Optional, `int`) + + The number of top documents to consider in the re-ranking process. Defaults to `10`. + + +`min_score` +: (Optional, `float`) + + Sets a minimum threshold score for including documents in the re-ranked results. Documents with similarity scores below this threshold will be excluded. Note that score calculations vary depending on the model used. + + +`filter` +: (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) + + Applies the specified [boolean query filter](/reference/query-languages/query-dsl/query-dsl-bool-query.md) to the child `retriever`. If the child retriever already specifies any filters, then this top-level filter is applied in conjuction with the filter defined in the child retriever. + + + +## Example: Elastic Rerank [text-similarity-reranker-retriever-example-elastic-rerank] + +::::{tip} +Refer to this [Python notebook](https://github.com/elastic/elasticsearch-labs/blob/main/notebooks/search/12-semantic-reranking-elastic-rerank.ipynb) for an end-to-end example using Elastic Rerank. + +:::: + + +This example demonstrates how to deploy the [Elastic Rerank](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-rerank.md) model and use it to re-rank search results using the `text_similarity_reranker` retriever. + +Follow these steps: + +1. Create an inference endpoint for the `rerank` task using the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). + + ```console + PUT _inference/rerank/my-elastic-rerank + { + "service": "elasticsearch", + "service_settings": { + "model_id": ".rerank-v1", + "num_threads": 1, + "adaptive_allocations": { <1> + "enabled": true, + "min_number_of_allocations": 1, + "max_number_of_allocations": 10 + } + } + } + ``` + + 1. [Adaptive allocations](docs-content://deploy-manage/autoscaling/trained-model-autoscaling.md#enabling-autoscaling-through-apis-adaptive-allocations) will be enabled with the minimum of 1 and the maximum of 10 allocations. + +2. Define a `text_similarity_rerank` retriever: + + ```console + POST _search + { + "retriever": { + "text_similarity_reranker": { + "retriever": { + "standard": { + "query": { + "match": { + "text": "How often does the moon hide the sun?" + } + } + } + }, + "field": "text", + "inference_id": "my-elastic-rerank", + "inference_text": "How often does the moon hide the sun?", + "rank_window_size": 100, + "min_score": 0.5 + } + } + } + ``` + + + +## Example: Cohere Rerank [text-similarity-reranker-retriever-example-cohere] + +This example enables out-of-the-box semantic search by re-ranking top documents using the Cohere Rerank API. This approach eliminates the need to generate and store embeddings for all indexed documents. This requires a [Cohere Rerank inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) that is set up for the `rerank` task type. + +```console +GET /index/_search +{ + "retriever": { + "text_similarity_reranker": { + "retriever": { + "standard": { + "query": { + "match_phrase": { + "text": "landmark in Paris" + } + } + } + }, + "field": "text", + "inference_id": "my-cohere-rerank-model", + "inference_text": "Most famous landmark in Paris", + "rank_window_size": 100, + "min_score": 0.5 + } + } +} +``` + + +## Example: Semantic re-ranking with a Hugging Face model [text-similarity-reranker-retriever-example-eland] + +The following example uses the `cross-encoder/ms-marco-MiniLM-L-6-v2` model from Hugging Face to rerank search results based on semantic similarity. The model must be uploaded to {{es}} using [Eland](eland://reference/machine-learning.md#ml-nlp-pytorch). + +::::{tip} +Refer to [the Elastic NLP model reference](docs-content://explore-analyze/machine-learning/nlp/ml-nlp-model-ref.md#ml-nlp-model-ref-text-similarity) for a list of third party text similarity models supported by {{es}}. + +:::: + + +Follow these steps to load the model and create a semantic re-ranker. + +1. Install Eland using `pip` + + ```sh + python -m pip install eland[pytorch] + ``` + +2. Upload the model to {{es}} using Eland. This example assumes you have an Elastic Cloud deployment and an API key. Refer to the [Eland documentation](eland://reference/machine-learning.md#ml-nlp-pytorch-auth) for more authentication options. + + ```sh + eland_import_hub_model \ + --cloud-id $CLOUD_ID \ + --es-api-key $ES_API_KEY \ + --hub-model-id cross-encoder/ms-marco-MiniLM-L-6-v2 \ + --task-type text_similarity \ + --clear-previous \ + --start + ``` + +3. Create an inference endpoint for the `rerank` task + + ```console + PUT _inference/rerank/my-msmarco-minilm-model + { + "service": "elasticsearch", + "service_settings": { + "num_allocations": 1, + "num_threads": 1, + "model_id": "cross-encoder__ms-marco-minilm-l-6-v2" + } + } + ``` + +4. Define a `text_similarity_rerank` retriever. + + ```console + POST movies/_search + { + "retriever": { + "text_similarity_reranker": { + "retriever": { + "standard": { + "query": { + "match": { + "genre": "drama" + } + } + } + }, + "field": "plot", + "inference_id": "my-msmarco-minilm-model", + "inference_text": "films that explore psychological depths" + } + } + } + ``` + + This retriever uses a standard `match` query to search the `movie` index for films tagged with the genre "drama". It then re-ranks the results based on semantic similarity to the text in the `inference_text` parameter, using the model we uploaded to {{es}}. diff --git a/docs/reference/elasticsearch/rest-apis/searching-with-query-rules.md b/docs/reference/elasticsearch/rest-apis/searching-with-query-rules.md index 37a97495d5c97..e7bdc9d1a5439 100644 --- a/docs/reference/elasticsearch/rest-apis/searching-with-query-rules.md +++ b/docs/reference/elasticsearch/rest-apis/searching-with-query-rules.md @@ -18,13 +18,13 @@ $$$query-rules$$$ * A referring site * etc. -Query rules define a metadata key that will be used to match the metadata provided in the [rule retriever](/reference/elasticsearch/rest-apis/retrievers.md#rule-retriever) with the criteria specified in the rule. +Query rules define a metadata key that will be used to match the metadata provided in the [rule retriever](/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md) with the criteria specified in the rule. When a query rule matches the rule metadata according to its defined criteria, the query rule action is applied to the underlying `organic` query. For example, a query rule could be defined to match a user-entered query string of `pugs` and a country `us` and promote adoptable shelter dogs if the rule query met both criteria. -Rules are defined using the [query rules API](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-query_rules) and searched using the [rule retriever](/reference/elasticsearch/rest-apis/retrievers.md#rule-retriever) or the [rule query](/reference/query-languages/query-dsl/query-dsl-rule-query.md). +Rules are defined using the [query rules API](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-query_rules) and searched using the [rule retriever](/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md) or the [rule query](/reference/query-languages/query-dsl/query-dsl-rule-query.md). ## Rule definition [query-rule-definition] @@ -148,7 +148,7 @@ You can use the [Get query ruleset](https://www.elastic.co/docs/api/doc/elastics ## Search using query rules [rule-query-search] -Once you have defined one or more query rulesets, you can search using these rulesets using the [rule retriever](/reference/elasticsearch/rest-apis/retrievers.md#rule-retriever) or the [rule query](/reference/query-languages/query-dsl/query-dsl-rule-query.md). Retrievers are the recommended way to use rule queries, as they will work out of the box with other reranking retrievers such as [Reciprocal rank fusion](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md). +Once you have defined one or more query rulesets, you can search using these rulesets using the [rule retriever](/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md) or the [rule query](/reference/query-languages/query-dsl/query-dsl-rule-query.md). Retrievers are the recommended way to use rule queries, as they will work out of the box with other reranking retrievers such as [Reciprocal rank fusion](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md). Rulesets are evaluated in order, so rules in the first ruleset you specify will be applied before any subsequent rulesets. @@ -186,7 +186,7 @@ It’s possible to have multiple rules in a ruleset match a single [rule query]( * If multiple documents are specified in a single rule, in the order they are specified * If a document is matched by both a `pinned` rule and an `exclude` rule, the `exclude` rule will take precedence -You can specify reranking retrievers such as [rrf](/reference/elasticsearch/rest-apis/retrievers.md#rrf-retriever) or [text_similarity_reranker](/reference/elasticsearch/rest-apis/retrievers.md#text-similarity-reranker-retriever) in the rule query to apply query rules on already-reranked results. Here is an example: +You can specify reranking retrievers such as [rrf](/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md) or [text_similarity_reranker](/reference/elasticsearch/rest-apis/retrievers/text-similarity-reranker-retriever.md) in the rule query to apply query rules on already-reranked results. Here is an example: ```console GET my-index-000001/_search diff --git a/docs/reference/elasticsearch/toc.yml b/docs/reference/elasticsearch/toc.yml index 087a2fb26e143..4e1d5bc52a68d 100644 --- a/docs/reference/elasticsearch/toc.yml +++ b/docs/reference/elasticsearch/toc.yml @@ -97,6 +97,15 @@ toc: - file: rest-apis/retrieve-selected-fields.md - file: rest-apis/retrieve-stored-fields.md - file: rest-apis/retrievers.md + children: + - file: rest-apis/retrievers/knn-retriever.md + - file: rest-apis/retrievers/linear-retriever.md + - file: rest-apis/retrievers/pinned-retriever.md + - file: rest-apis/retrievers/rescorer-retriever.md + - file: rest-apis/retrievers/rrf-retriever.md + - file: rest-apis/retrievers/rule-retriever.md + - file: rest-apis/retrievers/standard-retriever.md + - file: rest-apis/retrievers/text-similarity-reranker-retriever.md - file: rest-apis/search-multiple-data-streams-indices.md - file: rest-apis/search-profile.md - file: rest-apis/search-rank-eval.md diff --git a/docs/reference/query-languages/esql.md b/docs/reference/query-languages/esql.md index 034794af7d8e9..d66dceb1d36ff 100644 --- a/docs/reference/query-languages/esql.md +++ b/docs/reference/query-languages/esql.md @@ -20,4 +20,5 @@ This reference section provides detailed technical information about {{esql}} fe * [Advanced workflows](esql/esql-advanced.md): Learn how to handle more complex tasks with these guides, including how to extract, transform, and combine data from multiple indices * [Types and fields](esql/esql-types-and-fields.md): Learn about how {{esql}} handles different data types and special fields * [Limitations](esql/limitations.md): Learn about the current limitations of {{esql}} -* [Examples](esql/esql-examples.md): Explore some example queries \ No newline at end of file +* [Examples](esql/esql-examples.md): Explore some example queries +* [Troubleshooting](esql/esql-troubleshooting.md): Learn how to diagnose and resolve issues with {{esql}} diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/score.md b/docs/reference/query-languages/esql/_snippets/functions/description/score.md new file mode 100644 index 0000000000000..82426283b03a0 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/score.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Scores an expression. Only full text functions will be scored. Returns scores for all the resulting docs. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/score.md b/docs/reference/query-languages/esql/_snippets/functions/examples/score.md new file mode 100644 index 0000000000000..86691e4e941a8 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/score.md @@ -0,0 +1,11 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +FROM books METADATA _score +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL first_score = score(match(title, "Return")) +``` + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md new file mode 100644 index 0000000000000..acd2064002b44 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md @@ -0,0 +1,13 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported function named parameters** + +`output_format` +: (boolean) The output format of the categories. Defaults to regex. + +`similarity_threshold` +: (boolean) The minimum percentage of token weight that must match for text to be added to the category bucket. Must be between 1 and 100. The larger the value the narrower the categories. Larger values will increase memory usage and create narrower categories. Defaults to 70. + +`analyzer` +: (keyword) Analyzer used to convert the field into tokens for text categorization. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md index ca23c1e2efc23..2e331187665f4 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md @@ -19,5 +19,8 @@ :::{include} ../types/categorize.md ::: +:::{include} ../functionNamedParams/categorize.md +::: + :::{include} ../examples/categorize.md ::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/match_phrase.md b/docs/reference/query-languages/esql/_snippets/functions/layout/match_phrase.md index 6eb9e17bf35f9..f658fbf1bbde2 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/layout/match_phrase.md +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/match_phrase.md @@ -2,7 +2,7 @@ ## `MATCH_PHRASE` [esql-match_phrase] ```{applies_to} -stack: unavailable 9.0, ga 9.1.0 +stack: ga 9.1.0 ``` **Syntax** diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/score.md b/docs/reference/query-languages/esql/_snippets/functions/layout/score.md new file mode 100644 index 0000000000000..b2fa5e09baeac --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/score.md @@ -0,0 +1,27 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `SCORE` [esql-score] +```{applies_to} +stack: development +serverless: preview +``` + +**Syntax** + +:::{image} ../../../images/functions/score.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/score.md +::: + +:::{include} ../description/score.md +::: + +:::{include} ../types/score.md +::: + +:::{include} ../examples/score.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/bucket.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/bucket.md index cadd93c20be11..9bda78c46e8fe 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/bucket.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/bucket.md @@ -14,3 +14,6 @@ `to` : End of the range. Can be a number, a date or a date expressed as a string. +`emitEmptyBuckets` +: Whether or not empty buckets should be emitted. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md index 8733908754570..c013b67375a3d 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md @@ -5,3 +5,6 @@ `field` : Expression to categorize +`options` +: (Optional) Categorize additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params). + diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md new file mode 100644 index 0000000000000..59fdd3c54e1dd --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md @@ -0,0 +1,7 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`query` +: Boolean expression that contains full text function(s) to be scored. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/bucket.md b/docs/reference/query-languages/esql/_snippets/functions/types/bucket.md index 658d11d6f1130..578f527efb4ce 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/types/bucket.md +++ b/docs/reference/query-languages/esql/_snippets/functions/types/bucket.md @@ -2,64 +2,64 @@ **Supported types** -| field | buckets | from | to | result | -| --- | --- | --- | --- | --- | -| date | date_period | | | date | -| date | integer | date | date | date | -| date | integer | date | keyword | date | -| date | integer | date | text | date | -| date | integer | keyword | date | date | -| date | integer | keyword | keyword | date | -| date | integer | keyword | text | date | -| date | integer | text | date | date | -| date | integer | text | keyword | date | -| date | integer | text | text | date | -| date | time_duration | | | date | -| date_nanos | date_period | | | date_nanos | -| date_nanos | integer | date | date | date_nanos | -| date_nanos | integer | date | keyword | date_nanos | -| date_nanos | integer | date | text | date_nanos | -| date_nanos | integer | keyword | date | date_nanos | -| date_nanos | integer | keyword | keyword | date_nanos | -| date_nanos | integer | keyword | text | date_nanos | -| date_nanos | integer | text | date | date_nanos | -| date_nanos | integer | text | keyword | date_nanos | -| date_nanos | integer | text | text | date_nanos | -| date_nanos | time_duration | | | date_nanos | -| double | double | | | double | -| double | integer | double | double | double | -| double | integer | double | integer | double | -| double | integer | double | long | double | -| double | integer | integer | double | double | -| double | integer | integer | integer | double | -| double | integer | integer | long | double | -| double | integer | long | double | double | -| double | integer | long | integer | double | -| double | integer | long | long | double | -| double | integer | | | double | -| double | long | | | double | -| integer | double | | | double | -| integer | integer | double | double | double | -| integer | integer | double | integer | double | -| integer | integer | double | long | double | -| integer | integer | integer | double | double | -| integer | integer | integer | integer | double | -| integer | integer | integer | long | double | -| integer | integer | long | double | double | -| integer | integer | long | integer | double | -| integer | integer | long | long | double | -| integer | integer | | | double | -| integer | long | | | double | -| long | double | | | double | -| long | integer | double | double | double | -| long | integer | double | integer | double | -| long | integer | double | long | double | -| long | integer | integer | double | double | -| long | integer | integer | integer | double | -| long | integer | integer | long | double | -| long | integer | long | double | double | -| long | integer | long | integer | double | -| long | integer | long | long | double | -| long | integer | | | double | -| long | long | | | double | +| field | buckets | from | to | emitEmptyBuckets | result | +| --- | --- | --- | --- | --- | --- | +| date | date_period | | | | date | +| date | integer | date | date | | date | +| date | integer | date | keyword | | date | +| date | integer | date | text | | date | +| date | integer | keyword | date | | date | +| date | integer | keyword | keyword | | date | +| date | integer | keyword | text | | date | +| date | integer | text | date | | date | +| date | integer | text | keyword | | date | +| date | integer | text | text | | date | +| date | time_duration | | | | date | +| date_nanos | date_period | | | | date_nanos | +| date_nanos | integer | date | date | | date_nanos | +| date_nanos | integer | date | keyword | | date_nanos | +| date_nanos | integer | date | text | | date_nanos | +| date_nanos | integer | keyword | date | | date_nanos | +| date_nanos | integer | keyword | keyword | | date_nanos | +| date_nanos | integer | keyword | text | | date_nanos | +| date_nanos | integer | text | date | | date_nanos | +| date_nanos | integer | text | keyword | | date_nanos | +| date_nanos | integer | text | text | | date_nanos | +| date_nanos | time_duration | | | | date_nanos | +| double | double | | | | double | +| double | integer | double | double | | double | +| double | integer | double | integer | | double | +| double | integer | double | long | | double | +| double | integer | integer | double | | double | +| double | integer | integer | integer | | double | +| double | integer | integer | long | | double | +| double | integer | long | double | | double | +| double | integer | long | integer | | double | +| double | integer | long | long | | double | +| double | integer | | | | double | +| double | long | | | | double | +| integer | double | | | | double | +| integer | integer | double | double | | double | +| integer | integer | double | integer | | double | +| integer | integer | double | long | | double | +| integer | integer | integer | double | | double | +| integer | integer | integer | integer | | double | +| integer | integer | integer | long | | double | +| integer | integer | long | double | | double | +| integer | integer | long | integer | | double | +| integer | integer | long | long | | double | +| integer | integer | | | | double | +| integer | long | | | | double | +| long | double | | | | double | +| long | integer | double | double | | double | +| long | integer | double | integer | | double | +| long | integer | double | long | | double | +| long | integer | integer | double | | double | +| long | integer | integer | integer | | double | +| long | integer | integer | long | | double | +| long | integer | long | double | | double | +| long | integer | long | integer | | double | +| long | integer | long | long | | double | +| long | integer | | | | double | +| long | long | | | | double | diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md b/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md index 6043fbe719ff8..8ebe22b61286c 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md +++ b/docs/reference/query-languages/esql/_snippets/functions/types/categorize.md @@ -2,8 +2,8 @@ **Supported types** -| field | result | -| --- | --- | -| keyword | keyword | -| text | keyword | +| field | options | result | +| --- | --- | --- | +| keyword | | keyword | +| text | | keyword | diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/score.md b/docs/reference/query-languages/esql/_snippets/functions/types/score.md new file mode 100644 index 0000000000000..ab4532fd069a5 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/score.md @@ -0,0 +1,8 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| query | result | +| --- | --- | +| boolean | double | + diff --git a/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md b/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md index 5a7cf85a0256e..3dd13763aeb2e 100644 --- a/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md +++ b/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md @@ -17,4 +17,17 @@ ROW message = "foo ( bar" | WHERE message RLIKE """foo \( bar""" ``` +```{applies_to} +stack: ga 9.2 +serverless: ga +``` + +Both a single pattern or a list of patterns are supported. If a list of patterns is provided, +the expression will return true if any of the patterns match. + +```esql +ROW message = "foobar" +| WHERE message RLIKE ("foo.*", "bar.") +``` + diff --git a/docs/reference/query-languages/esql/esql-query-log.md b/docs/reference/query-languages/esql/esql-query-log.md new file mode 100644 index 0000000000000..05c7f41134a9c --- /dev/null +++ b/docs/reference/query-languages/esql/esql-query-log.md @@ -0,0 +1,130 @@ +--- +navigation_title: "Query log" +--- + +# {{esql}} Query log [esql-query-log] + + +The {{esql}} query log allows to log {{esql}} queries based on their execution time. + +You can use these logs to investigate, analyze or troubleshoot your cluster’s historical {{esql}} performance. + +{{esql}} query log reports task duration at coordinator level, but might not encompass the full task execution time observed on the client. For example, logs don’t surface HTTP network delays. + +Events that meet the specified threshold are emitted into [{{es}} server logs](docs-content://deploy-manage/monitor/logging-configuration/update-elasticsearch-logging-levels.md). + +These logs can be found in local {{es}} service logs directory. Slow log files have a suffix of `_esql_querylog.json`. + +## Query log format [query-log-format] + +The following is an example of a successful query event in the query log: + +```js +{ + "@timestamp": "2025-03-11T08:39:50.076Z", + "log.level": "TRACE", + "auth.type": "REALM", + "elasticsearch.querylog.planning.took": 3108666, + "elasticsearch.querylog.planning.took_millis": 3, + "elasticsearch.querylog.query": "from index | limit 100", + "elasticsearch.querylog.search_type": "ESQL", + "elasticsearch.querylog.success": true, + "elasticsearch.querylog.took": 8050416, + "elasticsearch.querylog.took_millis": 8, + "user.name": "elastic-admin", + "user.realm": "default_file", + "ecs.version": "1.2.0", + "service.name": "ES_ECS", + "event.dataset": "elasticsearch.esql_querylog", + "process.thread.name": "elasticsearch[runTask-0][esql_worker][T#12]", + "log.logger": "esql.querylog.query", + "elasticsearch.cluster.uuid": "KZo1V7TcQM-O6fnqMm1t_g", + "elasticsearch.node.id": "uPgRE2TrSfa9IvnUpNT1Uw", + "elasticsearch.node.name": "runTask-0", + "elasticsearch.cluster.name": "runTask" +} +``` + +The following is an example of a failing query event in the query log: + +```js +{ + "@timestamp": "2025-03-11T08:41:54.172Z", + "log.level": "TRACE", + "auth.type": "REALM", + "elasticsearch.querylog.error.message": "line 1:15: mismatched input 'limitxyz' expecting {DEV_CHANGE_POINT, 'enrich', 'dissect', 'eval', 'grok', 'limit', 'sort', 'stats', 'where', DEV_INLINESTATS, DEV_FORK, 'lookup', DEV_JOIN_LEFT, DEV_JOIN_RIGHT, DEV_LOOKUP, 'mv_expand', 'drop', 'keep', DEV_INSIST, 'rename'}", + "elasticsearch.querylog.error.type": "org.elasticsearch.xpack.esql.parser.ParsingException", + "elasticsearch.querylog.query": "from person | limitxyz 100", + "elasticsearch.querylog.search_type": "ESQL", + "elasticsearch.querylog.success": false, + "elasticsearch.querylog.took": 963750, + "elasticsearch.querylog.took_millis": 0, + "user.name": "elastic-admin", + "user.realm": "default_file", + "ecs.version": "1.2.0", + "service.name": "ES_ECS", + "event.dataset": "elasticsearch.esql_querylog", + "process.thread.name": "elasticsearch[runTask-0][search][T#16]", + "log.logger": "esql.querylog.query", + "elasticsearch.cluster.uuid": "KZo1V7TcQM-O6fnqMm1t_g", + "elasticsearch.node.id": "uPgRE2TrSfa9IvnUpNT1Uw", + "elasticsearch.node.name": "runTask-0", + "elasticsearch.cluster.name": "runTask" +} +``` + + +## Enable query logging [enable-query-log] + +You can enable query logging at cluster level. + +By default, all thresholds are set to `-1`, which results in no events being logged. + +Query log thresholds can be enabled for the four logging levels: `trace`, `debug`, `info`, and `warn`. + +To view the current query log settings, use the [get cluster settings API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-get-settings): + +```console +GET _cluster/settings?filter_path=*.esql.querylog.* +``` + +You can use the `esql.querylog.include.user` setting to append `user.*` and `auth.type` fields to slow log entries. These fields contain information about the user who triggered the request. + +The following snippet adjusts all available {{esql}} query log settings [update cluster settings API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-put-settings): + +```console +PUT /_cluster/settings +{ + "transient": { + "esql.querylog.threshold.warn": "10s", + "esql.querylog.threshold.info": "5s", + "esql.querylog.threshold.debug": "2s", + "esql.querylog.threshold.trace": "500ms", + "esql.querylog.include.user": true + } +} +``` + + + +## Best practices for query logging [troubleshoot-query-log] + +Logging slow requests can be resource intensive to your {{es}} cluster depending on the qualifying traffic’s volume. For example, emitted logs might increase the index disk usage of your [{{es}} monitoring](docs-content://deploy-manage/monitor/stack-monitoring.md) cluster. To reduce the impact of slow logs, consider the following: + +* Set high thresholds to reduce the number of logged events. +* Enable slow logs only when troubleshooting. + +If you aren’t sure how to start investigating traffic issues, consider enabling the `warn` threshold with a high `30s` threshold at the index level using the [update cluster settings API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-cluster-put-settings): + +Here is an example of how to change cluster settings to enable query logging at `warn` level, for queries taking more than 30 seconds, and include user information in the logs: + +```console +PUT /_cluster/settings +{ + "transient": { + "esql.querylog.include.user": true, + "esql.querylog.threshold.warn": "30s" + } +} +``` + diff --git a/docs/reference/query-languages/esql/esql-troubleshooting.md b/docs/reference/query-languages/esql/esql-troubleshooting.md new file mode 100644 index 0000000000000..43768a2facc99 --- /dev/null +++ b/docs/reference/query-languages/esql/esql-troubleshooting.md @@ -0,0 +1,9 @@ +--- +navigation_title: "Troubleshooting" +--- + +# Troubleshooting {{esql}} [esql-troubleshooting] + +This section provides some useful resource for troubleshooting {{esql}} + +* [Query log](esql-query-log.md): Learn how to log {{esql}} queries diff --git a/docs/reference/query-languages/esql/images/functions/bucket.svg b/docs/reference/query-languages/esql/images/functions/bucket.svg index 78694296922ed..900db7701480f 100644 --- a/docs/reference/query-languages/esql/images/functions/bucket.svg +++ b/docs/reference/query-languages/esql/images/functions/bucket.svg @@ -1 +1 @@ -BUCKET(field,buckets,from,to) \ No newline at end of file +BUCKET(field,buckets,from,to,emitEmptyBuckets) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/images/functions/categorize.svg b/docs/reference/query-languages/esql/images/functions/categorize.svg index bbb2bda7c480b..7629b9bb978ba 100644 --- a/docs/reference/query-languages/esql/images/functions/categorize.svg +++ b/docs/reference/query-languages/esql/images/functions/categorize.svg @@ -1 +1 @@ -CATEGORIZE(field) \ No newline at end of file +CATEGORIZE(field,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/images/functions/score.svg b/docs/reference/query-languages/esql/images/functions/score.svg new file mode 100644 index 0000000000000..9662976dd6db1 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/score.svg @@ -0,0 +1 @@ +SCORE(query) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/images/functions/v_cosine.svg b/docs/reference/query-languages/esql/images/functions/v_cosine.svg new file mode 100644 index 0000000000000..fb7a2ed91fa8d --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/v_cosine.svg @@ -0,0 +1 @@ +V_COSINE(left,right) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/score.json b/docs/reference/query-languages/esql/kibana/definition/functions/score.json new file mode 100644 index 0000000000000..4772093e349d9 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/score.json @@ -0,0 +1,25 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "score", + "description" : "Scores an expression. Only full text functions will be scored. Returns scores for all the resulting docs.", + "signatures" : [ + { + "params" : [ + { + "name" : "query", + "type" : "boolean", + "optional" : false, + "description" : "Boolean expression that contains full text function(s) to be scored." + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "FROM books METADATA _score\n| WHERE match(title, \"Return\") AND match(author, \"Tolkien\")\n| EVAL first_score = score(match(title, \"Return\"))" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json b/docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json new file mode 100644 index 0000000000000..f3b3df1d88c6a --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json @@ -0,0 +1,12 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "v_cosine", + "description" : "Calculates the cosine similarity between two dense_vectors.", + "signatures" : [ ], + "examples" : [ + " from colors\n | where color != \"black\"\n | eval similarity = v_cosine(rgb_vector, [0, 255, 255])\n | sort similarity desc, color asc" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md index c7af797488ba4..f32319b080dbb 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md @@ -1,4 +1,4 @@ -% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. ### KNN Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors. diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/score.md b/docs/reference/query-languages/esql/kibana/docs/functions/score.md new file mode 100644 index 0000000000000..865a7b0758ba9 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/score.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### SCORE +Scores an expression. Only full text functions will be scored. Returns scores for all the resulting docs. + +```esql +FROM books METADATA _score +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL first_score = score(match(title, "Return")) +``` diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md b/docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md new file mode 100644 index 0000000000000..22e4626fe38ad --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md @@ -0,0 +1,11 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### V COSINE +Calculates the cosine similarity between two dense_vectors. + +```esql + from colors + | where color != "black" + | eval similarity = v_cosine(rgb_vector, [0, 255, 255]) + | sort similarity desc, color asc +``` diff --git a/docs/reference/query-languages/esql/limitations.md b/docs/reference/query-languages/esql/limitations.md index 83ae009a7d0ca..5f4417aa78e98 100644 --- a/docs/reference/query-languages/esql/limitations.md +++ b/docs/reference/query-languages/esql/limitations.md @@ -250,3 +250,6 @@ Work around this limitation by converting the field to single value with one of * CSV export from Discover shows no more than 10,000 rows. This limit only applies to the number of rows that are retrieved by the query and displayed in Discover. Queries and aggregations run on the full data set. * Querying many indices at once without any filters can cause an error in kibana which looks like `[esql] > Unexpected error from Elasticsearch: The content length (536885793) is bigger than the maximum allowed string (536870888)`. The response from {{esql}} is too long. Use [`DROP`](/reference/query-languages/esql/commands/processing-commands.md#esql-drop) or [`KEEP`](/reference/query-languages/esql/commands/processing-commands.md#esql-keep) to limit the number of fields returned. +## Known issues [esql-known-issues] + +Refer to [Known issues](/release-notes/known-issues.md) for a list of known issues for {{esql}}. diff --git a/docs/reference/query-languages/query-dsl/query-dsl-bool-query.md b/docs/reference/query-languages/query-dsl/query-dsl-bool-query.md index 5f64e7d462b0b..c75d7d66b46b0 100644 --- a/docs/reference/query-languages/query-dsl/query-dsl-bool-query.md +++ b/docs/reference/query-languages/query-dsl/query-dsl-bool-query.md @@ -13,8 +13,8 @@ A query that matches documents matching boolean combinations of other queries. T | --- | --- | | `must` | The clause (query) must appear in matching documents and will contribute to the score. Each query defined under a `must` acts as a logical "AND", returning only documents that match *all* the specified queries. | | `should` | The clause (query) should appear in the matching document. Each query defined under a `should` acts as a logical "OR", returning documents that match *any* of the specified queries. | -| `filter` | The clause (query) must appear in matching documents. However unlike`must` the score of the query will be ignored. Filter clauses are executedin [filter context](/reference/query-languages/query-dsl/query-filter-context.md), meaning that scoring is ignoredand clauses are considered for caching. Each query defined under a `filter` acts as a logical "AND", returning only documents that match *all* the specified queries. | -| `must_not` | The clause (query) must not appear in the matchingdocuments. Clauses are executed in [filter context](/reference/query-languages/query-dsl/query-filter-context.md) meaningthat scoring is ignored and clauses are considered for caching. Because scoring isignored, a score of `0` for all documents is returned. Each query defined under a `must_not` acts as a logical "NOT", returning only documents that do not match any of the specified queries. | +| `filter` | The clause (query) must appear in matching documents. However unlike `must` the score of the query will be ignored. Filter clauses are executed in [filter context](/reference/query-languages/query-dsl/query-filter-context.md), meaning that scoring is ignored and clauses are considered for caching. Each query defined under a `filter` acts as a logical "AND", returning only documents that match *all* the specified queries. | +| `must_not` | The clause (query) must not appear in the matching documents. Clauses are executed in [filter context](/reference/query-languages/query-dsl/query-filter-context.md) meaning that scoring is ignored and clauses are considered for caching. Because scoring is ignored, a score of `0` for all documents is returned. Each query defined under a `must_not` acts as a logical "NOT", returning only documents that do not match any of the specified queries. | The `must` and `should` clauses function as logical AND, OR operators, contributing to the scoring of results. However, these results are not cached, which means repeated queries won't benefit from faster retrieval. In contrast, the `filter` and `must_not` clauses are used to include or exclude results without impacting the score, unless used within a `constant_score` query. diff --git a/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md b/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md index ab2dd232eaf50..d57fe8642e6c3 100644 --- a/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md +++ b/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md @@ -165,7 +165,6 @@ POST my-image-index/_search Knn query can be used as a part of hybrid search, where knn query is combined with other lexical queries. For example, the query below finds documents with `title` matching `mountain lake`, and combines them with the top 10 documents that have the closest image vectors to the `query_vector`. The combined documents are then scored and the top 3 top scored documents are returned. -+ ```console POST my-image-index/_search diff --git a/docs/reference/query-languages/query-dsl/query-dsl-rule-query.md b/docs/reference/query-languages/query-dsl/query-dsl-rule-query.md index c682481cdc58a..3b017804cb1f5 100644 --- a/docs/reference/query-languages/query-dsl/query-dsl-rule-query.md +++ b/docs/reference/query-languages/query-dsl/query-dsl-rule-query.md @@ -14,7 +14,7 @@ mapped_pages: ::::{tip} -The rule query is not supported for use alongside reranking. If you want to use query rules in conjunction with reranking, use the [rule retriever](/reference/elasticsearch/rest-apis/retrievers.md#rule-retriever) instead. +The rule query is not supported for use alongside reranking. If you want to use query rules in conjunction with reranking, use the [rule retriever](/reference/elasticsearch/rest-apis/retrievers/rule-retriever.md) instead. :::: diff --git a/docs/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md b/docs/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md index 1db97a6f13967..7274da6058fe4 100644 --- a/docs/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md +++ b/docs/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md @@ -150,7 +150,7 @@ GET my-index/_search } ``` -This can also be achieved using [reciprocal rank fusion (RRF)](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md), through an [`rrf` retriever](/reference/elasticsearch/rest-apis/retrievers.md#rrf-retriever) with multiple [`standard` retrievers](/reference/elasticsearch/rest-apis/retrievers.md#standard-retriever). +This can also be achieved using [reciprocal rank fusion (RRF)](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md), through an [`rrf` retriever](/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md) with multiple [`standard` retrievers](/reference/elasticsearch/rest-apis/retrievers/standard-retriever.md). ```console GET my-index/_search diff --git a/docs/reference/query-languages/query-dsl/query-dsl-text-expansion-query.md b/docs/reference/query-languages/query-dsl/query-dsl-text-expansion-query.md index 60bf0e4776309..7c81e300a5abb 100644 --- a/docs/reference/query-languages/query-dsl/query-dsl-text-expansion-query.md +++ b/docs/reference/query-languages/query-dsl/query-dsl-text-expansion-query.md @@ -134,7 +134,7 @@ GET my-index/_search } ``` -This can also be achieved using [reciprocal rank fusion (RRF)](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md), through an [`rrf` retriever](/reference/elasticsearch/rest-apis/retrievers.md#rrf-retriever) with multiple [`standard` retrievers](/reference/elasticsearch/rest-apis/retrievers.md#standard-retriever). +This can also be achieved using [reciprocal rank fusion (RRF)](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md), through an [`rrf` retriever](/reference/elasticsearch/rest-apis/retrievers/rrf-retriever.md) with multiple [`standard` retrievers](/reference/elasticsearch/rest-apis/retrievers/standard-retriever.md). ```console GET my-index/_search diff --git a/docs/reference/query-languages/toc.yml b/docs/reference/query-languages/toc.yml index 31bead277f05f..6ecc4d08d81b9 100644 --- a/docs/reference/query-languages/toc.yml +++ b/docs/reference/query-languages/toc.yml @@ -119,6 +119,9 @@ toc: - file: esql/limitations.md - file: esql/esql-examples.md + - file: esql/esql-troubleshooting.md + children: + - file: esql/esql-query-log.md - file: sql.md children: - file: sql/sql-spec.md diff --git a/docs/reference/search-connectors/api-tutorial.md b/docs/reference/search-connectors/api-tutorial.md index 7b691f44e9c64..dfb5fc5dee4ac 100644 --- a/docs/reference/search-connectors/api-tutorial.md +++ b/docs/reference/search-connectors/api-tutorial.md @@ -6,34 +6,31 @@ applies_to: elasticsearch: ga mapped_pages: - https://www.elastic.co/guide/en/elasticsearch/reference/current/es-connectors-tutorial-api.html +description: Use APIs to synchronize data from a PostgreSQL data source into Elasticsearch. --- # Connector API tutorial [es-connectors-tutorial-api] -Learn how to set up a self-managed connector using the [{{es}} Connector APIs](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-connector). +Learn how to set up a self-managed connector using the [{{es}} connector APIs]({{es-apis}}group/endpoint-connector). -For this example we’ll use the connectors-postgresql,PostgreSQL connector to sync data from a PostgreSQL database to {{es}}. We’ll spin up a simple PostgreSQL instance in Docker with some example data, create a connector, and sync the data to {{es}}. You can follow the same steps to set up a connector for another data source. +For this example we’ll use the [PostgreSQL connector](/reference/search-connectors/es-connectors-postgresql.md) to sync data from a PostgreSQL database to {{es}}. We’ll spin up a simple PostgreSQL instance in Docker with some example data, create a connector, and sync the data to {{es}}. You can follow the same steps to set up a connector for another data source. ::::{tip} -This tutorial focuses on running a self-managed connector on your own infrastructure, and managing syncs using the Connector APIs. See connectors for an overview of how connectors work. +This tutorial focuses on running a self-managed connector on your own infrastructure, and managing syncs using the connector APIs. If you’re just getting started with {{es}}, this tutorial might be a bit advanced. Refer to [quickstart](docs-content://solutions/search/get-started.md) for a more beginner-friendly introduction to {{es}}. -If you’re just getting started with connectors, you might want to start in the UI first. Check out this tutorial that focuses on managing connectors using the UI: - -* [Self-managed connector tutorial](/reference/search-connectors/es-postgresql-connector-client-tutorial.md). Set up a self-managed PostgreSQL connector. +If you’re just getting started with connectors, you might want to start in the UI first. Check out this tutorial that focuses on managing connectors using the UI: [](/reference/search-connectors/es-postgresql-connector-client-tutorial.md). :::: - -### Prerequisites [es-connectors-tutorial-api-prerequisites] +## Prerequisites [es-connectors-tutorial-api-prerequisites] * You should be familiar with how connectors, connectors work, to understand how the API calls relate to the overall connector setup. * You need to have [Docker Desktop](https://www.docker.com/products/docker-desktop/) installed. * You need to have {{es}} running, and an API key to access it. Refer to the next section for details, if you don’t have an {{es}} deployment yet. - -### Set up {{es}} [es-connectors-tutorial-api-setup-es] +## Set up {{es}} [es-connectors-tutorial-api-setup-es] If you already have an {{es}} deployment on Elastic Cloud (*Hosted deployment* or *Serverless project*), you’re good to go. To spin up {{es}} in local dev mode in Docker for testing purposes, open the collapsible section below. @@ -73,7 +70,8 @@ Note: With {{es}} running locally, you will need to pass the username and passwo ::::{admonition} Running API calls -You can run API calls using the [Dev Tools Console](docs-content://explore-analyze/query-filter/tools/console.md) in Kibana, using `curl` in your terminal, or with our programming language clients. Our example widget allows you to copy code examples in both Dev Tools Console syntax and curl syntax. To use curl, you’ll need to add authentication headers to your request. +You can run API calls using the [Dev Tools Console](docs-content://explore-analyze/query-filter/tools/console.md) in Kibana, using `curl` in your terminal, or with our programming language clients. +To use curl, you’ll need to add authentication headers to your request. Here’s an example of how to do that. Note that if you want the connector ID to be auto-generated, use the `POST _connector` endpoint. @@ -88,13 +86,11 @@ curl -s -X PUT http://localhost:9200/_connector/my-connector-id \ }' ``` -Refer to connectors-tutorial-api-create-api-key for instructions on creating an API key. +Refer to [](/reference/search-connectors/es-postgresql-connector-client-tutorial.md) for instructions on creating an API key. :::: - - -### Run PostgreSQL instance in Docker (optional) [es-connectors-tutorial-api-setup-postgres] +## Run PostgreSQL instance in Docker (optional) [es-connectors-tutorial-api-setup-postgres] For this tutorial, we’ll set up a PostgreSQL instance in Docker with some example data. Of course, you can **skip this step and use your own existing PostgreSQL instance** if you have one. Keep in mind that using a different instance might require adjustments to the connector configuration described in the next steps. @@ -105,7 +101,7 @@ Let’s launch a PostgreSQL container with a user and password, exposed at port docker run --name postgres -e POSTGRES_USER=myuser -e POSTGRES_PASSWORD=mypassword -p 5432:5432 -d postgres ``` -**Download and import example data** +### Download and import example data Next we need to create a directory to store our example dataset for this tutorial. In your terminal, run the following command: @@ -145,10 +141,9 @@ This tutorial uses a very basic setup. To use advanced functionality such as fil Now it’s time for the real fun! We’ll set up a connector to create a searchable mirror of our PostgreSQL data in {{es}}. +## Create a connector [es-connectors-tutorial-api-create-connector] -### Create a connector [es-connectors-tutorial-api-create-connector] - -We’ll use the [Create connector API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-put) to create a PostgreSQL connector instance. +We’ll use the [create connector API]({{es-apis}}operation/operation-connector-put) to create a PostgreSQL connector instance. Run the following API call, using the [Dev Tools Console](docs-content://explore-analyze/query-filter/tools/console.md) or `curl`: @@ -171,10 +166,9 @@ Note that we specified the `my-connector-id` ID as a part of the `PUT` request. If you’d prefer to use an autogenerated ID, replace `PUT _connector/my-connector-id` with `POST _connector`. +## Set up the connector service [es-connectors-tutorial-api-deploy-connector] -### Run connector service [es-connectors-tutorial-api-deploy-connector] - -Now we’ll run the connector service so we can start syncing data from our PostgreSQL instance to {{es}}. We’ll use the steps outlined in connectors-run-from-docker. +Now we’ll run the connector service so we can start syncing data from our PostgreSQL instance to {{es}}. We’ll use the steps outlined in [](/reference/search-connectors/es-connectors-run-from-docker.md). When running the connectors service on your own infrastructure, you need to provide a configuration file with the following details: @@ -183,10 +177,9 @@ When running the connectors service on your own infrastructure, you need to prov * Your third-party data source type (`service_type`) * Your connector ID (`connector_id`) +### Create an API key [es-connectors-tutorial-api-create-api-key] -#### Create an API key [es-connectors-tutorial-api-create-api-key] - -If you haven’t already created an API key to access {{es}}, you can use the [_security/api_key](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-security-create-api-key) endpoint. +If you haven’t already created an API key to access {{es}}, you can use the [_security/api_key]({{es-apis}}operation/operation-security-create-api-key) endpoint. Here, we assume your target {{es}} index name is `music`. If you use a different index name, adjust the request body accordingly. @@ -225,9 +218,7 @@ You can also create an API key in the {{kib}} and Serverless UIs. :::: - - -#### Prepare the configuration file [es-connectors-tutorial-api-prepare-configuration-file] +### Prepare the configuration file [es-connectors-tutorial-api-prepare-configuration-file] Let’s create a directory and a `config.yml` file to store the connector configuration: @@ -249,8 +240,7 @@ connectors: We provide an [example configuration file](https://raw.githubusercontent.com/elastic/connectors/main/config.yml.example) in the `elastic/connectors` repository for reference. - -#### Run the connector service [es-connectors-tutorial-api-run-connector-service] +### Run the service [es-connectors-tutorial-api-run-connector-service] Now that we have the configuration file set up, we can run the connector service locally. This will point your connector instance at your {{es}} deployment. @@ -273,12 +263,11 @@ Verify your connector is connected by getting the connector status (should be `n GET _connector/my-connector-id ``` - -### Configure connector [es-connectors-tutorial-api-update-connector-configuration] +## Configure the connector [es-connectors-tutorial-api-update-connector-configuration] Now our connector instance is up and running, but it doesn’t yet know *where* to sync data from. The final piece of the puzzle is to configure our connector with details about our PostgreSQL instance. When setting up a connector in the Elastic Cloud or Serverless UIs, you’re prompted to add these details in the user interface. -But because this tutorial is all about working with connectors *programmatically*, we’ll use the [Update connector configuration API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-update-configuration) to add our configuration details. +But because this tutorial is all about working with connectors *programmatically*, we’ll use the [update connector configuration API]({{es-apis}}operation/operation-connector-update-configuration) to add our configuration details. ::::{tip} Before configuring the connector, ensure that the configuration schema is registered by the service. For self-managed connectors, the schema registers on service startup (once the `config.yml` is populated). @@ -310,9 +299,7 @@ Configuration details are specific to the connector type. The keys and values wi :::: - - -### Sync data [es-connectors-tutorial-api-sync] +## Sync your data [es-connectors-tutorial-api-sync] We’re now ready to sync our PostgreSQL data to {{es}}. Run the following API call to start a full sync job: @@ -327,15 +314,13 @@ POST _connector/_sync_job To store data in {{es}}, the connector needs to create an index. When we created the connector, we specified the `music` index. The connector will create and configure this {{es}} index before launching the sync job. ::::{tip} -In the approach we’ve used here, the connector will use [dynamic mappings](docs-content://manage-data/data-store/mapping.md#mapping-dynamic) to automatically infer the data types of your fields. In a real-world scenario you would use the {{es}} [Create index API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-create) to first create the index with the desired field mappings and index settings. Defining your own mappings upfront gives you more control over how your data is indexed. +In the approach we’ve used here, the connector will use [dynamic mappings](docs-content://manage-data/data-store/mapping.md#mapping-dynamic) to automatically infer the data types of your fields. In a real-world scenario you would use the {{es}} [create index API]({{es-apis}}operation/operation-indices-create) to first create the index with the desired field mappings and index settings. Defining your own mappings upfront gives you more control over how your data is indexed. :::: +### Check sync status [es-connectors-tutorial-api-check-sync-status] - -#### Check sync status [es-connectors-tutorial-api-check-sync-status] - -Use the [Get sync job API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-connector-sync-job-get) to track the status and progress of the sync job. By default, the most recent job statuses are returned first. Run the following API call to check the status of the sync job: +Use the [get sync job API]({{es-apis}}operation/operation-connector-sync-job-get) to track the status and progress of the sync job. By default, the most recent job statuses are returned first. Run the following API call to check the status of the sync job: ```console GET _connector/_sync_job?connector_id=my-connector-id&size=1 @@ -345,6 +330,8 @@ The job document will be updated as the sync progresses, you can check it as oft Once the job completes, the status should be `completed` and `indexed_document_count` should be **622**. +## Query your data + Verify that data is present in the `music` index with the following API call: ```console @@ -357,8 +344,7 @@ GET music/_count GET music/_search ``` - -## Troubleshooting [es-connectors-tutorial-api-troubleshooting] +## Troubleshoot [es-connectors-tutorial-api-troubleshooting] Use the following command to inspect the latest sync job’s status: @@ -369,7 +355,7 @@ GET _connector/_sync_job?connector_id=my-connector-id&size=1 If the connector encountered any errors during the sync, you’ll find these in the `error` field. -### Cleaning up [es-connectors-tutorial-api-cleanup] +## Clean up [es-connectors-tutorial-api-cleanup] To delete the connector and its associated sync jobs run this command: @@ -397,13 +383,12 @@ docker stop docker rm ``` +## Next steps [es-connectors-tutorial-api-next-steps] -### Next steps [es-connectors-tutorial-api-next-steps] - -Congratulations! You’ve successfully set up a self-managed connector using the Connector APIs. +Congratulations! You’ve successfully set up a self-managed connector using the connector APIs. Here are some next steps to explore: -* Learn more about the [Connector APIs](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-connector). +* Learn more about the [connector APIs]({{es-apis}}group/endpoint-connector). * Learn how to deploy {{es}}, {{kib}}, and the connectors service using Docker Compose in our [quickstart guide](https://github.com/elastic/connectors/tree/main/scripts/stack#readme). diff --git a/docs/reference/search-connectors/connectors-ui-in-kibana.md b/docs/reference/search-connectors/connectors-ui-in-kibana.md index c5b4d72bad60c..c15fbd2f67284 100644 --- a/docs/reference/search-connectors/connectors-ui-in-kibana.md +++ b/docs/reference/search-connectors/connectors-ui-in-kibana.md @@ -11,10 +11,10 @@ mapped_pages: This document describes operations available to connectors using the UI. -In the Kibana or Serverless UI, find Connectors using the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). Here, you can view a summary of all your connectors and sync jobs, and to create new connectors. +In the Kibana or Serverless UI, find **{{connectors-app}}** using the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). Here, you can view a summary of all your connectors and sync jobs, and to create new connectors. ::::{tip} -In 8.12 we introduced a set of [Connector APIs](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-connector) to create and manage Elastic connectors and sync jobs, along with a [CLI tool](https://github.com/elastic/connectors/blob/main/docs/CLI.md). Use these tools if you’d like to work with connectors and sync jobs programmatically, without using the UI. +In 8.12 we introduced a set of [connector APIs]({{es-apis}}group/endpoint-connector) to create and manage Elastic connectors and sync jobs, along with a [CLI tool](https://github.com/elastic/connectors/blob/main/docs/CLI.md). Use these tools if you’d like to work with connectors and sync jobs programmatically, without using the UI. :::: @@ -24,13 +24,13 @@ In 8.12 we introduced a set of [Connector APIs](https://www.elastic.co/docs/api/ You connector writes data to an {{es}} index. -To create self-managed [**self-managed connector**](/reference/search-connectors/self-managed-connectors.md), use the buttons under **Search > Content > Connectors**. Once you’ve chosen the data source type you’d like to sync, you’ll be prompted to create an {{es}} index. +To create [self-managed connectors](/reference/search-connectors/self-managed-connectors.md), use the buttons under **{{es}} > Content > {{connectors-app}}**. Once you’ve chosen the data source type you’d like to sync, you’ll be prompted to create an {{es}} index. ## Manage connector indices [es-connectors-usage-indices] View and manage all Elasticsearch indices managed by connectors. -In the {{kib}} UI, navigate to **Search > Content > Connectors** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). Here, you can view a list of connector indices and their attributes, including connector type health and ingestion status. +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). Here, you can view a list of connector indices and their attributes, including connector type health and ingestion status. Within this interface, you can choose to view the details for each existing index or delete an index. Or, you can [create a new connector index](#es-connectors-usage-index-create). @@ -41,21 +41,21 @@ These operations require access to Kibana and additional index privileges. {{es}} stores your data as documents in an index. Each index is made up of a set of fields and each field has a type (such as `keyword`, `boolean`, or `date`). -**Mapping** is the process of defining how a document, and the fields it contains, are stored and indexed. Connectors use [dynamic mapping](docs-content://manage-data/data-store/mapping/dynamic-field-mapping.md) to automatically create mappings based on the data fetched from the source. +Mapping is the process of defining how a document, and the fields it contains, are stored and indexed. Connectors use [dynamic mapping](docs-content://manage-data/data-store/mapping/dynamic-field-mapping.md) to automatically create mappings based on the data fetched from the source. -Index **settings** are configurations that can be adjusted on a per-index basis. They control things like the index’s performance, the resources it uses, and how it should handle operations. +Index settings are configurations that can be adjusted on a per-index basis. They control things like the index’s performance, the resources it uses, and how it should handle operations. -When you create an index with a connector, the index is created with *default* search-optimized field template mappings and index settings. Mappings for specific fields are then dynamically created based on the data fetched from the source. +When you create an index with a connector, the index is created with default search-optimized field template mappings and index settings. Mappings for specific fields are then dynamically created based on the data fetched from the source. You can inspect your index mappings in the following ways: -* **In the {{kib}} UI**: Navigate to **Search > Content > Indices > *YOUR-INDEX* > Index Mappings** -* **By API**: Use the [Get mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-get-mapping) +* In the {{kib}} UI: Navigate to **{{es}} > Content > Indices > *YOUR-INDEX* > Index Mappings**. +* By API: Use the [get mapping API]({{es-apis}}operation/operation-indices-get-mapping). You can manually **edit** the mappings and settings via the {{es}} APIs: -* Use the [Put mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) to update index mappings. -* Use the [Update index settings API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-settings) to update index settings. +* Use the [put mapping API]({{es-apis}}operation/operation-indices-put-mapping) to update index mappings. +* Use the [update index settings API]({{es-apis}}operation/operation-indices-put-settings) to update index settings. It’s important to note that these updates are more complex when the index already contains data. @@ -69,12 +69,12 @@ Updating mappings and settings is simpler when your index has no data. If you cr ### Customize mappings and settings after syncing data [es-connectors-usage-index-create-configure-existing-index-have-data] -Once data has been added to {{es}} using dynamic mappings, you can’t directly update existing field mappings. If you’ve already synced data into an index and want to change the mappings, you’ll need to [reindex your data](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-reindex). +Once data has been added to {{es}} using dynamic mappings, you can’t directly update existing field mappings. If you’ve already synced data into an index and want to change the mappings, you’ll need to [reindex your data]({{es-apis}}operation/operation-reindex). The workflow for these updates is as follows: -1. [Create](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-create) a new index with the desired mappings and settings. -2. [Reindex](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-reindex) your data from the old index into this new index. +1. [Create]({{es-apis}}operation/operation-indices-create) a new index with the desired mappings and settings. +2. [Reindex]({{es-apis}}operation/operation-reindex) your data from the old index into this new index. 3. Delete the old index. 4. (Optional) Use an [alias](docs-content://manage-data/data-store/aliases.md), if you want to retain the old index name. 5. Attach your connector to the new index or alias. @@ -84,9 +84,9 @@ The workflow for these updates is as follows: After creating an index to be managed by a connector, you can configure automatic, recurring syncs. -In the {{kib}} UI, navigate to **Search > Content > Connectors** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). -Choose the index to configure, and then choose the **Scheduling** tab. +Choose the connector and then the **Scheduling** tab. Within this interface, you can enable or disable scheduled: @@ -107,9 +107,9 @@ After you enable recurring syncs or sync once, the first sync will begin. (There After creating the index to be managed by a connector, you can request a single sync at any time. -In the {{kib}} UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). -Then choose the index to sync. +Then choose the connector to sync. Regardless of which tab is active, the **Sync** button is always visible in the top right. Choose this button to reveal sync options: @@ -117,7 +117,7 @@ Regardless of which tab is active, the **Sync** button is always visible in the 2. Incremental content (if supported) 3. Access control (if supported) -Choose one of the options to request a sync. (There may be a short delay before the connector service begins the sync.) +Choose one of the options to request a sync. There may be a short delay before the connector service begins the sync. This operation requires access to Kibana and the `write` [indices privilege^](/reference/elasticsearch/security-privileges.md) for the `.elastic-connectors` index. @@ -126,9 +126,9 @@ This operation requires access to Kibana and the `write` [indices privilege^](/r After a sync has started, you can cancel the sync before it completes. -In the {{kib}} UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). -Then choose the index with the running sync. +Then choose the connector with the running sync. Regardless of which tab is active, the **Sync** button is always visible in the top right. Choose this button to reveal sync options, and choose **Cancel Syncs** to cancel active syncs. This will cancel the running job, and marks all *pending* and *suspended* jobs as canceled as well. (There may be a short delay before the connector service cancels the syncs.) @@ -139,9 +139,9 @@ This operation requires access to Kibana and the `write` [indices privilege^](/r View the index details to see a variety of information that communicate the status of the index and connector. -In the {{kib}} UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). -Then choose the index to view. +Then choose the connector to view. The **Overview** tab presents a variety of information, including: @@ -150,7 +150,7 @@ The **Overview** tab presents a variety of information, including: * The current ingestion status (see below for possible values). * The current document count. -Possible values of ingestion status: +Possible values of ingestion status include: * Incomplete - A connector that is not configured yet. * Configured - A connector that is configured. @@ -159,9 +159,8 @@ Possible values of ingestion status: * Connector failure - A connector that has not seen any update for more than 30 minutes. * Sync failure - A connector that failed in the last sync job. -This tab also displays the recent sync history, including sync status (see below for possible values). - -Possible values of sync status: +This tab also displays the recent sync history, including sync status. +Possible values of sync status include: * Sync pending - The initial job status, the job is pending to be picked up. * Sync in progress - The job is running. @@ -186,11 +185,9 @@ This operation requires access to Kibana and the `read` [indices privilege^](/re View the documents the connector has synced from the data. Additionally view the index mappings to determine the current document schema. -In the {{kib}} UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). - -Then choose the index to view. +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). -Choose the **Documents** tab to view the synced documents. Choose the **Index Mappings** tab to view the index mappings that were created by the connector. +Select the connector then the **Documents** tab to view the synced documents. Choose the **Mappings** tab to view the index mappings that were created by the connector. When setting up a new connector, ensure you are getting the documents and fields you were expecting from the data source. If not, see [Troubleshooting](/reference/search-connectors/es-connectors-troubleshooting.md) for help. @@ -203,7 +200,7 @@ See [Security](/reference/search-connectors/es-connectors-security.md) for secur Use [sync rules](/reference/search-connectors/es-sync-rules.md) to limit which documents are fetched from the data source, or limit which fetched documents are stored in Elastic. -In the {{kib}} UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). +In the {{kib}} UI, navigate to **{{es}} > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). Then choose the index to manage and choose the **Sync rules** tab. @@ -212,7 +209,5 @@ Then choose the index to manage and choose the **Sync rules** tab. Use [ingest pipelines](docs-content://solutions/search/ingest-for-search.md) to transform fetched data before it is stored in Elastic. -In the {{kib}} UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). - -Then choose the index to manage and choose the **Pipelines** tab. - +In the {{kib}} UI, navigate to **{{es}} > Content > {{connectors-app}}** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). +Then choose the connector and view its **Pipelines** tab. diff --git a/docs/reference/search-connectors/es-postgresql-connector-client-tutorial.md b/docs/reference/search-connectors/es-postgresql-connector-client-tutorial.md index 54a275e8184fb..8cc8f9aee771b 100644 --- a/docs/reference/search-connectors/es-postgresql-connector-client-tutorial.md +++ b/docs/reference/search-connectors/es-postgresql-connector-client-tutorial.md @@ -3,28 +3,33 @@ navigation_title: "Tutorial" mapped_pages: - https://www.elastic.co/guide/en/elasticsearch/reference/current/es-postgresql-connector-client-tutorial.html - https://www.elastic.co/guide/en/starting-with-the-elasticsearch-platform-and-its-solutions/current/getting-started-appsearch.html +applies_to: + stack: ga + serverless: + elasticsearch: ga +description: Synchronize data from a PostgreSQL data source into Elasticsearch. --- -# PostgreSQL self-managed connector tutorial [es-postgresql-connector-client-tutorial] +# Set up a self-managed PostgreSQL connector - -This tutorial walks you through the process of creating a self-managed connector for a PostgreSQL data source. You’ll be using the [self-managed connector](/reference/search-connectors/self-managed-connectors.md) workflow in the Kibana UI. This means you’ll be deploying the connector on your own infrastructure. Refer to the [Elastic PostgreSQL connector reference](/reference/search-connectors/es-connectors-postgresql.md) for more information about this connector. - -In this exercise, you’ll be working in both the terminal (or your IDE) and the Kibana UI. +Elastic connectors enable you to create searchable, read-only replicas of your data sources in {{es}}. +This tutorial walks you through the process of creating a self-managed connector for a PostgreSQL data source. If you want to deploy a self-managed connector for another data source, use this tutorial as a blueprint. Refer to the list of available [connectors](/reference/search-connectors/index.md). ::::{tip} -Want to get started quickly testing a self-managed connector using Docker Compose? Refer to this [guide](https://github.com/elastic/connectors/tree/main/scripts/stack#readme) in the `elastic/connectors` repo for more information. +Want to get started quickly testing a self-managed connector and a self-managed cluster using Docker Compose? Refer to the [readme](https://github.com/elastic/connectors/tree/main/scripts/stack#readme) in the `elastic/connectors` repo for more information. :::: ## Prerequisites [es-postgresql-connector-client-tutorial-prerequisites] - ### Elastic prerequisites [es-postgresql-connector-client-tutorial-prerequisites-elastic] -First, ensure you satisfy the [prerequisites](/reference/search-connectors/self-managed-connectors.md#es-build-connector-prerequisites) for self-managed connectors. +You must satisfy the [prerequisites](/reference/search-connectors/self-managed-connectors.md#es-build-connector-prerequisites) for self-managed connectors. + ### PostgreSQL prerequisites [es-postgresql-connector-client-tutorial-postgresql-prerequisites] @@ -47,142 +52,119 @@ Then restart the PostgreSQL server. :::: +## Set up the connector +:::::{stepper} +::::{step} Create an Elasticsearch index +To store data in {{es}}, the connector needs to create an index. +By default, connectors use [dynamic mappings](docs-content://manage-data/data-store/mapping.md#mapping-dynamic) to automatically infer the data types of your fields. +If you [use APIs](/reference/search-connectors/api-tutorial.md) or {{es-serverless}}, you can create an index with the desired field mappings and index settings before you create the connector. +Defining your own mappings upfront gives you more control over how your data is indexed. -## Steps [es-postgresql-connector-client-tutorial-steps] - -To complete this tutorial, you’ll need to complete the following steps: - -1. [Create an Elasticsearch index](#es-postgresql-connector-client-tutorial-create-index) -2. [Set up the connector](#es-postgresql-connector-client-tutorial-setup-connector) -3. [Run the `connectors` connector service](#es-postgresql-connector-client-tutorial-run-connector-service) -4. [Sync your PostgreSQL data source](#es-postgresql-connector-client-tutorial-sync-data-source) - - -## Create an Elasticsearch index [es-postgresql-connector-client-tutorial-create-index] - -Elastic connectors enable you to create searchable, read-only replicas of your data sources in Elasticsearch. The first step in setting up your self-managed connector is to create an index. - -In the [Kibana^](docs-content://get-started/the-stack.md) UI, navigate to **Search > Content > Elasticsearch indices** from the main menu, or use the [global search field](docs-content://explore-analyze/query-filter/filtering.md#_finding_your_apps_and_objects). - -Create a new connector index: - -1. Under **Select an ingestion method** choose **Connector**. -2. Choose **PostgreSQL** from the list of connectors. -3. Name your index and optionally change the language analyzer to match the human language of your data source. (The index name you provide is automatically prefixed with `search-`.) -4. Save your changes. - -The index is created and ready to configure. - -::::{admonition} Gather Elastic details -:name: es-postgresql-connector-client-tutorial-gather-elastic-details - -Before you can configure the connector, you need to gather some details about your Elastic deployment: - -* **Elasticsearch endpoint**. - - * If you’re an Elastic Cloud user, find your deployment’s Elasticsearch endpoint in the Cloud UI under **Cloud > Deployments > > Elasticsearch**. - * If you’re running your Elastic deployment and the connector service in Docker, the default Elasticsearch endpoint is `http://host.docker.internal:9200`. - -* **API key.** You’ll need this key to configure the connector. Use an existing key or create a new one. -* **Connector ID**. Your unique connector ID is automatically generated when you create the connector. Find this in the Kibana UI. - +Navigate to **{{index-manage-app}}** or use the [global search field](docs-content://explore-analyze/find-and-organize/find-apps-and-objects.md). +Follow the index creation workflow then optionally define field mappings. +For example, to add semantic search capabilities, you could add an extra field that stores your vectors for semantic search. :::: +::::{step} Create the connector +Navigate to **{{connectors-app}}** or use the global search field. +Follow the connector creation process in the UI. For example: + +1. Select **PostgreSQL** from the list of connectors. +1. Edit the name and description for the connector. This will help your team identify the connector. +1. Gather configuration details. + Before you can proceed to the next step, you need to gather some details about your Elastic deployment: + + * Elasticsearch endpoint: + * If you’re an Elastic Cloud user, find your deployment’s Elasticsearch endpoint in the Cloud UI under **Cloud > Deployments > > Elasticsearch**. + * If you’re running your Elastic deployment and the connector service in Docker, the default Elasticsearch endpoint is `http://host.docker.internal:9200`. + * API key: You’ll need this key to configure the connector. Use an existing key or create a new one. + * Connector ID: Your unique connector ID is automatically generated when you create the connector. +:::: +::::{step} Run the connector service +You must run the connector code on your own infrastructure and link it to {{es}}. +You have two options: [Run with Docker](/reference/search-connectors/es-connectors-run-from-docker.md) and [Run from source](/reference/search-connectors/es-connectors-run-from-source.md). +For this example, we’ll use the latter method: + +1. Clone or fork the repository locally with the following command: `git clone https://github.com/elastic/connectors`. +1. Open the `config.yml` configuration file in your editor of choice. +1. Replace the values for `host`, `api_key`, and `connector_id` with the values you gathered earlier. Use the `service_type` value `postgresql` for this connector. + + :::{dropdown} Expand to see an example config.yml file + Replace the values for `host`, `api_key`, and `connector_id` with your own values. + Use the `service_type` value `postgresql` for this connector. + + ```yaml + elasticsearch: + host: ">" # Your Elasticsearch endpoint + api_key: "" # Your top-level Elasticsearch API key + ... + connectors: + - + connector_id: "" + api_key: "" # Your scoped connector index API key (optional). If not provided, the top-level API key is used. + service_type: "postgresql" + + sources: + # mongodb: connectors.sources.mongo:MongoDataSource + # s3: connectors.sources.s3:S3DataSource + # dir: connectors.sources.directory:DirectoryDataSource + # mysql: connectors.sources.mysql:MySqlDataSource + # network_drive: connectors.sources.network_drive:NASDataSource + # google_cloud_storage: connectors.sources.google_cloud_storage:GoogleCloudStorageDataSource + # azure_blob_storage: connectors.sources.azure_blob_storage:AzureBlobStorageDataSource + postgresql: connectors.sources.postgresql:PostgreSQLDataSource + # oracle: connectors.sources.oracle:OracleDataSource + # sharepoint: connectors.sources.sharepoint:SharepointDataSource + # mssql: connectors.sources.mssql:MSSQLDataSource + # jira: connectors.sources.jira:JiraDataSource + ``` + +1. Now that you’ve configured the connector code, you can run the connector service. In your terminal or IDE: + + 1. `cd` into the root of your `connectors` clone/fork. + 1. Run the following command: `make run`. + +The connector service should now be running. +The UI will let you know that the connector has successfully connected to {{es}}. + +:::{tip} +Here we’re working locally. In production setups, you’ll deploy the connector service to your own infrastructure. +::: +:::: +::::{step} Add your data source details + +Now your connector instance is up and running, but it doesn’t yet know where to sync data from. +The final piece of the puzzle is to configure your connector with details about the PostgreSQL instance. + +Return to **{{connectors-app}}** to complete the connector creation process in the UI. +Enter the following PostgreSQL instance details: + +* **Host**: The server host address for your PostgreSQL instance. +* **Port**: The port number for your PostgreSQL instance. +* **Username**: The username of the PostgreSQL account. +* **Password**: The password for that user. +* **Database**: The name of the PostgreSQL database. +* **Schema**: The schema of the PostgreSQL database. +* **Comma-separated list of tables**: `*` will fetch data from all tables in the configured database. + +:::{note} +Configuration details are specific to the connector type. +The keys and values will differ depending on which third-party data source you’re connecting to. +Refer to the [](/reference/search-connectors/es-connectors-postgresql.md) for more details. +::: +:::: +::::{step} Link your index +If you [use APIs](/reference/search-connectors/api-tutorial.md) or {{es-serverless}}, you can create an index or choose an existing index for use by the connector. +Otherwise, the index is created for you and uses dynamic mappings for the fields. +:::: +::::: +## Sync your data [es-postgresql-connector-client-tutorial-sync-data-source] +In the **{{connectors-app}}** page, you can launch a sync on-demand or on a schedule. +The connector will traverse the database and synchronize documents to your index. -## Set up the connector [es-postgresql-connector-client-tutorial-setup-connector] - -Once you’ve created an index, you can set up the connector. You will be guided through this process in the UI. - -1. **Edit the name and description for the connector.** This will help your team identify the connector. -2. **Clone and edit the connector service code.** For this example, we’ll use the [Python framework](https://github.com/elastic/connectors/tree/main). Follow these steps: - - * Clone or fork that repository locally with the following command: `git clone https://github.com/elastic/connectors`. - * Open the `config.yml` configuration file in your editor of choice. - * Replace the values for `host`, `api_key`, and `connector_id` with the values you gathered [earlier](#es-postgresql-connector-client-tutorial-gather-elastic-details). Use the `service_type` value `postgresql` for this connector. - - ::::{dropdown} Expand to see an example config.yml file - Replace the values for `host`, `api_key`, and `connector_id` with your own values. Use the `service_type` value `postgresql` for this connector. - - ```yaml - elasticsearch: - host: > # Your Elasticsearch endpoint - api_key: '' # Your top-level Elasticsearch API key - ... - connectors: - - - connector_id: "" - api_key: "'" # Your scoped connector index API key (optional). If not provided, the top-level API key is used. - service_type: "postgresql" - - - - # Self-managed connector settings - connector_id: '' # Your connector ID - service_type: 'postgresql' # The service type for your connector - - sources: - # mongodb: connectors.sources.mongo:MongoDataSource - # s3: connectors.sources.s3:S3DataSource - # dir: connectors.sources.directory:DirectoryDataSource - # mysql: connectors.sources.mysql:MySqlDataSource - # network_drive: connectors.sources.network_drive:NASDataSource - # google_cloud_storage: connectors.sources.google_cloud_storage:GoogleCloudStorageDataSource - # azure_blob_storage: connectors.sources.azure_blob_storage:AzureBlobStorageDataSource - postgresql: connectors.sources.postgresql:PostgreSQLDataSource - # oracle: connectors.sources.oracle:OracleDataSource - # sharepoint: connectors.sources.sharepoint:SharepointDataSource - # mssql: connectors.sources.mssql:MSSQLDataSource - # jira: connectors.sources.jira:JiraDataSource - ``` - - :::: - - - -## Run the connector service [es-postgresql-connector-client-tutorial-run-connector-service] - -Now that you’ve configured the connector code, you can run the connector service. - -In your terminal or IDE: - -1. `cd` into the root of your `connectors` clone/fork. -2. Run the following command: `make run`. - -The connector service should now be running. The UI will let you know that the connector has successfully connected to Elasticsearch. - -Here we’re working locally. In production setups, you’ll deploy the connector service to your own infrastructure. If you prefer to use Docker, refer to the [repo docs](https://github.com/elastic/connectors/tree/main/docs/DOCKER.md) for instructions. - - -## Sync your PostgreSQL data source [es-postgresql-connector-client-tutorial-sync-data-source] - - -### Enter your PostgreSQL data source details [es-postgresql-connector-client-tutorial-sync-data-source-details] - -Once you’ve configured the connector, you can use it to index your data source. - -You can now enter your PostgreSQL instance details in the Kibana UI. - -Enter the following information: - -* **Host**. Server host address for your PostgreSQL instance. -* **Port**. Port number for your PostgreSQL instance. -* **Username**. Username of the PostgreSQL account. -* **Password**. Password for that user. -* **Database**. Name of the PostgreSQL database. -* **Comma-separated list of tables**. `*` will fetch data from all tables in the configured database. - -Once you’ve entered all these details, select **Save configuration**. - - -### Launch a sync [es-postgresql-connector-client-tutorial-sync-data-source-launch-sync] - -If you navigate to the **Overview** tab in the Kibana UI, you can see the connector’s *ingestion status*. This should now have changed to **Configured**. - -It’s time to launch a sync by selecting the **Sync** button. - -If you navigate to the terminal window where you’re running the connector service, you should see output like the following: +If you navigate to the terminal window where you’re running the connector service, after a sync occurs you should see output like the following: ```shell [FMWK][13:22:26][INFO] Fetcher @@ -193,14 +175,20 @@ If you navigate to the terminal window where you’re running the connector serv (27 seconds) ``` -This confirms the connector has fetched records from your PostgreSQL table(s) and transformed them into documents in your Elasticsearch index. +This confirms the connector has fetched records from your PostgreSQL tables and transformed them into documents in your {{es}} index. + +If you verify your {{es}} documents and you’re happy with the results, set a recurring sync schedule. +This will ensure your searchable data in {{es}} is always up to date with changes to your PostgreSQL data source. -Verify your Elasticsearch documents in the **Documents** tab in the Kibana UI. +In **{{connectors-app}}**, click on the connector, and then click **Scheduling**. +For example, you can schedule your content to be synchronized at the top of every hour, as long as the connector is up and running. -If you’re happy with the results, set a recurring sync schedule in the **Scheduling** tab. This will ensure your *searchable* data in Elasticsearch is always up to date with changes to your PostgreSQL data source. +## Next steps +You just learned how to synchronize data from an external database to {{es}}. +For an overview of how to start searching and analyzing your data in Kibana, go to [Explore and analyze](docs-content://explore-analyze/index.md). -## Learn more [es-postgresql-connector-client-tutorial-learn-more] +Learn more: * [Overview of self-managed connectors and frameworks](/reference/search-connectors/self-managed-connectors.md) * [Elastic connector framework repository](https://github.com/elastic/connectors/tree/main) diff --git a/docs/reference/search-connectors/self-managed-connectors.md b/docs/reference/search-connectors/self-managed-connectors.md index 7892569d01714..a863a6a9790e3 100644 --- a/docs/reference/search-connectors/self-managed-connectors.md +++ b/docs/reference/search-connectors/self-managed-connectors.md @@ -18,7 +18,8 @@ Self-managed [Elastic connectors](/reference/search-connectors/index.md) are run ## Availability and Elastic prerequisites [es-build-connector-prerequisites] ::::{note} -Self-managed connectors currently don’t support Windows. Use this [compatibility matrix](https://www.elastic.co/support/matrix#matrix_os) to check which operating systems are supported by self-managed connectors. Find this information under **self-managed connectors** on that page. +Self-managed connectors currently don’t support Windows. Use this [compatibility matrix](https://www.elastic.co/support/matrix#matrix_os) to check which operating systems are supported by self-managed connectors. +% Find this information under **self-managed connectors** on that page. :::: @@ -28,7 +29,7 @@ Your Elastic deployment must include the following Elastic services: * **Elasticsearch** * **Kibana** -(A new Elastic Cloud deployment includes these services by default.) +A new {{ech}} deployment or {{es-serverless}} project includes these services by default. To run self-managed connectors, your self-deployed connector service version must match your Elasticsearch version. For example, if you’re running Elasticsearch 8.10.1, your connector service should be version 8.10.1.x. Elastic does not support deployments running mismatched versions (except during upgrades). diff --git a/docs/reference/text-analysis/analysis-decimal-digit-tokenfilter.md b/docs/reference/text-analysis/analysis-decimal-digit-tokenfilter.md index 9cdead1bf23dc..1d417361bfa29 100644 --- a/docs/reference/text-analysis/analysis-decimal-digit-tokenfilter.md +++ b/docs/reference/text-analysis/analysis-decimal-digit-tokenfilter.md @@ -9,7 +9,7 @@ mapped_pages: Converts all digits in the Unicode `Decimal_Number` General Category to `0-9`. For example, the filter changes the Bengali numeral `৩` to `3`. -This filter uses Lucene’s [DecimalDigitFilter](https://lucene.apache.org/core/10_0_0/analysis/common/org/apache/lucene/analysis/core/DecimalDigitFilter.md). +This filter uses Lucene’s [DecimalDigitFilter](https://lucene.apache.org/core/10_0_0/analysis/common/org/apache/lucene/analysis/core/DecimalDigitFilter.html). ## Example [analysis-decimal-digit-tokenfilter-analyze-ex] diff --git a/docs/release-notes/breaking-changes.md b/docs/release-notes/breaking-changes.md index 1fe3c4c0d8fc6..a3addb91199b7 100644 --- a/docs/release-notes/breaking-changes.md +++ b/docs/release-notes/breaking-changes.md @@ -12,6 +12,13 @@ If you are migrating from a version prior to version 9.0, you must first upgrade % ## Next version [elasticsearch-nextversion-breaking-changes] +```{applies_to} +stack: coming 9.0.4 +``` +## 9.0.4 [elasticsearch-9.0.4-breaking-changes] + +No breaking changes in this version. + ## 9.0.3 [elasticsearch-9.0.3-breaking-changes] No breaking changes in this version. @@ -21,6 +28,8 @@ No breaking changes in this version. Snapshot/Restore: * Make S3 custom query parameter optional [#128043](https://github.com/elastic/elasticsearch/pull/128043) + + ## 9.0.1 [elasticsearch-9.0.1-breaking-changes] No breaking changes in this version. diff --git a/docs/release-notes/changelog-bundles/9.0.4.yml b/docs/release-notes/changelog-bundles/9.0.4.yml new file mode 100644 index 0000000000000..3ed97d59be0e3 --- /dev/null +++ b/docs/release-notes/changelog-bundles/9.0.4.yml @@ -0,0 +1,97 @@ +version: 9.0.4 +released: false +generated: 2025-07-14T17:07:39.875346517Z +changelogs: + - pr: 129223 + summary: Fix text similarity reranker does not propagate min score correctly + area: Search + type: bug + issues: [] + - pr: 129325 + summary: Check for model deployment in inference endpoints before stopping + area: Machine Learning + type: bug + issues: + - 128549 + - pr: 129370 + summary: Avoid dropping aggregate groupings in local plans + area: ES|QL + type: bug + issues: + - 129811 + - 128054 + - pr: 129600 + summary: Make flattened synthetic source concatenate object keys on scalar/object mismatch + area: Mapping + type: bug + issues: + - 122936 + - pr: 129725 + summary: Throw a 400 when sorting for all types of range fields + area: Search + type: bug + issues: [] + - pr: 129904 + summary: Reverse disordered-version warning message + area: Infra/Core + type: bug + issues: [] + - pr: 130083 + summary: Fix timeout bug in DBQ deletion of unused and orphan ML data + area: Machine Learning + type: bug + issues: [] + - pr: 130303 + summary: Drain responses on completion for `TransportNodesAction` + area: Distributed + type: bug + issues: [] + - pr: 130448 + summary: Fix wildcard drop after lookup join + area: ES|QL + type: bug + issues: + - 129561 + - pr: 130452 + summary: "Aggs: Add cancellation checks to `FilterByFilter` aggregator" + area: Aggregations + type: bug + issues: [] + - pr: 130521 + summary: Trim to size lists created in source fetchers + area: Search + type: bug + issues: [] + - pr: 130576 + summary: Avoid O(N^2) in VALUES with ordinals grouping + area: ES|QL + type: bug + issues: [] + - pr: 130705 + summary: Fix `BytesRef2BlockHash` + area: ES|QL + type: bug + issues: [] + - pr: 130776 + summary: Fix msearch request parsing when index expression is null + area: Search + type: bug + issues: + - 129631 + - pr: 130924 + summary: Check field data type before casting when applying geo distance sort + area: Search + type: bug + issues: + - 129500 + - pr: 131032 + summary: "Fix: `GET _synonyms` returns synonyms with empty rules" + area: Relevance + type: bug + issues: [] + - pr: 131081 + summary: Fix knn search error when dimensions are not set + area: Vector Search + type: bug + issues: + - 129550 diff --git a/docs/release-notes/deprecations.md b/docs/release-notes/deprecations.md index be1029c187cca..6082a91ffb964 100644 --- a/docs/release-notes/deprecations.md +++ b/docs/release-notes/deprecations.md @@ -16,11 +16,20 @@ To give you insight into what deprecated features you’re using, {{es}}: % ## Next version [elasticsearch-nextversion-deprecations] +```{applies_to} +stack: coming 9.0.4 +``` +## 9.0.4 [elasticsearch-9.0.4-deprecations] + +No deprecations in this version. + ## 9.0.3 [elasticsearch-9.0.3-deprecations] Engine: * Deprecate `indices.merge.scheduler.use_thread_pool` setting [#129464](https://github.com/elastic/elasticsearch/pull/129464) + + ## 9.0.2 [elasticsearch-9.0.2-deprecations] No deprecations in this version. diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 477f840abc088..3cd8e2ed5737b 100644 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -20,6 +20,49 @@ To check for security updates, go to [Security announcements for the Elastic sta % ### Fixes [elasticsearch-next-fixes] % * +## 9.0.4 [elasticsearch-9.0.4-release-notes] +```{applies_to} +stack: coming 9.0.4 +``` + +### Fixes [elasticsearch-9.0.4-fixes] + +Aggregations: +* Aggs: Add cancellation checks to `FilterByFilter` aggregator [#130452](https://github.com/elastic/elasticsearch/pull/130452) + +Distributed: +* Drain responses on completion for `TransportNodesAction` [#130303](https://github.com/elastic/elasticsearch/pull/130303) + +ES|QL: +* Avoid O(N^2) in VALUES with ordinals grouping [#130576](https://github.com/elastic/elasticsearch/pull/130576) +* Avoid dropping aggregate groupings in local plans [#129370](https://github.com/elastic/elasticsearch/pull/129370) (issues: [#129811](https://github.com/elastic/elasticsearch/issues/129811), [#128054](https://github.com/elastic/elasticsearch/issues/128054)) +* Fix `BytesRef2BlockHash` [#130705](https://github.com/elastic/elasticsearch/pull/130705) +* Fix wildcard drop after lookup join [#130448](https://github.com/elastic/elasticsearch/pull/130448) (issue: [#129561](https://github.com/elastic/elasticsearch/issues/129561)) + +Infra/Core: +* Reverse disordered-version warning message [#129904](https://github.com/elastic/elasticsearch/pull/129904) + +Machine Learning: +* Check for model deployment in inference endpoints before stopping [#129325](https://github.com/elastic/elasticsearch/pull/129325) (issue: [#128549](https://github.com/elastic/elasticsearch/issues/128549)) +* Fix timeout bug in DBQ deletion of unused and orphan ML data [#130083](https://github.com/elastic/elasticsearch/pull/130083) + +Mapping: +* Make flattened synthetic source concatenate object keys on scalar/object mismatch [#129600](https://github.com/elastic/elasticsearch/pull/129600) (issue: [#122936](https://github.com/elastic/elasticsearch/issues/122936)) + +Relevance: +* Fix: `GET _synonyms` returns synonyms with empty rules [#131032](https://github.com/elastic/elasticsearch/pull/131032) + +Search: +* Check field data type before casting when applying geo distance sort [#130924](https://github.com/elastic/elasticsearch/pull/130924) (issue: [#129500](https://github.com/elastic/elasticsearch/issues/129500)) +* Fix msearch request parsing when index expression is null [#130776](https://github.com/elastic/elasticsearch/pull/130776) (issue: [#129631](https://github.com/elastic/elasticsearch/issues/129631)) +* Fix text similarity reranker does not propagate min score correctly [#129223](https://github.com/elastic/elasticsearch/pull/129223) +* Throw a 400 when sorting for all types of range fields [#129725](https://github.com/elastic/elasticsearch/pull/129725) +* Trim to size lists created in source fetchers [#130521](https://github.com/elastic/elasticsearch/pull/130521) + +Vector Search: +* Fix knn search error when dimensions are not set [#131081](https://github.com/elastic/elasticsearch/pull/131081) (issue: [#129550](https://github.com/elastic/elasticsearch/issues/129550)) + + ## 9.0.3 [elasticsearch-9.0.3-release-notes] ### Features and enhancements [elasticsearch-9.0.3-features-enhancements] @@ -89,6 +132,7 @@ Searchable Snapshots: Security: * Fix error message when changing the password for a user in the file realm [#127621](https://github.com/elastic/elasticsearch/pull/127621) + ## 9.0.2 [elasticsearch-9.0.2-release-notes] ### Features and enhancements [elasticsearch-9.0.2-features-enhancements] diff --git a/docs/release-notes/known-issues.md b/docs/release-notes/known-issues.md index a20cff68c225c..ed3733b0f2a82 100644 --- a/docs/release-notes/known-issues.md +++ b/docs/release-notes/known-issues.md @@ -8,10 +8,19 @@ mapped_pages: Known issues are significant defects or limitations that may impact your implementation. These issues are actively being worked on and will be addressed in a future release. Review the Elasticsearch known issues to help you make informed decisions, such as upgrading to a new version. ## 9.0.3 [elasticsearch-9.0.3-known-issues] -A bug in the merge scheduler in Elasticsearch 9.0.3 may prevent shards from closing when there isn’t enough disk space to complete a merge. As a result, operations such as closing or relocating an index may hang until sufficient disk space becomes available. +* A bug in the merge scheduler in Elasticsearch 9.0.3 may prevent shards from closing when there isn’t enough disk space to complete a merge. As a result, operations such as closing or relocating an index may hang until sufficient disk space becomes available. To mitigate this issue, the disk space checker is disabled by default in 9.0.3 by setting `indices.merge.disk.check_interval` to `0` seconds. Manually enabling this setting is not recommended. -This issue is planned to be fixed in future patch release [#129613](https://github.com/elastic/elasticsearch/pull/129613) + This issue is planned to be fixed in future patch release [#129613](https://github.com/elastic/elasticsearch/pull/129613) + +* A bug in the ES|QL STATS command may yield incorrect results. The bug only happens in very specific cases that follow this pattern: `STATS ... BY keyword1, keyword2`, i.e. the command must have exactly two grouping fields, both keywords, where the first field has high cardinality (more than 65k distinct values). + + The bug is described in detail in [this issue](https://github.com/elastic/elasticsearch/issues/130644). + The problem was introduced in 8.16.0 and [fixed](https://github.com/elastic/elasticsearch/pull/130705) in 8.17.9, 8.18.7, 9.0.4. + + Possible workarounds include: + * switching the order of the grouping keys (eg. `STATS ... BY keyword2, keyword1`, if the `keyword2` has a lower cardinality) + * reducing the grouping key cardinality, by filtering out values before STATS ## 9.0.0 [elasticsearch-9.0.0-known-issues] * Elasticsearch on Windows might fail to start, or might forbid some file-related operations, when referencing paths with a case different from the one stored by the filesystem. Windows treats paths as case-insensitive, but the filesystem stores them with case. Entitlements, the new security system used by Elasticsearch, treat all paths as case-sensitive, and can therefore prevent access to a path that should be accessible. @@ -40,3 +49,12 @@ This issue will be fixed in a future patch release (see [PR #126990](https://git DELETE _index_template/.watches POST /_watcher/_start ``` + +* A bug in the ES|QL STATS command may yield incorrect results. The bug only happens in very specific cases that follow this pattern: `STATS ... BY keyword1, keyword2`, i.e. the command must have exactly two grouping fields, both keywords, where the first field has high cardinality (more than 65k distinct values). + + The bug is described in detail in [this issue](https://github.com/elastic/elasticsearch/issues/130644). + The problem was introduced in 8.16.0 and [fixed](https://github.com/elastic/elasticsearch/pull/130705) in 8.17.9, 8.18.7, 9.0.4. + + Possible workarounds include: + * switching the order of the grouping keys (eg. `STATS ... BY keyword2, keyword1`, if the `keyword2` has a lower cardinality) + * reducing the grouping key cardinality, by filtering out values before STATS diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 7e147eff76dbd..bb4ae5da279fb 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -71,9 +71,9 @@ - - - + + + @@ -4954,174 +4954,174 @@ - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + diff --git a/libs/build.gradle b/libs/build.gradle index b39ddaab98c2d..79806b0dc45b3 100644 --- a/libs/build.gradle +++ b/libs/build.gradle @@ -50,7 +50,7 @@ configure(childProjects.values()) { // Omit oddball libraries that aren't in server. def nonServerLibs = ['plugin-scanner'] if (false == nonServerLibs.contains(project.name)) { - project.getTasks().withType(Test.class).matching(test -> ['test'].contains(test.name)).configureEach(test -> { + project.getTasks().withType(Test.class).matching(test -> ['test', 'internalClusterTest'].contains(test.name)).configureEach(test -> { test.systemProperty('es.entitlement.enableForTests', 'true') }) } diff --git a/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsTestRule.java b/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsTestRule.java index 9cad8b710ae11..7aa31054af97f 100644 --- a/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsTestRule.java +++ b/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsTestRule.java @@ -83,7 +83,6 @@ protected void before() throws Throwable { cluster = ElasticsearchCluster.local() .module("entitled", spec -> buildEntitlements(spec, "org.elasticsearch.entitlement.qa.entitled", ENTITLED_POLICY)) .module(ENTITLEMENT_TEST_PLUGIN_NAME, spec -> setupEntitlements(spec, modular, policyBuilder)) - .systemProperty("es.entitlements.enabled", "true") .systemProperty("es.entitlements.verify_bytecode", "true") .systemProperty("es.entitlements.testdir", () -> testDir.getRoot().getAbsolutePath()) .systemProperties(spec -> tempDirSystemPropertyProvider.get(testDir.getRoot().toPath())) diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java index 928b953c295fe..3d364f4b53cec 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java @@ -161,7 +161,7 @@ private static PolicyManager createPolicyManager( PathLookup pathLookup, Policy serverPolicyPatch, Function, PolicyManager.PolicyScope> scopeResolver, - Map> pluginSourcePaths + Map> pluginSourcePathsResolver ) { FilesEntitlementsValidation.validate(pluginPolicies, pathLookup); @@ -170,7 +170,7 @@ private static PolicyManager createPolicyManager( HardcodedEntitlements.agentEntitlements(), pluginPolicies, scopeResolver, - pluginSourcePaths, + pluginSourcePathsResolver::get, pathLookup ); } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookup.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookup.java index 361d77ff83477..3d2daf5e407c6 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookup.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookup.java @@ -9,6 +9,8 @@ package org.elasticsearch.entitlement.runtime.policy; +import org.elasticsearch.core.PathUtils; + import java.nio.file.Path; import java.util.stream.Stream; @@ -16,6 +18,8 @@ * Resolves paths for known directories checked by entitlements. */ public interface PathLookup { + Class DEFAULT_FILESYSTEM_CLASS = PathUtils.getDefaultFileSystem().getClass(); + enum BaseDir { USER_HOME, CONFIG, @@ -37,4 +41,6 @@ enum BaseDir { * paths of the given {@code baseDir}. */ Stream resolveSettingPaths(BaseDir baseDir, String settingName); + + boolean isPathOnDefaultFilesystem(Path path); } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookupImpl.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookupImpl.java index e3474250d43f0..df259254025bb 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookupImpl.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PathLookupImpl.java @@ -75,4 +75,9 @@ public Stream resolveSettingPaths(BaseDir baseDir, String settingName) { .toList(); return getBaseDirPaths(baseDir).flatMap(path -> relativePaths.stream().map(path::resolve)); } + + @Override + public boolean isPathOnDefaultFilesystem(Path path) { + return path.getFileSystem().getClass() == DEFAULT_FILESYSTEM_CLASS; + } } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyCheckerImpl.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyCheckerImpl.java index 2c3374f594847..acfdbb0caded7 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyCheckerImpl.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyCheckerImpl.java @@ -9,7 +9,6 @@ package org.elasticsearch.entitlement.runtime.policy; -import org.elasticsearch.core.PathUtils; import org.elasticsearch.core.Strings; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; @@ -58,7 +57,7 @@ */ @SuppressForbidden(reason = "Explicitly checking APIs that are forbidden") public class PolicyCheckerImpl implements PolicyChecker { - static final Class DEFAULT_FILESYSTEM_CLASS = PathUtils.getDefaultFileSystem().getClass(); + protected final Set suppressFailureLogPackages; /** * Frames originating from this module are ignored in the permission logic. @@ -81,15 +80,14 @@ public PolicyCheckerImpl( this.pathLookup = pathLookup; } - private static boolean isPathOnDefaultFilesystem(Path path) { - var pathFileSystemClass = path.getFileSystem().getClass(); - if (path.getFileSystem().getClass() != DEFAULT_FILESYSTEM_CLASS) { + private boolean isPathOnDefaultFilesystem(Path path) { + if (pathLookup.isPathOnDefaultFilesystem(path) == false) { PolicyManager.generalLogger.trace( () -> Strings.format( "File entitlement trivially allowed: path [%s] is for a different FileSystem class [%s], default is [%s]", path.toString(), - pathFileSystemClass.getName(), - DEFAULT_FILESYSTEM_CLASS.getName() + path.getFileSystem().getClass().getName(), + PathLookup.DEFAULT_FILESYSTEM_CLASS.getName() ) ); return false; @@ -139,7 +137,7 @@ private void neverEntitled(Class callerClass, Supplier operationDescr requestingClass, operationDescription.get() ), - callerClass, + requestingClass, entitlements ); } @@ -217,7 +215,7 @@ public void checkFileRead(Class callerClass, Path path) { @Override public void checkFileRead(Class callerClass, Path path, boolean followLinks) throws NoSuchFileException { - if (PolicyCheckerImpl.isPathOnDefaultFilesystem(path) == false) { + if (isPathOnDefaultFilesystem(path) == false) { return; } var requestingClass = requestingClass(callerClass); @@ -251,7 +249,7 @@ public void checkFileRead(Class callerClass, Path path, boolean followLinks) requestingClass, realPath == null ? path : Strings.format("%s -> %s", path, realPath) ), - callerClass, + requestingClass, entitlements ); } @@ -265,7 +263,7 @@ public void checkFileWrite(Class callerClass, File file) { @Override public void checkFileWrite(Class callerClass, Path path) { - if (PolicyCheckerImpl.isPathOnDefaultFilesystem(path) == false) { + if (isPathOnDefaultFilesystem(path) == false) { return; } var requestingClass = requestingClass(callerClass); @@ -283,7 +281,7 @@ public void checkFileWrite(Class callerClass, Path path) { requestingClass, path ), - callerClass, + requestingClass, entitlements ); } @@ -360,8 +358,8 @@ public void checkAllNetworkAccess(Class callerClass) { } var classEntitlements = policyManager.getEntitlements(requestingClass); - checkFlagEntitlement(classEntitlements, InboundNetworkEntitlement.class, requestingClass, callerClass); - checkFlagEntitlement(classEntitlements, OutboundNetworkEntitlement.class, requestingClass, callerClass); + checkFlagEntitlement(classEntitlements, InboundNetworkEntitlement.class, requestingClass); + checkFlagEntitlement(classEntitlements, OutboundNetworkEntitlement.class, requestingClass); } @Override @@ -378,16 +376,15 @@ public void checkWriteProperty(Class callerClass, String property) { ModuleEntitlements entitlements = policyManager.getEntitlements(requestingClass); if (entitlements.getEntitlements(WriteSystemPropertiesEntitlement.class).anyMatch(e -> e.properties().contains(property))) { - entitlements.logger() - .debug( - () -> Strings.format( - "Entitled: component [%s], module [%s], class [%s], entitlement [write_system_properties], property [%s]", - entitlements.componentName(), - entitlements.moduleName(), - requestingClass, - property - ) - ); + PolicyManager.generalLogger.debug( + () -> Strings.format( + "Entitled: component [%s], module [%s], class [%s], entitlement [write_system_properties], property [%s]", + entitlements.componentName(), + entitlements.moduleName(), + requestingClass, + property + ) + ); return; } notEntitled( @@ -398,7 +395,7 @@ public void checkWriteProperty(Class callerClass, String property) { requestingClass, property ), - callerClass, + requestingClass, entitlements ); } @@ -439,8 +436,7 @@ Optional findRequestingFrame(Stream entitlementClass, - Class requestingClass, - Class callerClass + Class requestingClass ) { if (classEntitlements.hasEntitlement(entitlementClass) == false) { notEntitled( @@ -451,27 +447,26 @@ private void checkFlagEntitlement( requestingClass, PolicyParser.buildEntitlementNameFromClass(entitlementClass) ), - callerClass, + requestingClass, classEntitlements ); } - classEntitlements.logger() - .debug( - () -> Strings.format( - "Entitled: component [%s], module [%s], class [%s], entitlement [%s]", - classEntitlements.componentName(), - classEntitlements.moduleName(), - requestingClass, - PolicyParser.buildEntitlementNameFromClass(entitlementClass) - ) - ); + PolicyManager.generalLogger.debug( + () -> Strings.format( + "Entitled: component [%s], module [%s], class [%s], entitlement [%s]", + classEntitlements.componentName(), + classEntitlements.moduleName(), + requestingClass, + PolicyParser.buildEntitlementNameFromClass(entitlementClass) + ) + ); } - private void notEntitled(String message, Class callerClass, ModuleEntitlements entitlements) { + private void notEntitled(String message, Class requestingClass, ModuleEntitlements entitlements) { var exception = new NotEntitledException(message); // Don't emit a log for suppressed packages, e.g. packages containing self tests - if (suppressFailureLogPackages.contains(callerClass.getPackage()) == false) { - entitlements.logger().warn("Not entitled: {}", message, exception); + if (suppressFailureLogPackages.contains(requestingClass.getPackage()) == false) { + entitlements.logger(requestingClass).warn("Not entitled: {}", message, exception); } throw exception; } @@ -482,7 +477,7 @@ public void checkEntitlementPresent(Class callerClass, Class, List> entitlementsByType, - FileAccessTree fileAccess, - Logger logger + FileAccessTree fileAccess ) { public ModuleEntitlements { @@ -141,6 +142,12 @@ public Stream getEntitlements(Class entitlementCla } return entitlements.stream().map(entitlementClass::cast); } + + Logger logger(Class requestingClass) { + var packageName = requestingClass.getPackageName(); + var loggerSuffix = "." + componentName + "." + ((moduleName == null) ? ALL_UNNAMED : moduleName) + "." + packageName; + return LogManager.getLogger(PolicyManager.class.getName() + loggerSuffix); + } } private FileAccessTree getDefaultFileAccess(Collection componentPaths) { @@ -149,13 +156,7 @@ private FileAccessTree getDefaultFileAccess(Collection componentPaths) { // pkg private for testing ModuleEntitlements defaultEntitlements(String componentName, Collection componentPaths, String moduleName) { - return new ModuleEntitlements( - componentName, - moduleName, - Map.of(), - getDefaultFileAccess(componentPaths), - getLogger(componentName, moduleName) - ); + return new ModuleEntitlements(componentName, moduleName, Map.of(), getDefaultFileAccess(componentPaths)); } // pkg private for testing @@ -175,8 +176,7 @@ ModuleEntitlements policyEntitlements( componentName, moduleName, entitlements.stream().collect(groupingBy(Entitlement::getClass)), - FileAccessTree.of(componentName, moduleName, filesEntitlement, pathLookup, componentPaths, exclusivePaths), - getLogger(componentName, moduleName) + FileAccessTree.of(componentName, moduleName, filesEntitlement, pathLookup, componentPaths, exclusivePaths) ); } @@ -217,7 +217,7 @@ private static Set findSystemLayerModules() { .filter(m -> SYSTEM_LAYER_MODULES.contains(m) == false) .collect(Collectors.toUnmodifiableSet()); - private final Map> pluginSourcePaths; + private final Function> pluginSourcePathsResolver; /** * Paths that are only allowed for a single module. Used to generate @@ -231,7 +231,7 @@ public PolicyManager( List apmAgentEntitlements, Map pluginPolicies, Function, PolicyScope> scopeResolver, - Map> pluginSourcePaths, + Function> pluginSourcePathsResolver, PathLookup pathLookup ) { this.serverEntitlements = buildScopeEntitlementsMap(requireNonNull(serverPolicy)); @@ -240,7 +240,7 @@ public PolicyManager( .stream() .collect(toUnmodifiableMap(Map.Entry::getKey, e -> buildScopeEntitlementsMap(e.getValue()))); this.scopeResolver = scopeResolver; - this.pluginSourcePaths = pluginSourcePaths; + this.pluginSourcePathsResolver = pluginSourcePathsResolver; this.pathLookup = requireNonNull(pathLookup); List exclusiveFileEntitlements = new ArrayList<>(); @@ -286,21 +286,6 @@ private static void validateEntitlementsPerModule( } } - private static Logger getLogger(String componentName, String moduleName) { - var loggerSuffix = "." + componentName + "." + ((moduleName == null) ? ALL_UNNAMED : moduleName); - return MODULE_LOGGERS.computeIfAbsent(PolicyManager.class.getName() + loggerSuffix, LogManager::getLogger); - } - - /** - * We want to use the same {@link Logger} object for a given name, because we want {@link ModuleEntitlements} - * {@code equals} and {@code hashCode} to work. - *

- * This would not be required if LogManager - * memoized the loggers, - * but here we are. - */ - private static final ConcurrentHashMap MODULE_LOGGERS = new ConcurrentHashMap<>(); - protected ModuleEntitlements getEntitlements(Class requestingClass) { return moduleEntitlementsMap.computeIfAbsent(requestingClass.getModule(), m -> computeEntitlements(requestingClass)); } @@ -334,7 +319,10 @@ protected final ModuleEntitlements computeEntitlements(Class requestingClass) default -> { assert policyScope.kind() == PLUGIN; var pluginEntitlements = pluginsEntitlements.get(componentName); - Collection componentPaths = pluginSourcePaths.getOrDefault(componentName, List.of()); + Collection componentPaths = Objects.requireNonNullElse( + pluginSourcePathsResolver.apply(componentName), + Collections.emptyList() + ); if (pluginEntitlements == null) { return defaultEntitlements(componentName, componentPaths, moduleName); } else { diff --git a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java index 21197fc6bd942..e1f20a0eae990 100644 --- a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java +++ b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java @@ -33,6 +33,7 @@ import java.net.URLClassLoader; import java.nio.file.Path; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -95,7 +96,7 @@ public void testGetEntitlements() { List.of(), Map.of("plugin1", new Policy("plugin1", List.of(new Scope("plugin.module1", List.of(new ExitVMEntitlement()))))), c -> policyScope.get(), - Map.of("plugin1", plugin1SourcePaths), + Map.of("plugin1", plugin1SourcePaths)::get, TEST_PATH_LOOKUP ); Collection thisSourcePaths = policyManager.getComponentPathsFromClass(getClass()); @@ -170,7 +171,7 @@ public void testAgentsEntitlements() throws IOException, ClassNotFoundException c -> c.getPackageName().startsWith(TEST_AGENTS_PACKAGE_NAME) ? PolicyScope.apmAgent("test.agent.module") : PolicyScope.plugin("test", "test.plugin.module"), - Map.of(), + name -> Collections.emptyList(), TEST_PATH_LOOKUP ); ModuleEntitlements agentsEntitlements = policyManager.getEntitlements(TestAgent.class); @@ -197,7 +198,7 @@ public void testDuplicateEntitlements() { List.of(), Map.of(), c -> PolicyScope.plugin("test", moduleName(c)), - Map.of(), + name -> Collections.emptyList(), TEST_PATH_LOOKUP ) ); @@ -213,7 +214,7 @@ public void testDuplicateEntitlements() { List.of(new CreateClassLoaderEntitlement(), new CreateClassLoaderEntitlement()), Map.of(), c -> PolicyScope.plugin("test", moduleName(c)), - Map.of(), + name -> Collections.emptyList(), TEST_PATH_LOOKUP ) ); @@ -249,7 +250,7 @@ public void testDuplicateEntitlements() { ) ), c -> PolicyScope.plugin("plugin1", moduleName(c)), - Map.of("plugin1", List.of(Path.of("modules", "plugin1"))), + Map.of("plugin1", List.of(Path.of("modules", "plugin1")))::get, TEST_PATH_LOOKUP ) ); @@ -299,7 +300,7 @@ public void testFilesEntitlementsWithExclusive() { ) ), c -> PolicyScope.plugin("", moduleName(c)), - Map.of("plugin1", List.of(Path.of("modules", "plugin1")), "plugin2", List.of(Path.of("modules", "plugin2"))), + Map.of("plugin1", List.of(Path.of("modules", "plugin1")), "plugin2", List.of(Path.of("modules", "plugin2")))::get, TEST_PATH_LOOKUP ) ); @@ -350,7 +351,7 @@ public void testFilesEntitlementsWithExclusive() { ) ), c -> PolicyScope.plugin("", moduleName(c)), - Map.of(), + name -> Collections.emptyList(), TEST_PATH_LOOKUP ) ); diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java index 803bdd523a6b6..95415cee2b090 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java @@ -8,10 +8,15 @@ */ package org.elasticsearch.simdvec; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; import java.io.IOException; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + /** Scorer for quantized vectors stored as an {@link IndexInput}. *

* Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but @@ -20,11 +25,19 @@ * */ public class ES91Int4VectorsScorer { + public static final int BULK_SIZE = 16; + protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + /** The wrapper {@link IndexInput}. */ protected final IndexInput in; protected final int dimensions; protected byte[] scratch; + protected final float[] lowerIntervals = new float[BULK_SIZE]; + protected final float[] upperIntervals = new float[BULK_SIZE]; + protected final int[] targetComponentSums = new int[BULK_SIZE]; + protected final float[] additionalCorrections = new float[BULK_SIZE]; + /** Sole constructor, called by sub-classes. */ public ES91Int4VectorsScorer(IndexInput in, int dimensions) { this.in = in; @@ -32,6 +45,10 @@ public ES91Int4VectorsScorer(IndexInput in, int dimensions) { scratch = new byte[dimensions]; } + /** + * compute the quantize distance between the provided quantized query and the quantized vector + * that is read from the wrapped {@link IndexInput}. + */ public long int4DotProduct(byte[] b) throws IOException { in.readBytes(scratch, 0, dimensions); int total = 0; @@ -40,4 +57,129 @@ public long int4DotProduct(byte[] b) throws IOException { } return total; } + + /** + * compute the quantize distance between the provided quantized query and the quantized vectors + * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is + * determined by {code count} and the results are stored in the provided {@code scores} array. + */ + public void int4DotProductBulk(byte[] b, int count, float[] scores) throws IOException { + for (int i = 0; i < count; i++) { + scores[i] = int4DotProduct(b); + } + } + + /** + * Computes the score by applying the necessary corrections to the provided quantized distance. + */ + public float score( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp + ) throws IOException { + float score = int4DotProduct(q); + in.readFloats(lowerIntervals, 0, 3); + int addition = Short.toUnsignedInt(in.readShort()); + return applyCorrections( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + lowerIntervals[0], + lowerIntervals[1], + addition, + lowerIntervals[2], + score + ); + } + + /** + * compute the distance between the provided quantized query and the quantized vectors that are + * read from the wrapped {@link IndexInput}. + * + *

The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the + * input is as follows: First the quantized vectors are read from the input,then all the lower + * intervals as floats, then all the upper intervals as floats, then all the target component sums + * as shorts, and finally all the additional corrections as floats. + * + *

The results are stored in the provided scores array. + */ + public void scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int4DotProductBulk(q, BULK_SIZE, scores); + in.readFloats(lowerIntervals, 0, BULK_SIZE); + in.readFloats(upperIntervals, 0, BULK_SIZE); + for (int i = 0; i < BULK_SIZE; i++) { + targetComponentSums[i] = Short.toUnsignedInt(in.readShort()); + } + in.readFloats(additionalCorrections, 0, BULK_SIZE); + for (int i = 0; i < BULK_SIZE; i++) { + scores[i] = applyCorrections( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + lowerIntervals[i], + upperIntervals[i], + targetComponentSums[i], + additionalCorrections[i], + scores[i] + ); + } + } + + /** + * Computes the score by applying the necessary corrections to the provided quantized distance. + */ + public float applyCorrections( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float lowerInterval, + float upperInterval, + int targetComponentSum, + float additionalCorrection, + float qcDist + ) { + float ax = lowerInterval; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = upperInterval - ax; + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; + float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryAdditionalCorrection + additionalCorrection - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryAdditionalCorrection + additionalCorrection - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java index 9a314fc4c18ec..7aaacae89be74 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java @@ -9,22 +9,30 @@ package org.elasticsearch.simdvec.internal.vectorization; import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; import jdk.incubator.vector.VectorSpecies; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; import static java.nio.ByteOrder.LITTLE_ENDIAN; import static jdk.incubator.vector.VectorOperators.ADD; import static jdk.incubator.vector.VectorOperators.B2I; import static jdk.incubator.vector.VectorOperators.B2S; import static jdk.incubator.vector.VectorOperators.S2I; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; /** Panamized scorer for quantized vectors stored as an {@link IndexInput}. *

@@ -43,6 +51,15 @@ public final class MemorySegmentES91Int4VectorsScorer extends ES91Int4VectorsSco private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512; + private static final VectorSpecies FLOAT_SPECIES; + private static final VectorSpecies SHORT_SPECIES; + + static { + // default to platform supported bitsize + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(PanamaESVectorUtilSupport.VECTOR_BITSIZE)); + SHORT_SPECIES = VectorSpecies.of(short.class, VectorShape.forBitSize(PanamaESVectorUtilSupport.VECTOR_BITSIZE)); + } + private final MemorySegment memorySegment; public MemorySegmentES91Int4VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { @@ -99,12 +116,11 @@ private int int4DotProductBody128(byte[] q, int limit) throws IOException { } private long dotProduct(byte[] q) throws IOException { - int i = 0; - int res = 0; - // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) if (dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + int i = 0; + int res = 0; // compute vectorized dot product consistent with VPDPBUSD instruction if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) { i += BYTE_SPECIES_128.loopBound(dimensions); @@ -113,16 +129,15 @@ private long dotProduct(byte[] q) throws IOException { i += BYTE_SPECIES_64.loopBound(dimensions); res += dotProductBody256(q, i); } else { - // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); - res += dotProductBody128(q, i); + throw new IllegalArgumentException("Unreacheable statement"); } + // scalar tail + for (; i < q.length; i++) { + res += in.readByte() * q[i]; + } + return res; } - // scalar tail - for (; i < q.length; i++) { - res += in.readByte() * q[i]; - } - return res; + return super.int4DotProduct(q); } /** vectorized dot product body (512 bit vectors) */ @@ -166,26 +181,222 @@ private int dotProductBody256(byte[] q, int limit) throws IOException { return acc.reduceLanes(ADD); } - /** vectorized dot product body (128 bit vectors) */ - private int dotProductBody128(byte[] q, int limit) throws IOException { - IntVector acc = IntVector.zero(INT_SPECIES_128); - long offset = in.getFilePointer(); - // 4 bytes at a time (re-loading half the vector each time!) - for (int i = 0; i < limit; i += BYTE_SPECIES_64.length() >> 1) { - // load 8 bytes - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + @Override + public void int4DotProductBulk(byte[] q, int count, float[] scores) throws IOException { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) { + dotProductBulk(q, count, scores); + return; + } + if (dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + int4DotProductBody128Bulk(q, count, scores); + return; + } + super.int4DotProductBulk(q, count, scores); + } - // process first "half" only: 16-bit multiply - Vector va16 = va8.convert(B2S, 0); - Vector vb16 = vb8.convert(B2S, 0); - Vector prod16 = va16.mul(vb16); + private void int4DotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + int sum = 0; + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128); + ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128); - // 32-bit add - acc = acc.add(prod16.convertShape(S2I, INT_SPECIES_128, 0)); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j, LITTLE_ENDIAN); + + ByteVector prod8 = va8.mul(vb8); + ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 255)); + + va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j + 8); + vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j + 8, LITTLE_ENDIAN); + + prod8 = va8.mul(vb8); + prod16 = prod8.convertShape(B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 255)); + } + + IntVector intAcc0 = acc0.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + in.seek(offset + limit); + in.readBytes(scratch, limit, dimensions - limit); + for (int j = limit; j < dimensions; j++) { + sum += scratch[j] * q[j]; + } + scores[iter] = sum; } - in.seek(offset + limit); - // reduce - return acc.reduceLanes(ADD); + } + + private void dotProductBulk(byte[] q, int count, float[] scores) throws IOException { + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) { + dotProductBody512Bulk(q, count, scores); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) { + dotProductBody256Bulk(q, count, scores); + } else { + throw new IllegalArgumentException("Unreacheable statement"); + } + return; + } + super.int4DotProductBulk(q, count, scores); + } + + /** vectorized dot product body (512 bit vectors) */ + private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(INT_SPECIES_512); + long offset = in.getFilePointer(); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); + acc = acc.add(prod32); + } + + in.seek(offset + limit); // advance the input stream + // reduce + long res = acc.reduceLanes(ADD); + for (; i < q.length; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + /** vectorized dot product body (256 bit vectors) */ + private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(INT_SPECIES_256); + long offset = in.getFilePointer(); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); + Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); + acc = acc.add(va32.mul(vb32)); + } + in.seek(offset + limit); + // reduce + long res = acc.reduceLanes(ADD); + for (; i < q.length; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + @Override + public void scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int4DotProductBulk(q, BULK_SIZE, scores); + applyCorrectionsBulk( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } + + private void applyCorrectionsBulk( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; + for (; i < limit; i += FLOAT_SPECIES.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax); + var targetComponentSums = ShortVector.fromMemorySegment( + SHORT_SPECIES, + memorySegment, + offset + 8 * BULK_SIZE + i * Short.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 10 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 14L * BULK_SIZE); } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java index c19211585a765..34ae512b0765e 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java @@ -9,13 +9,18 @@ package org.elasticsearch.simdvec.internal.vectorization; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; + +import static org.hamcrest.Matchers.lessThan; public class ES91Int4VectorScorerTests extends BaseVectorizationTests { @@ -57,4 +62,153 @@ public void testInt4DotProduct() throws Exception { } } } + + public void testInt4Score() throws Exception { + // only even dimensions are supported + final int dimensions = random().nextInt(1, 1000) * 2; + final int numVectors = random().nextInt(1, 100); + final byte[] vector = new byte[dimensions]; + final byte[] corrections = new byte[14]; + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + for (int i = 0; i < numVectors; i++) { + for (int j = 0; j < dimensions; j++) { + vector[j] = (byte) random().nextInt(16); // 4-bit quantization + } + out.writeBytes(vector, 0, dimensions); + random().nextBytes(corrections); + out.writeBytes(corrections, 0, corrections.length); + } + } + final byte[] query = new byte[dimensions]; + for (int j = 0; j < dimensions; j++) { + query[j] = (byte) random().nextInt(16); // 4-bit quantization + } + OptimizedScalarQuantizer.QuantizationResult queryCorrections = new OptimizedScalarQuantizer.QuantizationResult( + random().nextFloat(), + random().nextFloat(), + random().nextFloat(), + Short.toUnsignedInt((short) random().nextInt()) + ); + float centroidDp = random().nextFloat(); + VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values()); + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) (dimensions + 14) * numVectors); + final ES91Int4VectorsScorer defaultScorer = defaultProvider().newES91Int4VectorsScorer(in, dimensions); + final ES91Int4VectorsScorer panamaScorer = maybePanamaProvider().newES91Int4VectorsScorer(slice, dimensions); + for (int i = 0; i < numVectors; i++) { + float scoreDefault = defaultScorer.score( + query, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); + float scorePanama = panamaScorer.score( + query, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); + assertEquals(scoreDefault, scorePanama, 0.001f); + assertEquals(in.getFilePointer(), slice.getFilePointer()); + } + assertEquals((long) (dimensions + 14) * numVectors, in.getFilePointer()); + } + } + } + + public void testInt4ScoreBulk() throws Exception { + // only even dimensions are supported + final int dimensions = random().nextInt(1, 1000) * 2; + final int numVectors = random().nextInt(1, 10) * ES91Int4VectorsScorer.BULK_SIZE; + final byte[] vector = new byte[ES91Int4VectorsScorer.BULK_SIZE * dimensions]; + final byte[] corrections = new byte[ES91Int4VectorsScorer.BULK_SIZE * 14]; + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { + for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE * dimensions; j++) { + vector[j] = (byte) random().nextInt(16); // 4-bit quantization + } + out.writeBytes(vector, 0, vector.length); + random().nextBytes(corrections); + out.writeBytes(corrections, 0, corrections.length); + } + } + final byte[] query = new byte[dimensions]; + for (int j = 0; j < dimensions; j++) { + query[j] = (byte) random().nextInt(16); // 4-bit quantization + } + OptimizedScalarQuantizer.QuantizationResult queryCorrections = new OptimizedScalarQuantizer.QuantizationResult( + random().nextFloat(), + random().nextFloat(), + random().nextFloat(), + Short.toUnsignedInt((short) random().nextInt()) + ); + float centroidDp = random().nextFloat(); + VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values()); + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) (dimensions + 14) * numVectors); + final ES91Int4VectorsScorer defaultScorer = defaultProvider().newES91Int4VectorsScorer(in, dimensions); + final ES91Int4VectorsScorer panamaScorer = maybePanamaProvider().newES91Int4VectorsScorer(slice, dimensions); + float[] scoresDefault = new float[ES91Int4VectorsScorer.BULK_SIZE]; + float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE]; + for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { + defaultScorer.scoreBulk( + query, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scoresDefault + ); + panamaScorer.scoreBulk( + query, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scoresPanama + ); + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + if (scoresDefault[j] == scoresPanama[j]) { + continue; + } + if (scoresDefault[j] > (1000 * Byte.MAX_VALUE)) { + float diff = Math.abs(scoresDefault[j] - scoresPanama[j]); + assertThat( + "defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j], + diff / scoresDefault[j], + lessThan(1e-5f) + ); + assertThat( + "defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j], + diff / scoresPanama[j], + lessThan(1e-5f) + ); + } else { + assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f); + } + } + assertEquals(in.getFilePointer(), slice.getFilePointer()); + } + assertEquals((long) (dimensions + 14) * numVectors, in.getFilePointer()); + } + } + } } diff --git a/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfiguration.java b/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfiguration.java index 30d846da46156..8f22063d3f27f 100644 --- a/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfiguration.java +++ b/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfiguration.java @@ -40,7 +40,8 @@ public record SslConfiguration( SslVerificationMode verificationMode, SslClientAuthenticationMode clientAuth, List ciphers, - List supportedProtocols + List supportedProtocols, + long handshakeTimeoutMillis ) { /** @@ -71,7 +72,8 @@ public SslConfiguration( SslVerificationMode verificationMode, SslClientAuthenticationMode clientAuth, List ciphers, - List supportedProtocols + List supportedProtocols, + long handshakeTimeoutMillis ) { this.settingPrefix = settingPrefix; this.explicitlyConfigured = explicitlyConfigured; @@ -85,6 +87,10 @@ public SslConfiguration( this.keyConfig = Objects.requireNonNull(keyConfig, "key config cannot be null"); this.verificationMode = Objects.requireNonNull(verificationMode, "verification mode cannot be null"); this.clientAuth = Objects.requireNonNull(clientAuth, "client authentication cannot be null"); + if (handshakeTimeoutMillis < 1L) { + throw new SslConfigException("handshake timeout must be at least 1ms"); + } + this.handshakeTimeoutMillis = handshakeTimeoutMillis; this.ciphers = Collections.unmodifiableList(ciphers); this.supportedProtocols = Collections.unmodifiableList(supportedProtocols); } @@ -164,11 +170,21 @@ public boolean equals(Object o) { && this.verificationMode == that.verificationMode && this.clientAuth == that.clientAuth && Objects.equals(this.ciphers, that.ciphers) - && Objects.equals(this.supportedProtocols, that.supportedProtocols); + && Objects.equals(this.supportedProtocols, that.supportedProtocols) + && this.handshakeTimeoutMillis == that.handshakeTimeoutMillis; } @Override public int hashCode() { - return Objects.hash(settingPrefix, trustConfig, keyConfig, verificationMode, clientAuth, ciphers, supportedProtocols); + return Objects.hash( + settingPrefix, + trustConfig, + keyConfig, + verificationMode, + clientAuth, + ciphers, + supportedProtocols, + handshakeTimeoutMillis + ); } } diff --git a/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationKeys.java b/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationKeys.java index 1c782a2fa5f31..777f68c518fbb 100644 --- a/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationKeys.java +++ b/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationKeys.java @@ -132,6 +132,10 @@ public class SslConfigurationKeys { * The use of this setting {@link #isDeprecated(String) is deprecated}. */ public static final String KEY_LEGACY_PASSPHRASE = "key_passphrase"; + /** + * The timeout for TLS handshakes in this context. + */ + public static final String HANDSHAKE_TIMEOUT = "handshake_timeout"; private static final Set DEPRECATED_KEYS = new HashSet<>( Arrays.asList(TRUSTSTORE_LEGACY_PASSWORD, KEYSTORE_LEGACY_PASSWORD, KEYSTORE_LEGACY_KEY_PASSWORD, KEY_LEGACY_PASSPHRASE) diff --git a/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationLoader.java b/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationLoader.java index 9d455807953e7..e2e9d92726b6a 100644 --- a/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationLoader.java +++ b/libs/ssl-config/src/main/java/org/elasticsearch/common/ssl/SslConfigurationLoader.java @@ -10,6 +10,7 @@ package org.elasticsearch.common.ssl; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; import java.nio.file.Path; import java.security.KeyStore; @@ -27,6 +28,7 @@ import static org.elasticsearch.common.ssl.SslConfigurationKeys.CERTIFICATE_AUTHORITIES; import static org.elasticsearch.common.ssl.SslConfigurationKeys.CIPHERS; import static org.elasticsearch.common.ssl.SslConfigurationKeys.CLIENT_AUTH; +import static org.elasticsearch.common.ssl.SslConfigurationKeys.HANDSHAKE_TIMEOUT; import static org.elasticsearch.common.ssl.SslConfigurationKeys.KEY; import static org.elasticsearch.common.ssl.SslConfigurationKeys.KEYSTORE_ALGORITHM; import static org.elasticsearch.common.ssl.SslConfigurationKeys.KEYSTORE_LEGACY_KEY_PASSWORD; @@ -152,6 +154,8 @@ public abstract class SslConfigurationLoader { private static final char[] EMPTY_PASSWORD = new char[0]; public static final List GLOBAL_DEFAULT_RESTRICTED_TRUST_FIELDS = List.of(X509Field.SAN_OTHERNAME_COMMONNAME); + public static final TimeValue DEFAULT_HANDSHAKE_TIMEOUT = TimeValue.timeValueSeconds(10); + private final String settingPrefix; private SslTrustConfig defaultTrustConfig; @@ -302,6 +306,11 @@ public SslConfiguration load(Path basePath) { X509Field::parseForRestrictedTrust, defaultRestrictedTrustFields ); + final long handshakeTimeoutMillis = resolveSetting( + HANDSHAKE_TIMEOUT, + s -> TimeValue.parseTimeValue(s, HANDSHAKE_TIMEOUT), + DEFAULT_HANDSHAKE_TIMEOUT + ).millis(); final SslKeyConfig keyConfig = buildKeyConfig(basePath); final SslTrustConfig trustConfig = buildTrustConfig(basePath, verificationMode, keyConfig, Set.copyOf(trustRestrictionsX509Fields)); @@ -321,7 +330,8 @@ public SslConfiguration load(Path basePath) { verificationMode, clientAuth, ciphers, - protocols + protocols, + handshakeTimeoutMillis ); } diff --git a/libs/ssl-config/src/test/java/org/elasticsearch/common/ssl/SslConfigurationTests.java b/libs/ssl-config/src/test/java/org/elasticsearch/common/ssl/SslConfigurationTests.java index 735edddd284bd..d93266e50ebd1 100644 --- a/libs/ssl-config/src/test/java/org/elasticsearch/common/ssl/SslConfigurationTests.java +++ b/libs/ssl-config/src/test/java/org/elasticsearch/common/ssl/SslConfigurationTests.java @@ -39,6 +39,7 @@ public void testBasicConstruction() { final SslClientAuthenticationMode clientAuth = randomFrom(SslClientAuthenticationMode.values()); final List ciphers = randomSubsetOf(randomIntBetween(1, DEFAULT_CIPHERS.size()), DEFAULT_CIPHERS); final List protocols = randomSubsetOf(randomIntBetween(1, 4), VALID_PROTOCOLS); + final long handshakeTimeoutMillis = randomHandshakeTimeoutMillis(); final SslConfiguration configuration = new SslConfiguration( "test.ssl", true, @@ -47,7 +48,8 @@ public void testBasicConstruction() { verificationMode, clientAuth, ciphers, - protocols + protocols, + handshakeTimeoutMillis ); assertThat(configuration.trustConfig(), is(trustConfig)); @@ -56,6 +58,7 @@ public void testBasicConstruction() { assertThat(configuration.clientAuth(), is(clientAuth)); assertThat(configuration.getCipherSuites(), is(ciphers)); assertThat(configuration.supportedProtocols(), is(protocols)); + assertThat(configuration.handshakeTimeoutMillis(), is(handshakeTimeoutMillis)); assertThat(configuration.toString(), containsString("TEST-TRUST")); assertThat(configuration.toString(), containsString("TEST-KEY")); @@ -63,6 +66,7 @@ public void testBasicConstruction() { assertThat(configuration.toString(), containsString(clientAuth.toString())); assertThat(configuration.toString(), containsString(randomFrom(ciphers))); assertThat(configuration.toString(), containsString(randomFrom(protocols))); + assertThat(configuration.toString(), containsString("handshakeTimeoutMillis=" + handshakeTimeoutMillis)); } public void testEqualsAndHashCode() { @@ -72,6 +76,7 @@ public void testEqualsAndHashCode() { final SslClientAuthenticationMode clientAuth = randomFrom(SslClientAuthenticationMode.values()); final List ciphers = randomSubsetOf(randomIntBetween(1, DEFAULT_CIPHERS.size() - 1), DEFAULT_CIPHERS); final List protocols = randomSubsetOf(randomIntBetween(1, VALID_PROTOCOLS.length - 1), VALID_PROTOCOLS); + final long handshakeTimeoutMillis = randomHandshakeTimeoutMillis(); final SslConfiguration configuration = new SslConfiguration( "test.ssl", true, @@ -80,7 +85,8 @@ public void testEqualsAndHashCode() { verificationMode, clientAuth, ciphers, - protocols + protocols, + handshakeTimeoutMillis ); EqualsHashCodeTestUtils.checkEqualsAndHashCode( @@ -93,14 +99,15 @@ public void testEqualsAndHashCode() { orig.verificationMode(), orig.clientAuth(), orig.getCipherSuites(), - orig.supportedProtocols() + orig.supportedProtocols(), + orig.handshakeTimeoutMillis() ), this::mutateSslConfiguration ); } private SslConfiguration mutateSslConfiguration(SslConfiguration orig) { - return switch (randomIntBetween(1, 4)) { + return switch (randomIntBetween(1, 5)) { case 1 -> new SslConfiguration( "test.ssl", true, @@ -109,7 +116,8 @@ private SslConfiguration mutateSslConfiguration(SslConfiguration orig) { randomValueOtherThan(orig.verificationMode(), () -> randomFrom(SslVerificationMode.values())), orig.clientAuth(), orig.getCipherSuites(), - orig.supportedProtocols() + orig.supportedProtocols(), + orig.handshakeTimeoutMillis() ); case 2 -> new SslConfiguration( "test.ssl", @@ -119,7 +127,8 @@ private SslConfiguration mutateSslConfiguration(SslConfiguration orig) { orig.verificationMode(), randomValueOtherThan(orig.clientAuth(), () -> randomFrom(SslClientAuthenticationMode.values())), orig.getCipherSuites(), - orig.supportedProtocols() + orig.supportedProtocols(), + orig.handshakeTimeoutMillis() ); case 3 -> new SslConfiguration( "test.ssl", @@ -129,7 +138,19 @@ private SslConfiguration mutateSslConfiguration(SslConfiguration orig) { orig.verificationMode(), orig.clientAuth(), DEFAULT_CIPHERS, - orig.supportedProtocols() + orig.supportedProtocols(), + orig.handshakeTimeoutMillis() + ); + case 4 -> new SslConfiguration( + "test.ssl", + true, + orig.trustConfig(), + orig.keyConfig(), + orig.verificationMode(), + orig.clientAuth(), + orig.getCipherSuites(), + Arrays.asList(VALID_PROTOCOLS), + orig.handshakeTimeoutMillis() ); default -> new SslConfiguration( "test.ssl", @@ -139,11 +160,16 @@ private SslConfiguration mutateSslConfiguration(SslConfiguration orig) { orig.verificationMode(), orig.clientAuth(), orig.getCipherSuites(), - Arrays.asList(VALID_PROTOCOLS) + orig.supportedProtocols(), + randomValueOtherThan(orig.handshakeTimeoutMillis(), SslConfigurationTests::randomHandshakeTimeoutMillis) ); }; } + private static long randomHandshakeTimeoutMillis() { + return randomLongBetween(1, 100000); + } + public void testDependentFiles() { final SslTrustConfig trustConfig = Mockito.mock(SslTrustConfig.class); final SslKeyConfig keyConfig = Mockito.mock(SslKeyConfig.class); @@ -155,7 +181,8 @@ public void testDependentFiles() { randomFrom(SslVerificationMode.values()), randomFrom(SslClientAuthenticationMode.values()), DEFAULT_CIPHERS, - SslConfigurationLoader.DEFAULT_PROTOCOLS + SslConfigurationLoader.DEFAULT_PROTOCOLS, + randomHandshakeTimeoutMillis() ); final Path dir = createTempDir(); @@ -182,7 +209,8 @@ public void testBuildSslContext() { randomFrom(SslVerificationMode.values()), randomFrom(SslClientAuthenticationMode.values()), DEFAULT_CIPHERS, - Collections.singletonList(protocol) + Collections.singletonList(protocol), + randomHandshakeTimeoutMillis() ); Mockito.when(trustConfig.createTrustManager()).thenReturn(null); diff --git a/libs/x-content/src/main/java/org/elasticsearch/xcontent/AbstractObjectParser.java b/libs/x-content/src/main/java/org/elasticsearch/xcontent/AbstractObjectParser.java index 244e1270fe530..140ee9357dd08 100644 --- a/libs/x-content/src/main/java/org/elasticsearch/xcontent/AbstractObjectParser.java +++ b/libs/x-content/src/main/java/org/elasticsearch/xcontent/AbstractObjectParser.java @@ -10,6 +10,7 @@ package org.elasticsearch.xcontent; import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.xcontent.ObjectParser.NamedObjectParser; import org.elasticsearch.xcontent.ObjectParser.ValueType; @@ -230,11 +231,13 @@ public void declareDoubleOrNull(BiConsumer consumer, double nullV ); } + @UpdateForV10(owner = UpdateForV10.Owner.CORE_INFRA) // https://github.com/elastic/elasticsearch/issues/130797 public void declareLong(BiConsumer consumer, ParseField field) { // Using a method reference here angers some compilers declareField(consumer, p -> p.longValue(), field, ValueType.LONG); } + @UpdateForV10(owner = UpdateForV10.Owner.CORE_INFRA) // https://github.com/elastic/elasticsearch/issues/130797 public void declareLongOrNull(BiConsumer consumer, long nullValue, ParseField field) { // Using a method reference here angers some compilers declareField( @@ -245,6 +248,7 @@ public void declareLongOrNull(BiConsumer consumer, long nullValue, ); } + @UpdateForV10(owner = UpdateForV10.Owner.CORE_INFRA) // https://github.com/elastic/elasticsearch/issues/130797 public void declareInt(BiConsumer consumer, ParseField field) { // Using a method reference here angers some compilers declareField(consumer, p -> p.intValue(), field, ValueType.INT); @@ -253,6 +257,7 @@ public void declareInt(BiConsumer consumer, ParseField field) { /** * Declare an integer field that parses explicit {@code null}s in the json to a default value. */ + @UpdateForV10(owner = UpdateForV10.Owner.CORE_INFRA) // https://github.com/elastic/elasticsearch/issues/130797 public void declareIntOrNull(BiConsumer consumer, int nullValue, ParseField field) { declareField( consumer, @@ -320,10 +325,12 @@ public void declareFloatArray(BiConsumer> consumer, ParseFiel declareFieldArray(consumer, (p, c) -> p.floatValue(), field, ValueType.FLOAT_ARRAY); } + @UpdateForV10(owner = UpdateForV10.Owner.CORE_INFRA) // https://github.com/elastic/elasticsearch/issues/130797 public void declareLongArray(BiConsumer> consumer, ParseField field) { declareFieldArray(consumer, (p, c) -> p.longValue(), field, ValueType.LONG_ARRAY); } + @UpdateForV10(owner = UpdateForV10.Owner.CORE_INFRA) // https://github.com/elastic/elasticsearch/issues/130797 public void declareIntArray(BiConsumer> consumer, ParseField field) { declareFieldArray(consumer, (p, c) -> p.intValue(), field, ValueType.INT_ARRAY); } diff --git a/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java b/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java index be13207702627..c3b322db0e3a5 100644 --- a/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java +++ b/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java @@ -220,6 +220,27 @@ public void declareField(BiConsumer consumer, ContextParser void declareObjectArrayOrNull( + BiConsumer> consumer, + ContextParser objectParser, + ParseField field + ) { + declareField( + consumer, + (p, c) -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : parseArray(p, c, objectParser), + field, + ValueType.OBJECT_ARRAY_OR_NULL + ); + } + @Override public void declareNamedObject( BiConsumer consumer, diff --git a/modules/apm/build.gradle b/modules/apm/build.gradle index 86d06258bcbca..37d42e4b3fb0c 100644 --- a/modules/apm/build.gradle +++ b/modules/apm/build.gradle @@ -20,7 +20,7 @@ dependencies { implementation "io.opentelemetry:opentelemetry-api:${otelVersion}" implementation "io.opentelemetry:opentelemetry-context:${otelVersion}" implementation "io.opentelemetry:opentelemetry-semconv:${otelSemconvVersion}" - runtimeOnly "co.elastic.apm:elastic-apm-agent:1.52.2" + runtimeOnly "co.elastic.apm:elastic-apm-agent:1.55.0" javaRestTestImplementation project(':modules:apm') javaRestTestImplementation project(':test:framework') diff --git a/modules/apm/src/javaRestTest/java/org/elasticsearch/telemetry/apm/ApmAgentSettingsIT.java b/modules/apm/src/javaRestTest/java/org/elasticsearch/telemetry/apm/ApmAgentSettingsIT.java index ee26178723608..16f6fb1bf8ace 100644 --- a/modules/apm/src/javaRestTest/java/org/elasticsearch/telemetry/apm/ApmAgentSettingsIT.java +++ b/modules/apm/src/javaRestTest/java/org/elasticsearch/telemetry/apm/ApmAgentSettingsIT.java @@ -18,10 +18,7 @@ public class ApmAgentSettingsIT extends ESRestTestCase { @ClassRule - public static ElasticsearchCluster cluster = ElasticsearchCluster.local() - .module("apm") - .systemProperty("es.entitlements.enabled", "true") - .build(); + public static ElasticsearchCluster cluster = ElasticsearchCluster.local().module("apm").build(); @Override protected String getTestRestCluster() { diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderIT.java b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderIT.java index 9b2f5400d4f2d..4fdee727b5755 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderIT.java +++ b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.Settings; @@ -116,7 +117,7 @@ public void testEnterpriseDownloaderTask() throws Exception { final String sourceField = "ip"; final String targetField = "ip-result"; - startEnterpriseGeoIpDownloaderTask(); + startEnterpriseGeoIpDownloaderTask(ProjectId.DEFAULT); configureMaxmindDatabase(MAXMIND_DATABASE_TYPE); configureIpinfoDatabase(IPINFO_DATABASE_TYPE); waitAround(); @@ -171,9 +172,10 @@ private void deleteDatabaseConfiguration(String configurationName, ActionListene ); } - private void startEnterpriseGeoIpDownloaderTask() { + private void startEnterpriseGeoIpDownloaderTask(ProjectId projectId) { PersistentTasksService persistentTasksService = internalCluster().getInstance(PersistentTasksService.class); - persistentTasksService.sendStartRequest( + persistentTasksService.sendProjectStartRequest( + projectId, ENTERPRISE_GEOIP_DOWNLOADER, ENTERPRISE_GEOIP_DOWNLOADER, new EnterpriseGeoIpTask.EnterpriseGeoIpTaskParams(), diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/ConfigDatabases.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/ConfigDatabases.java index cf677558785f7..35ca8a8e2a273 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/ConfigDatabases.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/ConfigDatabases.java @@ -11,7 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.cluster.metadata.ProjectId; -import org.elasticsearch.core.FixForMultiProject; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.env.Environment; import org.elasticsearch.watcher.FileChangesListener; import org.elasticsearch.watcher.FileWatcher; @@ -34,6 +34,9 @@ * Keeps track of user provided databases in the ES_HOME/config/ingest-geoip directory. * This directory is monitored and files updates are picked up and may cause databases being loaded or removed at runtime. */ +@NotMultiProjectCapable( + description = "Custom databases not available in serverless, we should review this class for MP again after serverless is enabled" +) final class ConfigDatabases implements Closeable { private static final Logger logger = LogManager.getLogger(ConfigDatabases.class); @@ -71,7 +74,7 @@ Map getConfigDatabases() { return configDatabases; } - @FixForMultiProject(description = "Replace DEFAULT project") + @NotMultiProjectCapable(description = "Replace DEFAULT project after serverless is enabled") void updateDatabase(Path file, boolean update) { String databaseFileName = file.getFileName().toString(); try { @@ -93,7 +96,7 @@ void updateDatabase(Path file, boolean update) { } } - @FixForMultiProject(description = "Replace DEFAULT project") + @NotMultiProjectCapable(description = "Replace DEFAULT project after serverless is enabled") Map initConfigDatabases() throws IOException { Map databases = new HashMap<>(); diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloader.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloader.java index 29f246e58e844..06f672a3719ef 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloader.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloader.java @@ -19,10 +19,12 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.Strings; import org.elasticsearch.common.hash.MessageDigests; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.query.BoolQueryBuilder; @@ -67,6 +69,9 @@ * Downloads are verified against MD5 checksum provided by the server * Current state of all stored databases is stored in cluster state in persistent task state */ +@NotMultiProjectCapable( + description = "Enterprise GeoIP not available in serverless, we should review this class for MP again after serverless is enabled" +) public class EnterpriseGeoIpDownloader extends AllocatedPersistentTask { private static final Logger logger = LogManager.getLogger(EnterpriseGeoIpDownloader.class); @@ -142,22 +147,27 @@ void setState(EnterpriseGeoIpTaskState state) { // visible for testing void updateDatabases() throws IOException { + @NotMultiProjectCapable(description = "Enterprise GeoIP not available in serverless") + ProjectId projectId = ProjectId.DEFAULT; var clusterState = clusterService.state(); - var geoipIndex = clusterState.getMetadata().getProject().getIndicesLookup().get(EnterpriseGeoIpDownloader.DATABASES_INDEX); + var geoipIndex = clusterState.getMetadata().getProject(projectId).getIndicesLookup().get(EnterpriseGeoIpDownloader.DATABASES_INDEX); if (geoipIndex != null) { logger.trace("the geoip index [{}] exists", EnterpriseGeoIpDownloader.DATABASES_INDEX); - if (clusterState.getRoutingTable().index(geoipIndex.getWriteIndex()).allPrimaryShardsActive() == false) { + if (clusterState.routingTable(projectId).index(geoipIndex.getWriteIndex()).allPrimaryShardsActive() == false) { logger.debug("not updating databases because not all primary shards of [{}] index are active yet", DATABASES_INDEX); return; } - var blockException = clusterState.blocks().indexBlockedException(ClusterBlockLevel.WRITE, geoipIndex.getWriteIndex().getName()); + var blockException = clusterState.blocks() + .indexBlockedException(projectId, ClusterBlockLevel.WRITE, geoipIndex.getWriteIndex().getName()); if (blockException != null) { throw blockException; } } logger.trace("Updating databases"); - IngestGeoIpMetadata geoIpMeta = clusterState.metadata().getProject().custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); + IngestGeoIpMetadata geoIpMeta = clusterState.metadata() + .getProject(projectId) + .custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); // if there are entries in the cs that aren't in the persistent task state, // then download those (only) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTaskExecutor.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTaskExecutor.java index 6d1e7cea5efc9..1eb42493d2716 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTaskExecutor.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTaskExecutor.java @@ -23,6 +23,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.ingest.EnterpriseGeoIpTask.EnterpriseGeoIpTaskParams; import org.elasticsearch.ingest.IngestService; @@ -47,6 +48,9 @@ import static org.elasticsearch.ingest.geoip.GeoIpDownloaderTaskExecutor.ENABLED_SETTING; import static org.elasticsearch.ingest.geoip.GeoIpDownloaderTaskExecutor.POLL_INTERVAL_SETTING; +@NotMultiProjectCapable( + description = "Enterprise GeoIP not available in serverless, we should review this class for MP again after serverless is enabled" +) public class EnterpriseGeoIpDownloaderTaskExecutor extends PersistentTasksExecutor implements ClusterStateListener { diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportDeleteDatabaseConfigurationAction.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportDeleteDatabaseConfigurationAction.java index 5b1be42b1dcfa..bb77da9de731a 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportDeleteDatabaseConfigurationAction.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportDeleteDatabaseConfigurationAction.java @@ -21,7 +21,9 @@ import org.elasticsearch.cluster.SimpleBatchedExecutor; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.common.Priority; @@ -56,13 +58,15 @@ public void taskSucceeded(DeleteDatabaseConfigurationTask task, Void unused) { }; private final MasterServiceTaskQueue deleteDatabaseConfigurationTaskQueue; + private final ProjectResolver projectResolver; @Inject public TransportDeleteDatabaseConfigurationAction( TransportService transportService, ClusterService clusterService, ThreadPool threadPool, - ActionFilters actionFilters + ActionFilters actionFilters, + ProjectResolver projectResolver ) { super( DeleteDatabaseConfigurationAction.NAME, @@ -79,13 +83,17 @@ public TransportDeleteDatabaseConfigurationAction( Priority.NORMAL, DELETE_TASK_EXECUTOR ); + this.projectResolver = projectResolver; } @Override protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { final String id = request.getDatabaseId(); - final IngestGeoIpMetadata geoIpMeta = state.metadata().getProject().custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); + final ProjectId projectId = projectResolver.getProjectId(); + final IngestGeoIpMetadata geoIpMeta = state.metadata() + .getProject(projectId) + .custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); if (geoIpMeta.getDatabases().containsKey(id) == false) { throw new ResourceNotFoundException("Database configuration not found: {}", id); } else if (geoIpMeta.getDatabases().get(id).database().isReadOnly()) { @@ -93,17 +101,17 @@ protected void masterOperation(Task task, Request request, ClusterState state, A } deleteDatabaseConfigurationTaskQueue.submitTask( Strings.format("delete-geoip-database-configuration-[%s]", id), - new DeleteDatabaseConfigurationTask(listener, id), + new DeleteDatabaseConfigurationTask(projectId, listener, id), null ); } - private record DeleteDatabaseConfigurationTask(ActionListener listener, String databaseId) + private record DeleteDatabaseConfigurationTask(ProjectId projectId, ActionListener listener, String databaseId) implements ClusterStateTaskListener { ClusterState execute(ClusterState currentState) throws Exception { - final var project = currentState.metadata().getProject(); + final var project = currentState.metadata().getProject(projectId); final IngestGeoIpMetadata geoIpMeta = project.custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); logger.debug("deleting database configuration [{}]", databaseId); diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportGetDatabaseConfigurationAction.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportGetDatabaseConfigurationAction.java index 7233765bfeda5..88200af2d2e18 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportGetDatabaseConfigurationAction.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/TransportGetDatabaseConfigurationAction.java @@ -13,11 +13,14 @@ import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.nodes.TransportNodesAction; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.ingest.geoip.DatabaseNodeService; +import org.elasticsearch.ingest.geoip.GeoIpDownloaderTaskExecutor; import org.elasticsearch.ingest.geoip.GeoIpTaskState; import org.elasticsearch.ingest.geoip.IngestGeoIpMetadata; import org.elasticsearch.injection.guice.Inject; @@ -48,6 +51,7 @@ public class TransportGetDatabaseConfigurationAction extends TransportNodesActio List> { private final DatabaseNodeService databaseNodeService; + private final ProjectResolver projectResolver; @Inject public TransportGetDatabaseConfigurationAction( @@ -55,7 +59,8 @@ public TransportGetDatabaseConfigurationAction( ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, - DatabaseNodeService databaseNodeService + DatabaseNodeService databaseNodeService, + ProjectResolver projectResolver ) { super( GetDatabaseConfigurationAction.NAME, @@ -66,6 +71,7 @@ public TransportGetDatabaseConfigurationAction( threadPool.executor(ThreadPool.Names.MANAGEMENT) ); this.databaseNodeService = databaseNodeService; + this.projectResolver = projectResolver; } protected List createActionContext(Task task, GetDatabaseConfigurationAction.Request request) { @@ -82,14 +88,14 @@ protected List createActionContext(Task task, Get "wildcard only supports a single value, please use comma-separated values or a single wildcard value" ); } - List results = new ArrayList<>(); - PersistentTasksCustomMetadata tasksMetadata = PersistentTasksCustomMetadata.getPersistentTasksCustomMetadata( - clusterService.state() - ); + ProjectMetadata projectMetadata = projectResolver.getProjectMetadata(clusterService.state()); + PersistentTasksCustomMetadata tasksMetadata = PersistentTasksCustomMetadata.get(projectMetadata); + String geoIpTaskId = GeoIpDownloaderTaskExecutor.getTaskId(projectMetadata.id(), projectResolver.supportsMultipleProjects()); + for (String id : ids) { - results.addAll(getWebDatabases(tasksMetadata, id)); - results.addAll(getMaxmindDatabases(clusterService, id)); + results.addAll(getWebDatabases(geoIpTaskId, tasksMetadata, id)); + results.addAll(getMaxmindDatabases(projectMetadata, id)); } return results; } @@ -97,10 +103,14 @@ protected List createActionContext(Task task, Get /* * This returns read-only database information about the databases managed by the standard downloader */ - private static Collection getWebDatabases(PersistentTasksCustomMetadata tasksMetadata, String id) { + private static Collection getWebDatabases( + String geoIpTaskId, + PersistentTasksCustomMetadata tasksMetadata, + String id + ) { List webDatabases = new ArrayList<>(); if (tasksMetadata != null) { - PersistentTasksCustomMetadata.PersistentTask maybeGeoIpTask = tasksMetadata.getTask("geoip-downloader"); + PersistentTasksCustomMetadata.PersistentTask maybeGeoIpTask = tasksMetadata.getTask(geoIpTaskId); if (maybeGeoIpTask != null) { GeoIpTaskState geoIpTaskState = (GeoIpTaskState) maybeGeoIpTask.getState(); if (geoIpTaskState != null) { @@ -137,12 +147,9 @@ private static String getDatabaseNameForFileName(String databaseFileName) { /* * This returns information about databases that are downloaded from maxmind. */ - private static Collection getMaxmindDatabases(ClusterService clusterService, String id) { + private static Collection getMaxmindDatabases(ProjectMetadata projectMetadata, String id) { List maxmindDatabases = new ArrayList<>(); - final IngestGeoIpMetadata geoIpMeta = clusterService.state() - .metadata() - .getProject() - .custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); + final IngestGeoIpMetadata geoIpMeta = projectMetadata.custom(IngestGeoIpMetadata.TYPE, IngestGeoIpMetadata.EMPTY); if (Regex.isSimpleMatchPattern(id)) { for (Map.Entry entry : geoIpMeta.getDatabases().entrySet()) { if (Regex.simpleMatch(id, entry.getKey())) { diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsTests.java similarity index 99% rename from modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java rename to modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsTests.java index c65d9a2dc2009..7f298038141df 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsTests.java @@ -57,7 +57,7 @@ // 'WindowsFS.checkDeleteAccess(...)'). } ) -public class ReloadingDatabasesWhilePerformingGeoLookupsIT extends ESTestCase { +public class ReloadingDatabasesWhilePerformingGeoLookupsTests extends ESTestCase { /** * This tests essentially verifies that a Maxmind database reader doesn't fail with: diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java index 04ace5ccc4157..faea13dac4e31 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java @@ -47,6 +47,7 @@ import org.elasticsearch.index.mapper.BlockStoredFieldsReader; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.SourceValueFetcher; @@ -251,6 +252,17 @@ private IOFunction, IOExcepti if (searchExecutionContext.isSourceSynthetic() && withinMultiField) { String parentField = searchExecutionContext.parentPath(name()); var parent = searchExecutionContext.lookup().fieldType(parentField); + + if (parent instanceof KeywordFieldMapper.KeywordFieldType keywordParent + && keywordParent.ignoreAbove() != Integer.MAX_VALUE) { + if (parent.isStored()) { + return storedFieldFetcher(parentField, keywordParent.originalName()); + } else if (parent.hasDocValues()) { + var ifd = searchExecutionContext.getForField(parent, MappedFieldType.FielddataOperation.SEARCH); + return combineFieldFetchers(docValuesFieldFetcher(ifd), storedFieldFetcher(keywordParent.originalName())); + } + } + if (parent.isStored()) { return storedFieldFetcher(parentField); } else if (parent.hasDocValues()) { @@ -262,8 +274,19 @@ private IOFunction, IOExcepti } else if (searchExecutionContext.isSourceSynthetic() && hasCompatibleMultiFields) { var mapper = (MatchOnlyTextFieldMapper) searchExecutionContext.getMappingLookup().getMapper(name()); var kwd = TextFieldMapper.SyntheticSourceHelper.getKeywordFieldMapperForSyntheticSource(mapper); + if (kwd != null) { var fieldType = kwd.fieldType(); + + if (fieldType.ignoreAbove() != Integer.MAX_VALUE) { + if (fieldType.isStored()) { + return storedFieldFetcher(fieldType.name(), fieldType.originalName()); + } else if (fieldType.hasDocValues()) { + var ifd = searchExecutionContext.getForField(fieldType, MappedFieldType.FielddataOperation.SEARCH); + return combineFieldFetchers(docValuesFieldFetcher(ifd), storedFieldFetcher(fieldType.originalName())); + } + } + if (fieldType.isStored()) { return storedFieldFetcher(fieldType.name()); } else if (fieldType.hasDocValues()) { @@ -312,13 +335,52 @@ private static IOFunction, IO }; } - private static IOFunction, IOException>> storedFieldFetcher(String name) { - var loader = StoredFieldLoader.create(false, Set.of(name)); + private static IOFunction, IOException>> storedFieldFetcher(String... names) { + var loader = StoredFieldLoader.create(false, Set.of(names)); return context -> { var leafLoader = loader.getLoader(context, null); return docId -> { leafLoader.advanceTo(docId); - return leafLoader.storedFields().get(name); + var storedFields = leafLoader.storedFields(); + if (names.length == 1) { + return storedFields.get(names[0]); + } + + List values = new ArrayList<>(); + for (var name : names) { + var currValues = storedFields.get(name); + if (currValues != null) { + values.addAll(currValues); + } + } + + return values; + }; + }; + } + + private static IOFunction, IOException>> combineFieldFetchers( + IOFunction, IOException>> primaryFetcher, + IOFunction, IOException>> secondaryFetcher + ) { + return context -> { + var primaryGetter = primaryFetcher.apply(context); + var secondaryGetter = secondaryFetcher.apply(context); + return docId -> { + List values = new ArrayList<>(); + var primary = primaryGetter.apply(docId); + if (primary != null) { + values.addAll(primary); + } + + var secondary = secondaryGetter.apply(docId); + if (secondary != null) { + values.addAll(secondary); + } + + assert primary != null || secondary != null; + + return values; }; }; } diff --git a/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/match_only_text/10_basic.yml b/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/match_only_text/10_basic.yml index 1d52038e29d45..48a596ef14c72 100644 --- a/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/match_only_text/10_basic.yml +++ b/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/match_only_text/10_basic.yml @@ -435,6 +435,50 @@ synthetic_source match_only_text as multi-field: - match: hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" +--- +synthetic_source match_only_text as multi-field with ignored keyword as parent: + - requires: + cluster_features: [ "mapper.source.mode_from_index_setting" ] + reason: "Source mode configured through index setting" + + - do: + indices.create: + index: synthetic_source_test + body: + settings: + index: + mapping.source.mode: synthetic + mappings: + properties: + foo: + type: keyword + store: false + doc_values: true + ignore_above: 10 + fields: + text: + type: match_only_text + + - do: + index: + index: synthetic_source_test + id: "1" + refresh: true + body: + foo: [ "Apache Lucene powers Elasticsearch", "Apache" ] + + - do: + search: + index: synthetic_source_test + body: + query: + match_phrase: + foo.text: apache lucene + + - match: { "hits.total.value": 1 } + - match: + hits.hits.0._source.foo: [ "Apache", "Apache Lucene powers Elasticsearch" ] + --- synthetic_source match_only_text as multi-field with stored keyword as parent: - requires: @@ -479,6 +523,49 @@ synthetic_source match_only_text as multi-field with stored keyword as parent: hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" --- +synthetic_source match_only_text as multi-field with ignored stored keyword as parent: + - requires: + cluster_features: [ "mapper.source.mode_from_index_setting" ] + reason: "Source mode configured through index setting" + + - do: + indices.create: + index: synthetic_source_test + body: + settings: + index: + mapping.source.mode: synthetic + mappings: + properties: + foo: + type: keyword + store: true + doc_values: false + ignore_above: 10 + fields: + text: + type: match_only_text + + - do: + index: + index: synthetic_source_test + id: "1" + refresh: true + body: + foo: "Apache Lucene powers Elasticsearch" + + - do: + search: + index: synthetic_source_test + body: + query: + match_phrase: + foo.text: apache lucene + + - match: { "hits.total.value": 1 } + - match: + hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" +--- synthetic_source match_only_text with multi-field: - requires: cluster_features: [ "mapper.source.mode_from_index_setting" ] @@ -519,6 +606,50 @@ synthetic_source match_only_text with multi-field: - match: hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" +--- +synthetic_source match_only_text with ignored multi-field: + - requires: + cluster_features: [ "mapper.source.mode_from_index_setting" ] + reason: "Source mode configured through index setting" + + - do: + indices.create: + index: synthetic_source_test + body: + settings: + index: + mapping.source.mode: synthetic + mappings: + properties: + foo: + type: match_only_text + fields: + raw: + type: keyword + store: false + doc_values: true + ignore_above: 10 + + - do: + index: + index: synthetic_source_test + id: "1" + refresh: true + body: + foo: "Apache Lucene powers Elasticsearch" + + - do: + search: + index: synthetic_source_test + body: + query: + match_phrase: + foo: apache lucene + + - match: { "hits.total.value": 1 } + - match: + hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" + --- synthetic_source match_only_text with stored multi-field: - requires: @@ -561,3 +692,47 @@ synthetic_source match_only_text with stored multi-field: - match: { "hits.total.value": 1 } - match: hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" + +--- +synthetic_source match_only_text with ignored stored multi-field: + - requires: + cluster_features: [ "mapper.source.mode_from_index_setting" ] + reason: "Source mode configured through index setting" + + - do: + indices.create: + index: synthetic_source_test + body: + settings: + index: + mapping.source.mode: synthetic + mappings: + properties: + foo: + type: match_only_text + fields: + raw: + type: keyword + store: true + doc_values: false + ignore_above: 10 + + - do: + index: + index: synthetic_source_test + id: "1" + refresh: true + body: + foo: "Apache Lucene powers Elasticsearch" + + - do: + search: + index: synthetic_source_test + body: + query: + match_phrase: + foo: apache lucene + + - match: { "hits.total.value": 1 } + - match: + hits.hits.0._source.foo: "Apache Lucene powers Elasticsearch" diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexBasicTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexBasicTests.java index 86bcda284babb..96c7ef49f6956 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexBasicTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexBasicTests.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.index.IndexSettings.SYNTHETIC_VECTORS; import static org.elasticsearch.index.query.QueryBuilders.termQuery; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; @@ -181,6 +182,7 @@ public void testReindexFromComplexDateMathIndexName() throws Exception { } public void testReindexIncludeVectors() throws Exception { + assumeTrue("This test requires synthetic vectors to be enabled", SYNTHETIC_VECTORS); var resp1 = prepareCreate("test").setSettings( Settings.builder().put(IndexSettings.INDEX_MAPPING_SOURCE_SYNTHETIC_VECTORS_SETTING.getKey(), true).build() ).setMapping("foo", "type=dense_vector,similarity=l2_norm", "bar", "type=sparse_vector").get(); diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/UpdateByQueryBasicTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/UpdateByQueryBasicTests.java index ac7d05610ad3c..33c80e9138d28 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/UpdateByQueryBasicTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/UpdateByQueryBasicTests.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.index.IndexSettings.SYNTHETIC_VECTORS; import static org.elasticsearch.index.query.QueryBuilders.termQuery; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; @@ -157,6 +158,7 @@ public void testMissingSources() { } public void testUpdateByQueryIncludeVectors() throws Exception { + assumeTrue("This test requires synthetic vectors to be enabled", SYNTHETIC_VECTORS); var resp1 = prepareCreate("test").setSettings( Settings.builder().put(IndexSettings.INDEX_MAPPING_SOURCE_SYNTHETIC_VECTORS_SETTING.getKey(), true).build() ).setMapping("foo", "type=dense_vector,similarity=l2_norm", "bar", "type=sparse_vector").get(); diff --git a/modules/repository-s3/src/main/resources/org/elasticsearch/repositories/s3/regions_by_endpoint.txt b/modules/repository-s3/src/main/resources/org/elasticsearch/repositories/s3/regions_by_endpoint.txt index 3fae5c314c10b..5ca027a5f4a13 100644 --- a/modules/repository-s3/src/main/resources/org/elasticsearch/repositories/s3/regions_by_endpoint.txt +++ b/modules/repository-s3/src/main/resources/org/elasticsearch/repositories/s3/regions_by_endpoint.txt @@ -6,6 +6,10 @@ ap-east-1 s3-fips.ap-east-1.amazonaws.com ap-east-1 s3-fips.dualstack.ap-east-1.amazonaws.com ap-east-1 s3.ap-east-1.amazonaws.com ap-east-1 s3.dualstack.ap-east-1.amazonaws.com +ap-east-2 s3-fips.ap-east-2.amazonaws.com +ap-east-2 s3-fips.dualstack.ap-east-2.amazonaws.com +ap-east-2 s3.ap-east-2.amazonaws.com +ap-east-2 s3.dualstack.ap-east-2.amazonaws.com ap-northeast-1 s3-fips.ap-northeast-1.amazonaws.com ap-northeast-1 s3-fips.dualstack.ap-northeast-1.amazonaws.com ap-northeast-1 s3.ap-northeast-1.amazonaws.com @@ -56,6 +60,14 @@ aws-iso-b-global s3-fips.aws-iso-b-global.sc2s.sgov.gov aws-iso-b-global s3-fips.dualstack.aws-iso-b-global.sc2s.sgov.gov aws-iso-b-global s3.aws-iso-b-global.sc2s.sgov.gov aws-iso-b-global s3.dualstack.aws-iso-b-global.sc2s.sgov.gov +aws-iso-e-global s3-fips.aws-iso-e-global.cloud.adc-e.uk +aws-iso-e-global s3-fips.dualstack.aws-iso-e-global.cloud.adc-e.uk +aws-iso-e-global s3.aws-iso-e-global.cloud.adc-e.uk +aws-iso-e-global s3.dualstack.aws-iso-e-global.cloud.adc-e.uk +aws-iso-f-global s3-fips.aws-iso-f-global.csp.hci.ic.gov +aws-iso-f-global s3-fips.dualstack.aws-iso-f-global.csp.hci.ic.gov +aws-iso-f-global s3.aws-iso-f-global.csp.hci.ic.gov +aws-iso-f-global s3.dualstack.aws-iso-f-global.csp.hci.ic.gov aws-iso-global s3-fips.aws-iso-global.c2s.ic.gov aws-iso-global s3-fips.dualstack.aws-iso-global.c2s.ic.gov aws-iso-global s3.aws-iso-global.c2s.ic.gov @@ -76,6 +88,10 @@ cn-north-1 s3.cn-north-1.amazonaws.com.cn cn-north-1 s3.dualstack.cn-north-1.amazonaws.com.cn cn-northwest-1 s3.cn-northwest-1.amazonaws.com.cn cn-northwest-1 s3.dualstack.cn-northwest-1.amazonaws.com.cn +eusc-de-east-1 s3-fips.eusc-de-east-1.amazonaws.eu +eusc-de-east-1 s3-fips.dualstack.eusc-de-east-1.amazonaws.eu +eusc-de-east-1 s3.eusc-de-east-1.amazonaws.eu +eusc-de-east-1 s3.dualstack.eusc-de-east-1.amazonaws.eu eu-central-1 s3-fips.dualstack.eu-central-1.amazonaws.com eu-central-1 s3-fips.eu-central-1.amazonaws.com eu-central-1 s3.dualstack.eu-central-1.amazonaws.com diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/RegionFromEndpointGuesserTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/RegionFromEndpointGuesserTests.java index 9fe0c40c83979..402181878b600 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/RegionFromEndpointGuesserTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/RegionFromEndpointGuesserTests.java @@ -9,6 +9,11 @@ package org.elasticsearch.repositories.s3; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; +import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider; + import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; @@ -23,6 +28,14 @@ public void testRegionGuessing() { assertRegionGuess("random.endpoint.internal.net", null); } + public void testHasEntryForEachRegion() { + final var defaultS3EndpointProvider = new DefaultS3EndpointProvider(); + for (var region : Region.regions()) { + final Endpoint endpoint = safeGet(defaultS3EndpointProvider.resolveEndpoint(S3EndpointParams.builder().region(region).build())); + assertNotNull(region.id(), RegionFromEndpointGuesser.guessRegion(endpoint.url().toString())); + } + } + private static void assertRegionGuess(String endpoint, @Nullable String expectedRegion) { assertEquals(endpoint, expectedRegion, RegionFromEndpointGuesser.guessRegion(endpoint)); } diff --git a/modules/streams/src/main/java/org/elasticsearch/rest/streams/logs/StreamsStatusAction.java b/modules/streams/src/main/java/org/elasticsearch/rest/streams/logs/StreamsStatusAction.java index 95a1783ac0452..4b60c088980e9 100644 --- a/modules/streams/src/main/java/org/elasticsearch/rest/streams/logs/StreamsStatusAction.java +++ b/modules/streams/src/main/java/org/elasticsearch/rest/streams/logs/StreamsStatusAction.java @@ -26,7 +26,7 @@ public class StreamsStatusAction { - public static ActionType INSTANCE = new ActionType<>("cluster:admin/streams/status"); + public static ActionType INSTANCE = new ActionType<>("cluster:monitor/streams/status"); public static class Request extends LocalClusterStateRequest { protected Request(TimeValue masterTimeout) { diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java index b88907b10c45f..5219e2ab8c10e 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java @@ -51,6 +51,8 @@ import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.plugins.TelemetryPlugin; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestController; @@ -60,6 +62,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.action.EmptyResponseListener; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.telemetry.Measurement; +import org.elasticsearch.telemetry.TestTelemetryPlugin; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.transport.netty4.NettyAllocator; import org.elasticsearch.xcontent.ToXContentObject; @@ -91,7 +95,7 @@ public class Netty4PipeliningIT extends ESNetty4IntegTestCase { @Override protected Collection> nodePlugins() { return CollectionUtils.concatLists( - List.of(CountDown3Plugin.class, ChunkAndFailPlugin.class, KeepPipeliningPlugin.class), + List.of(CountDown3Plugin.class, ChunkAndFailPlugin.class, KeepPipeliningPlugin.class, TestTelemetryPlugin.class), super.nodePlugins() ); } @@ -281,6 +285,90 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } + public void testConnectionStatsExposedToTelemetryPlugin() throws Exception { + final var targetNode = internalCluster().startNode(); + + final var telemetryPlugin = asInstanceOf( + TestTelemetryPlugin.class, + internalCluster().getInstance(PluginsService.class, targetNode).filterPlugins(TelemetryPlugin.class).findAny().orElseThrow() + ); + + assertHttpMetrics(telemetryPlugin, 0L, 0L); + + final var releasables = new ArrayList(3); + try { + final var keepPipeliningRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, KeepPipeliningPlugin.ROUTE); + releasables.add(keepPipeliningRequest::release); + + final var responseReceivedLatch = new CountDownLatch(1); + + final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(1); + releasables.add(() -> eventLoopGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).awaitUninterruptibly()); + final var clientBootstrap = new Bootstrap().channel(NettyAllocator.getChannelType()) + .option(ChannelOption.ALLOCATOR, NettyAllocator.getAllocator()) + .group(eventLoopGroup) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ch.pipeline().addLast(new HttpClientCodec()); + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpResponse msg) { + responseReceivedLatch.countDown(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ExceptionsHelper.maybeDieOnAnotherThread(new AssertionError(cause)); + } + }); + } + }); + + final var httpServerTransport = internalCluster().getInstance(HttpServerTransport.class, targetNode); + final var httpServerAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses()).address(); + + // Open a channel on which we will pipeline the requests to KeepPipeliningPlugin.ROUTE + final var pipeliningChannel = clientBootstrap.connect(httpServerAddress).syncUninterruptibly().channel(); + releasables.add(() -> pipeliningChannel.close().syncUninterruptibly()); + + if (randomBoolean()) { + // assertBusy because client-side connect may complete before server-side + assertBusy(() -> assertHttpMetrics(telemetryPlugin, 1L, 1L)); + } else { + // Send two pipelined requests so that we start to receive responses + pipeliningChannel.writeAndFlush(keepPipeliningRequest.retain()); + pipeliningChannel.writeAndFlush(keepPipeliningRequest.retain()); + + // wait until we've started to receive responses (but we won't have received them all) - server side is definitely open now + safeAwait(responseReceivedLatch); + assertHttpMetrics(telemetryPlugin, 1L, 1L); + } + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + + // assertBusy because client-side close may complete before server-side + assertBusy(() -> assertHttpMetrics(telemetryPlugin, 1L, 0L)); + } + + private static void assertHttpMetrics(TestTelemetryPlugin telemetryPlugin, long expectedTotal, long expectedCurrent) { + try { + telemetryPlugin.collect(); + assertMeasurement(telemetryPlugin.getLongAsyncCounterMeasurement("es.http.connections.total"), expectedTotal); + assertMeasurement(telemetryPlugin.getLongGaugeMeasurement("es.http.connections.current"), expectedCurrent); + } finally { + telemetryPlugin.resetMeter(); + } + } + + private static void assertMeasurement(List measurements, long expectedValue) { + assertThat(measurements, hasSize(1)); + assertThat(measurements.get(0).getLong(), equalTo(expectedValue)); + } + private void assertOpaqueIdsInOrder(Collection opaqueIds) { // check if opaque ids are monotonically increasing int i = 0; diff --git a/muted-tests.yml b/muted-tests.yml index d09e417c6da49..e8283b260b8d2 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -64,8 +64,6 @@ tests: - class: org.elasticsearch.xpack.shutdown.NodeShutdownIT method: testAllocationPreventedForRemoval issue: https://github.com/elastic/elasticsearch/issues/116363 -- class: org.elasticsearch.xpack.security.authc.ldap.ActiveDirectoryGroupsResolverTests - issue: https://github.com/elastic/elasticsearch/issues/116182 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=snapshot/20_operator_privileges_disabled/Operator only settings can be set and restored by non-operator user when operator privileges is disabled} issue: https://github.com/elastic/elasticsearch/issues/116775 @@ -114,8 +112,6 @@ tests: - class: org.elasticsearch.xpack.ml.integration.ForecastIT method: testOverflowToDisk issue: https://github.com/elastic/elasticsearch/issues/117740 -- class: org.elasticsearch.xpack.security.authc.ldap.MultiGroupMappingIT - issue: https://github.com/elastic/elasticsearch/issues/119599 - class: org.elasticsearch.multi_cluster.MultiClusterYamlTestSuiteIT issue: https://github.com/elastic/elasticsearch/issues/119983 - class: org.elasticsearch.xpack.test.rest.XPackRestIT @@ -138,36 +134,17 @@ tests: - class: org.elasticsearch.xpack.inference.DefaultEndPointsIT method: testMultipleInferencesTriggeringDownloadAndDeploy issue: https://github.com/elastic/elasticsearch/issues/120668 -- class: org.elasticsearch.xpack.security.authc.ldap.ADLdapUserSearchSessionFactoryTests - issue: https://github.com/elastic/elasticsearch/issues/119882 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=ml/3rd_party_deployment/Test start deployment fails while model download in progress} issue: https://github.com/elastic/elasticsearch/issues/120810 - class: org.elasticsearch.xpack.security.authc.service.ServiceAccountIT method: testAuthenticateShouldNotFallThroughInCaseOfFailure issue: https://github.com/elastic/elasticsearch/issues/120902 -- class: org.elasticsearch.packaging.test.DockerTests - method: test050BasicApiTests - issue: https://github.com/elastic/elasticsearch/issues/120911 -- class: org.elasticsearch.packaging.test.DockerTests - method: test140CgroupOsStatsAreAvailable - issue: https://github.com/elastic/elasticsearch/issues/120914 -- class: org.elasticsearch.packaging.test.DockerTests - method: test070BindMountCustomPathConfAndJvmOptions - issue: https://github.com/elastic/elasticsearch/issues/120910 -- class: org.elasticsearch.packaging.test.DockerTests - method: test071BindMountCustomPathWithDifferentUID - issue: https://github.com/elastic/elasticsearch/issues/120918 -- class: org.elasticsearch.packaging.test.DockerTests - method: test171AdditionalCliOptionsAreForwarded - issue: https://github.com/elastic/elasticsearch/issues/120925 - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT method: test {p0=nodes.stats/11_indices_metrics/indices mappings exact count test for indices level} issue: https://github.com/elastic/elasticsearch/issues/120950 - class: org.elasticsearch.xpack.ml.integration.PyTorchModelIT issue: https://github.com/elastic/elasticsearch/issues/121165 -- class: org.elasticsearch.xpack.security.authc.ldap.ActiveDirectorySessionFactoryTests - issue: https://github.com/elastic/elasticsearch/issues/121285 - class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT issue: https://github.com/elastic/elasticsearch/issues/121407 - class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT @@ -176,15 +153,9 @@ tests: - class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT method: test {yaml=snapshot.delete/10_basic/Delete a snapshot asynchronously} issue: https://github.com/elastic/elasticsearch/issues/122102 -- class: org.elasticsearch.smoketest.SmokeTestMonitoringWithSecurityIT - method: testHTTPExporterWithSSL - issue: https://github.com/elastic/elasticsearch/issues/122220 - class: org.elasticsearch.blocks.SimpleBlocksIT method: testConcurrentAddBlock issue: https://github.com/elastic/elasticsearch/issues/122324 -- class: org.elasticsearch.packaging.test.DockerTests - method: test151MachineDependentHeapWithSizeOverride - issue: https://github.com/elastic/elasticsearch/issues/123437 - class: org.elasticsearch.action.admin.cluster.node.tasks.CancellableTasksIT method: testChildrenTasksCancelledOnTimeout issue: https://github.com/elastic/elasticsearch/issues/123568 @@ -212,21 +183,12 @@ tests: - class: org.elasticsearch.xpack.restart.MLModelDeploymentFullClusterRestartIT method: testDeploymentSurvivesRestart {cluster=OLD} issue: https://github.com/elastic/elasticsearch/issues/124160 -- class: org.elasticsearch.multiproject.test.CoreWithMultipleProjectsClientYamlTestSuiteIT - method: test {yaml=search.vectors/41_knn_search_byte_quantized/kNN search plus query} - issue: https://github.com/elastic/elasticsearch/issues/124687 - class: org.elasticsearch.packaging.test.BootstrapCheckTests method: test20RunWithBootstrapChecks issue: https://github.com/elastic/elasticsearch/issues/124940 - class: org.elasticsearch.packaging.test.BootstrapCheckTests method: test10Install issue: https://github.com/elastic/elasticsearch/issues/124957 -- class: org.elasticsearch.packaging.test.DockerTests - method: test011SecurityEnabledStatus - issue: https://github.com/elastic/elasticsearch/issues/124990 -- class: org.elasticsearch.packaging.test.DockerTests - method: test012SecurityCanBeDisabled - issue: https://github.com/elastic/elasticsearch/issues/116636 - class: org.elasticsearch.smoketest.MlWithSecurityIT method: test {yaml=ml/data_frame_analytics_crud/Test get stats on newly created config} issue: https://github.com/elastic/elasticsearch/issues/121726 @@ -236,9 +198,6 @@ tests: - class: org.elasticsearch.smoketest.MlWithSecurityIT method: test {yaml=ml/data_frame_analytics_cat_apis/Test cat data frame analytics single job with header} issue: https://github.com/elastic/elasticsearch/issues/125642 -- class: org.elasticsearch.packaging.test.DockerTests - method: test010Install - issue: https://github.com/elastic/elasticsearch/issues/125680 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=transform/transforms_start_stop/Test schedule_now on an already started transform} issue: https://github.com/elastic/elasticsearch/issues/120720 @@ -254,9 +213,6 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=transform/transforms_stats/Test get transform stats with timeout} issue: https://github.com/elastic/elasticsearch/issues/125975 -- class: org.elasticsearch.packaging.test.DockerTests - method: test021InstallPlugin - issue: https://github.com/elastic/elasticsearch/issues/116147 - class: org.elasticsearch.action.RejectionActionIT method: testSimulatedSearchRejectionLoad issue: https://github.com/elastic/elasticsearch/issues/125901 @@ -266,9 +222,6 @@ tests: - class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT method: testSearchWithRandomDisconnects issue: https://github.com/elastic/elasticsearch/issues/122707 -- class: org.elasticsearch.packaging.test.DockerTests - method: test020PluginsListWithNoPlugins - issue: https://github.com/elastic/elasticsearch/issues/126232 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=transform/transforms_reset/Test force reseting a running transform} issue: https://github.com/elastic/elasticsearch/issues/126240 @@ -278,15 +231,9 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=ml/start_data_frame_analytics/Test start classification analysis when the dependent variable cardinality is too low} issue: https://github.com/elastic/elasticsearch/issues/126299 -- class: org.elasticsearch.packaging.test.DockerTests - method: test023InstallPluginUsingConfigFile - issue: https://github.com/elastic/elasticsearch/issues/126145 - class: org.elasticsearch.smoketest.MlWithSecurityIT method: test {yaml=ml/start_data_frame_analytics/Test start classification analysis when the dependent variable cardinality is too low} issue: https://github.com/elastic/elasticsearch/issues/123200 -- class: org.elasticsearch.packaging.test.DockerTests - method: test022InstallPluginsFromLocalArchive - issue: https://github.com/elastic/elasticsearch/issues/116866 - class: org.elasticsearch.smoketest.MlWithSecurityIT method: test {yaml=ml/trained_model_cat_apis/Test cat trained models} issue: https://github.com/elastic/elasticsearch/issues/125750 @@ -329,24 +276,12 @@ tests: - class: org.elasticsearch.cli.keystore.AddStringKeyStoreCommandTests method: testStdinWithMultipleValues issue: https://github.com/elastic/elasticsearch/issues/126882 -- class: org.elasticsearch.packaging.test.DockerTests - method: test024InstallPluginFromArchiveUsingConfigFile - issue: https://github.com/elastic/elasticsearch/issues/126936 -- class: org.elasticsearch.packaging.test.DockerTests - method: test026InstallBundledRepositoryPlugins - issue: https://github.com/elastic/elasticsearch/issues/127081 -- class: org.elasticsearch.packaging.test.DockerTests - method: test026InstallBundledRepositoryPluginsViaConfigFile - issue: https://github.com/elastic/elasticsearch/issues/127158 - class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS2EnrichUnavailableRemotesIT method: testEsqlEnrichWithSkipUnavailable issue: https://github.com/elastic/elasticsearch/issues/127368 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=ml/data_frame_analytics_cat_apis/Test cat data frame analytics all jobs with header} issue: https://github.com/elastic/elasticsearch/issues/127625 -- class: org.elasticsearch.xpack.search.CrossClusterAsyncSearchIT - method: testCancellationViaTimeoutWithAllowPartialResultsSetToFalse - issue: https://github.com/elastic/elasticsearch/issues/127096 - class: org.elasticsearch.xpack.ccr.action.ShardFollowTaskReplicationTests method: testChangeFollowerHistoryUUID issue: https://github.com/elastic/elasticsearch/issues/127680 @@ -356,48 +291,18 @@ tests: - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT method: test {p0=search/350_point_in_time/point-in-time with index filter} issue: https://github.com/elastic/elasticsearch/issues/127741 -- class: org.elasticsearch.packaging.test.DockerTests - method: test025SyncPluginsUsingProxy - issue: https://github.com/elastic/elasticsearch/issues/127138 - class: org.elasticsearch.xpack.esql.action.CrossClusterQueryWithPartialResultsIT method: testOneRemoteClusterPartial issue: https://github.com/elastic/elasticsearch/issues/124055 - class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT method: test {lookup-join.MvJoinKeyOnTheLookupIndex ASYNC} issue: https://github.com/elastic/elasticsearch/issues/128030 -- class: org.elasticsearch.packaging.test.DockerTests - method: test042KeystorePermissionsAreCorrect - issue: https://github.com/elastic/elasticsearch/issues/128018 -- class: org.elasticsearch.packaging.test.DockerTests - method: test072RunEsAsDifferentUserAndGroup - issue: https://github.com/elastic/elasticsearch/issues/128031 -- class: org.elasticsearch.packaging.test.DockerTests - method: test122CanUseDockerLoggingConfig - issue: https://github.com/elastic/elasticsearch/issues/128110 -- class: org.elasticsearch.packaging.test.DockerTests - method: test041AmazonCaCertsAreInTheKeystore - issue: https://github.com/elastic/elasticsearch/issues/128006 -- class: org.elasticsearch.packaging.test.DockerTests - method: test130JavaHasCorrectOwnership - issue: https://github.com/elastic/elasticsearch/issues/128174 -- class: org.elasticsearch.packaging.test.DockerTests - method: test600Interrupt - issue: https://github.com/elastic/elasticsearch/issues/128144 - class: org.elasticsearch.packaging.test.EnrollmentProcessTests method: test20DockerAutoFormCluster issue: https://github.com/elastic/elasticsearch/issues/128113 -- class: org.elasticsearch.packaging.test.DockerTests - method: test121CanUseStackLoggingConfig - issue: https://github.com/elastic/elasticsearch/issues/128165 -- class: org.elasticsearch.packaging.test.DockerTests - method: test080ConfigurePasswordThroughEnvironmentVariableFile - issue: https://github.com/elastic/elasticsearch/issues/128075 - class: org.elasticsearch.ingest.geoip.GeoIpDownloaderCliIT method: testInvalidTimestamp issue: https://github.com/elastic/elasticsearch/issues/128284 -- class: org.elasticsearch.packaging.test.DockerTests - method: test120DockerLogsIncludeElasticsearchLogs - issue: https://github.com/elastic/elasticsearch/issues/128117 - class: org.elasticsearch.packaging.test.TemporaryDirectoryConfigTests method: test21AcceptsCustomPathInDocker issue: https://github.com/elastic/elasticsearch/issues/128114 @@ -413,27 +318,9 @@ tests: - class: org.elasticsearch.xpack.esql.action.CrossClusterQueryWithPartialResultsIT method: testFailToStartRequestOnRemoteCluster issue: https://github.com/elastic/elasticsearch/issues/128545 -- class: org.elasticsearch.packaging.test.DockerTests - method: test124CanRestartContainerWithStackLoggingConfig - issue: https://github.com/elastic/elasticsearch/issues/128121 -- class: org.elasticsearch.packaging.test.DockerTests - method: test085EnvironmentVariablesAreRespectedUnderDockerExec - issue: https://github.com/elastic/elasticsearch/issues/128115 - class: org.elasticsearch.compute.operator.LimitOperatorTests method: testEarlyTermination issue: https://github.com/elastic/elasticsearch/issues/128721 -- class: org.elasticsearch.packaging.test.DockerTests - method: test040JavaUsesTheOsProvidedKeystore - issue: https://github.com/elastic/elasticsearch/issues/128230 -- class: org.elasticsearch.packaging.test.DockerTests - method: test150MachineDependentHeap - issue: https://github.com/elastic/elasticsearch/issues/128120 -- class: org.elasticsearch.packaging.test.DockerTests - method: test073RunEsAsDifferentUserAndGroupWithoutBindMounting - issue: https://github.com/elastic/elasticsearch/issues/128996 -- class: org.elasticsearch.packaging.test.DockerTests - method: test081SymlinksAreFollowedWithEnvironmentVariableFiles - issue: https://github.com/elastic/elasticsearch/issues/128867 - class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT method: test {lookup-join.EnrichLookupStatsBug ASYNC} issue: https://github.com/elastic/elasticsearch/issues/129228 @@ -455,15 +342,9 @@ tests: - class: org.elasticsearch.xpack.profiling.action.GetStatusActionIT method: testWaitsUntilResourcesAreCreated issue: https://github.com/elastic/elasticsearch/issues/129486 -- class: org.elasticsearch.xpack.security.PermissionsIT - method: testWhenUserLimitedByOnlyAliasOfIndexCanWriteToIndexWhichWasRolledoverByILMPolicy - issue: https://github.com/elastic/elasticsearch/issues/129481 - class: org.elasticsearch.upgrades.MlJobSnapshotUpgradeIT method: testSnapshotUpgrader issue: https://github.com/elastic/elasticsearch/issues/98560 -- class: org.elasticsearch.upgrades.QueryableBuiltInRolesUpgradeIT - method: testBuiltInRolesSyncedOnClusterUpgrade - issue: https://github.com/elastic/elasticsearch/issues/129534 - class: org.elasticsearch.search.query.VectorIT method: testFilteredQueryStrategy issue: https://github.com/elastic/elasticsearch/issues/129517 @@ -498,39 +379,9 @@ tests: - class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT method: test {p0=msearch/20_typed_keys/Multisearch test with typed_keys parameter for sampler and significant terms} issue: https://github.com/elastic/elasticsearch/issues/130472 -- class: org.elasticsearch.xpack.ssl.SSLErrorMessageFileTests - method: testMessageForKeyStoreOutsideConfigDir - issue: https://github.com/elastic/elasticsearch/issues/127192 -- class: org.elasticsearch.xpack.ssl.SSLErrorMessageFileTests - method: testMessageForPemKeyOutsideConfigDir - issue: https://github.com/elastic/elasticsearch/issues/127192 -- class: org.elasticsearch.xpack.ssl.SSLErrorMessageFileTests - method: testMessageForPemCertificateOutsideConfigDir - issue: https://github.com/elastic/elasticsearch/issues/127192 -- class: org.elasticsearch.xpack.ssl.SSLErrorMessageFileTests - method: testMessageForTrustStoreOutsideConfigDir - issue: https://github.com/elastic/elasticsearch/issues/127192 -- class: org.elasticsearch.xpack.ssl.SSLErrorMessageFileTests - method: testMessageForCertificateAuthoritiesOutsideConfigDir - issue: https://github.com/elastic/elasticsearch/issues/127192 - class: org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeansTests method: testHKmeans issue: https://github.com/elastic/elasticsearch/issues/130497 -- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT - method: testProjectWhere - issue: https://github.com/elastic/elasticsearch/issues/130504 -- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT - method: testTopNPushedToLucene - issue: https://github.com/elastic/elasticsearch/issues/130505 -- class: org.elasticsearch.xpack.monitoring.exporter.http.HttpExporterTests - method: testExporterWithHostOnly - issue: https://github.com/elastic/elasticsearch/issues/130599 -- class: org.elasticsearch.xpack.monitoring.exporter.http.HttpExporterTests - method: testCreateRestClient - issue: https://github.com/elastic/elasticsearch/issues/130600 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {match-operator.MatchWithMoreComplexDisjunctionAndConjunction SYNC} - issue: https://github.com/elastic/elasticsearch/issues/130640 - class: org.elasticsearch.gradle.LoggedExecFuncTest method: failed tasks output logged to console when spooling true issue: https://github.com/elastic/elasticsearch/issues/119509 @@ -543,39 +394,102 @@ tests: - class: org.elasticsearch.indices.stats.IndexStatsIT method: testFilterCacheStats issue: https://github.com/elastic/elasticsearch/issues/124447 -- class: org.elasticsearch.xpack.esql.action.CrossClusterAsyncQueryStopIT - method: testStopQueryLocal - issue: https://github.com/elastic/elasticsearch/issues/121672 -- class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT - method: test - issue: https://github.com/elastic/elasticsearch/issues/130067 -- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeIT - method: test - issue: https://github.com/elastic/elasticsearch/issues/130067 -- class: org.elasticsearch.xpack.esql.action.EsqlRemoteErrorWrapIT - method: testThatRemoteErrorsAreWrapped - issue: https://github.com/elastic/elasticsearch/issues/130794 - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT method: test {p0=mtermvectors/10_basic/Tests catching other exceptions per item} issue: https://github.com/elastic/elasticsearch/issues/122414 -- class: org.elasticsearch.xpack.slm.SLMFileSettingsIT - method: testSettingsApplied - issue: https://github.com/elastic/elasticsearch/issues/130853 -- class: org.elasticsearch.cluster.ClusterStateSerializationTests - method: testSerializationPreMultiProject - issue: https://github.com/elastic/elasticsearch/issues/130872 -- class: org.elasticsearch.cluster.coordination.votingonly.VotingOnlyNodePluginTests - method: testPreferFullMasterOverVotingOnlyNodes - issue: https://github.com/elastic/elasticsearch/issues/130883 - class: org.elasticsearch.search.SearchWithRejectionsIT method: testOpenContextsAfterRejections issue: https://github.com/elastic/elasticsearch/issues/130821 -- class: org.elasticsearch.cluster.coordination.votingonly.VotingOnlyNodePluginTests - method: testVotingOnlyNodesCannotBeMasterWithoutFullMasterNodes - issue: https://github.com/elastic/elasticsearch/issues/130979 +- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT + method: test {lookup-join.MvJoinKeyOnFromAfterStats ASYNC} + issue: https://github.com/elastic/elasticsearch/issues/131148 +- class: org.elasticsearch.xpack.esql.ccq.MultiClustersIT + method: testLookupJoinAliases + issue: https://github.com/elastic/elasticsearch/issues/131166 +- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT + method: test {p0=field_caps/40_time_series/Get simple time series field caps} + issue: https://github.com/elastic/elasticsearch/issues/131225 +- class: org.elasticsearch.packaging.test.DockerTests + method: test090SecurityCliPackaging + issue: https://github.com/elastic/elasticsearch/issues/131107 +- class: org.elasticsearch.xpack.esql.expression.function.fulltext.ScoreTests + method: testSerializationOfSimple {TestCase=} + issue: https://github.com/elastic/elasticsearch/issues/131334 +- class: org.elasticsearch.xpack.esql.analysis.VerifierTests + method: testMatchInsideEval + issue: https://github.com/elastic/elasticsearch/issues/131336 +- class: org.elasticsearch.packaging.test.DockerTests + method: test022InstallPluginsFromLocalArchive + issue: https://github.com/elastic/elasticsearch/issues/116866 +- class: org.elasticsearch.packaging.test.DockerTests + method: test071BindMountCustomPathWithDifferentUID + issue: https://github.com/elastic/elasticsearch/issues/120917 +- class: org.elasticsearch.packaging.test.DockerTests + method: test171AdditionalCliOptionsAreForwarded + issue: https://github.com/elastic/elasticsearch/issues/120925 +- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT + method: test {p0=search/110_field_collapsing/field collapsing, inner_hits and maxConcurrentGroupRequests} + issue: https://github.com/elastic/elasticsearch/issues/131348 +- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT + method: testSimilarityBetweenConstantVectors {functionName=v_cosine similarityFunction=COSINE} + issue: https://github.com/elastic/elasticsearch/issues/131361 +- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT + method: testDifferentDimensions {functionName=v_cosine similarityFunction=COSINE} + issue: https://github.com/elastic/elasticsearch/issues/131362 +- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT + method: testSimilarityBetweenConstantVectorAndField {functionName=v_cosine similarityFunction=COSINE} + issue: https://github.com/elastic/elasticsearch/issues/131363 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + method: test {p0=ml/delete_expired_data/Test delete expired data with body parameters} + issue: https://github.com/elastic/elasticsearch/issues/131364 - class: org.elasticsearch.packaging.test.DockerTests - method: test082CannotUseEnvVarsAndFiles - issue: https://github.com/elastic/elasticsearch/issues/129808 + method: test070BindMountCustomPathConfAndJvmOptions + issue: https://github.com/elastic/elasticsearch/issues/131366 +- class: org.elasticsearch.packaging.test.DockerTests + method: test140CgroupOsStatsAreAvailable + issue: https://github.com/elastic/elasticsearch/issues/131372 +- class: org.elasticsearch.packaging.test.DockerTests + method: test130JavaHasCorrectOwnership + issue: https://github.com/elastic/elasticsearch/issues/131369 +- class: org.elasticsearch.xpack.downsample.DataStreamLifecycleDownsampleDisruptionIT + method: testDataStreamLifecycleDownsampleRollingRestart + issue: https://github.com/elastic/elasticsearch/issues/131394 +- class: org.elasticsearch.packaging.test.DockerTests + method: test072RunEsAsDifferentUserAndGroup + issue: https://github.com/elastic/elasticsearch/issues/131412 +- class: org.elasticsearch.xpack.esql.heap_attack.HeapAttackIT + method: testLookupExplosionNoFetch + issue: https://github.com/elastic/elasticsearch/issues/128720 +- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT + method: test {p0=vector-tile/20_aggregations/stats agg} + issue: https://github.com/elastic/elasticsearch/issues/131484 +- class: org.elasticsearch.packaging.test.DockerTests + method: test050BasicApiTests + issue: https://github.com/elastic/elasticsearch/issues/120911 +- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT + method: testFromEvalStats + issue: https://github.com/elastic/elasticsearch/issues/131503 +- class: org.elasticsearch.xpack.downsample.DownsampleWithBasicRestIT + method: test {p0=downsample-with-security/10_basic/Downsample index} + issue: https://github.com/elastic/elasticsearch/issues/131513 +- class: org.elasticsearch.xpack.search.CrossClusterAsyncSearchIT + method: testCancellationViaTimeoutWithAllowPartialResultsSetToFalse + issue: https://github.com/elastic/elasticsearch/issues/131248 +- class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT + method: test + issue: https://github.com/elastic/elasticsearch/issues/131508 +- class: org.elasticsearch.action.admin.cluster.node.tasks.CancellableTasksIT + method: testRemoveBanParentsOnDisconnect + issue: https://github.com/elastic/elasticsearch/issues/131562 +- class: org.elasticsearch.xpack.esql.action.CrossClusterQueryWithPartialResultsIT + method: testPartialResults + issue: https://github.com/elastic/elasticsearch/issues/131481 +- class: org.elasticsearch.packaging.test.DockerTests + method: test010Install + issue: https://github.com/elastic/elasticsearch/issues/131376 +- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT + method: test {p0=search/40_indices_boost/Indices boost with alias} + issue: https://github.com/elastic/elasticsearch/issues/131598 # Examples: # diff --git a/plugins/discovery-ec2/build.gradle b/plugins/discovery-ec2/build.gradle index 454508d0298f9..f4eb1d3a90f01 100644 --- a/plugins/discovery-ec2/build.gradle +++ b/plugins/discovery-ec2/build.gradle @@ -15,31 +15,6 @@ esplugin { classname ='org.elasticsearch.discovery.ec2.Ec2DiscoveryPlugin' } -def patched = Attribute.of('patched', Boolean) - -configurations { - compileClasspath { - attributes { - attribute(patched, true) - } - } - runtimeClasspath { - attributes { - attribute(patched, true) - } - } - testCompileClasspath { - attributes { - attribute(patched, true) - } - } - testRuntimeClasspath { - attributes { - attribute(patched, true) - } - } -} - dependencies { implementation "software.amazon.awssdk:annotations:${versions.awsv2sdk}" @@ -90,17 +65,6 @@ dependencies { testImplementation project(':test:fixtures:ec2-imds-fixture') internalClusterTestImplementation project(':test:fixtures:ec2-imds-fixture') - - attributesSchema { - attribute(patched) - } - artifactTypes.getByName("jar") { - attributes.attribute(patched, false) - } - registerTransform(org.elasticsearch.gradle.internal.dependencies.patches.awsv2sdk.Awsv2ClassPatcher) { - from.attribute(patched, false) - to.attribute(patched, true) - } } tasks.named("dependencyLicenses").configure { diff --git a/qa/packaging/src/test/java/org/elasticsearch/packaging/test/PackagingTestCase.java b/qa/packaging/src/test/java/org/elasticsearch/packaging/test/PackagingTestCase.java index a157cc84e624e..31cd1f3a36879 100644 --- a/qa/packaging/src/test/java/org/elasticsearch/packaging/test/PackagingTestCase.java +++ b/qa/packaging/src/test/java/org/elasticsearch/packaging/test/PackagingTestCase.java @@ -145,6 +145,10 @@ public abstract class PackagingTestCase extends Assert { @Override protected void failed(Throwable e, Description description) { failed = true; + if (installation != null && installation.distribution.isDocker()) { + logger.warn("Test {} failed. Printing logs for failed test...", description.getMethodName()); + FileUtils.logAllLogs(installation.logs, logger); + } } }; diff --git a/qa/packaging/src/test/java/org/elasticsearch/packaging/util/ServerUtils.java b/qa/packaging/src/test/java/org/elasticsearch/packaging/util/ServerUtils.java index 2b1c9ed140ee3..1f7e984791b51 100644 --- a/qa/packaging/src/test/java/org/elasticsearch/packaging/util/ServerUtils.java +++ b/qa/packaging/src/test/java/org/elasticsearch/packaging/util/ServerUtils.java @@ -151,7 +151,19 @@ private static HttpResponse execute(Request request, String username, String pas executor.auth(username, password); executor.authPreemptive(new HttpHost("localhost", 9200)); } - return executor.execute(request).returnResponse(); + try { + return executor.execute(request).returnResponse(); + } catch (Exception e) { + logger.warn( + "Failed to execute request [{}] with username/password [{}/{}] and caCert [{}]", + request.toString(), + username, + password, + caCert, + e + ); + throw e; + } } // polls every two seconds for Elasticsearch to be running on 9200 @@ -239,14 +251,13 @@ public static void waitForElasticsearch( long timeElapsed = 0; boolean started = false; Throwable thrownException = null; - if (caCert == null) { - caCert = getCaCert(installation); - } while (started == false && timeElapsed < waitTime) { if (System.currentTimeMillis() - lastRequest > requestInterval) { + if (caCert == null) { + caCert = getCaCert(installation); + } try { - final HttpResponse response = execute( Request.Get((caCert != null ? "https" : "http") + "://localhost:9200/_cluster/health") .connectTimeout((int) timeoutLength) @@ -277,7 +288,7 @@ public static void waitForElasticsearch( } started = true; - } catch (IOException e) { + } catch (Exception e) { if (thrownException == null) { thrownException = e; } else { diff --git a/qa/packaging/src/test/java/org/elasticsearch/packaging/util/docker/Docker.java b/qa/packaging/src/test/java/org/elasticsearch/packaging/util/docker/Docker.java index ab167d7663be1..a17ae7781db48 100644 --- a/qa/packaging/src/test/java/org/elasticsearch/packaging/util/docker/Docker.java +++ b/qa/packaging/src/test/java/org/elasticsearch/packaging/util/docker/Docker.java @@ -73,7 +73,7 @@ public class Docker { public static final Shell sh = new Shell(); public static final DockerShell dockerShell = new DockerShell(); public static final int STARTUP_SLEEP_INTERVAL_MILLISECONDS = 1000; - public static final int STARTUP_ATTEMPTS_MAX = 30; + public static final int STARTUP_ATTEMPTS_MAX = 45; /** * The length of the command exceeds what we can use for COLUMNS so we use diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 59cd581f9dc0e..f51c550e5292e 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -29,7 +29,7 @@ * This class encapsulates all the parameters required to run the KNN index tests. */ record CmdLineArgs( - Path docVectors, + List docVectors, Path queryVectors, int numDocs, int numQueries, @@ -88,7 +88,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { static final ObjectParser PARSER = new ObjectParser<>("cmd_line_args", true, Builder::new); static { - PARSER.declareString(Builder::setDocVectors, DOC_VECTORS_FIELD); + PARSER.declareStringArray(Builder::setDocVectors, DOC_VECTORS_FIELD); PARSER.declareString(Builder::setQueryVectors, QUERY_VECTORS_FIELD); PARSER.declareInt(Builder::setNumDocs, NUM_DOCS_FIELD); PARSER.declareInt(Builder::setNumQueries, NUM_QUERIES_FIELD); @@ -118,7 +118,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); if (docVectors != null) { - builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectors.toString()); + List docVectorsStrings = docVectors.stream().map(Path::toString).toList(); + builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectorsStrings); } if (queryVectors != null) { builder.field(QUERY_VECTORS_FIELD.getPreferredName(), queryVectors.toString()); @@ -154,7 +155,7 @@ public String toString() { } static class Builder { - private Path docVectors; + private List docVectors; private Path queryVectors; private int numDocs = 1000; private int numQueries = 100; @@ -179,8 +180,12 @@ static class Builder { private float filterSelectivity = 1f; private long seed = 1751900822751L; - public Builder setDocVectors(String docVectors) { - this.docVectors = PathUtils.get(docVectors); + public Builder setDocVectors(List docVectors) { + if (docVectors == null || docVectors.isEmpty()) { + throw new IllegalArgumentException("Document vectors path must be provided"); + } + // Convert list of strings to list of Paths + this.docVectors = docVectors.stream().map(PathUtils::get).toList(); return this; } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index fe20f895d3ea9..c4b0ccdfe35e3 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -83,7 +83,7 @@ private static String formatIndexPath(CmdLineArgs args) { suffix.add(Integer.toString(args.quantizeBits())); } } - return INDEX_DIR + "/" + args.docVectors().getFileName() + "-" + String.join("-", suffix) + ".index"; + return INDEX_DIR + "/" + args.docVectors().get(0).getFileName() + "-" + String.join("-", suffix) + ".index"; } static Codec createCodec(CmdLineArgs args) { @@ -137,7 +137,7 @@ public static void main(String[] args) throws Exception { System.out.println( Strings.toString( new CmdLineArgs.Builder().setDimensions(64) - .setDocVectors("/doc/vectors/path") + .setDocVectors(List.of("/doc/vectors/path")) .setQueryVectors("/query/vectors/path") .build(), true, @@ -179,7 +179,7 @@ public static void main(String[] args) throws Exception { : new int[] { 0 }; String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT); Results indexResults = new Results( - cmdLineArgs.docVectors().getFileName().toString(), + cmdLineArgs.docVectors().get(0).getFileName().toString(), indexType, cmdLineArgs.numDocs(), cmdLineArgs.filterSelectivity() @@ -187,7 +187,7 @@ public static void main(String[] args) throws Exception { Results[] results = new Results[nProbes.length]; for (int i = 0; i < nProbes.length; i++) { results[i] = new Results( - cmdLineArgs.docVectors().getFileName().toString(), + cmdLineArgs.docVectors().get(0).getFileName().toString(), indexType, cmdLineArgs.numDocs(), cmdLineArgs.filterSelectivity() diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java index 40eb8424aeb1b..f7d00c9806c8d 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java @@ -61,7 +61,7 @@ class KnnIndexer { static final String ID_FIELD = "id"; static final String VECTOR_FIELD = "vector"; - private final Path docsPath; + private final List docsPath; private final Path indexPath; private final VectorEncoding vectorEncoding; private int dim; @@ -71,7 +71,7 @@ class KnnIndexer { private final int numIndexThreads; KnnIndexer( - Path docsPath, + List docsPath, Path indexPath, Codec codec, int numIndexThreads, @@ -127,57 +127,70 @@ public boolean isEnabled(String component) { } long start = System.nanoTime(); - try ( - FSDirectory dir = FSDirectory.open(indexPath); - IndexWriter iw = new IndexWriter(dir, iwc); - FileChannel in = FileChannel.open(docsPath) - ) { - long docsPathSizeInBytes = in.size(); - int offsetByteSize = 0; - if (dim == -1) { - offsetByteSize = 4; - ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); - int bytesRead = Channels.readFromFileChannel(in, 0, preamble); - if (bytesRead < 4) { - throw new IllegalArgumentException( - "docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes + AtomicInteger numDocsIndexed = new AtomicInteger(); + try (FSDirectory dir = FSDirectory.open(indexPath); IndexWriter iw = new IndexWriter(dir, iwc);) { + for (Path docsPath : this.docsPath) { + int dim = this.dim; + try (FileChannel in = FileChannel.open(docsPath)) { + long docsPathSizeInBytes = in.size(); + int offsetByteSize = 0; + if (dim == -1) { + offsetByteSize = 4; + ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + int bytesRead = Channels.readFromFileChannel(in, 0, preamble); + if (bytesRead < 4) { + throw new IllegalArgumentException( + "docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes + ); + } + dim = preamble.getInt(0); + if (dim <= 0) { + throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim); + } + } + FieldType fieldType = switch (vectorEncoding) { + case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction); + case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction); + }; + if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) { + throw new IllegalArgumentException( + "docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes + ); + } + int numDocs = (int) (docsPathSizeInBytes / ((long) dim * vectorEncoding.byteSize + offsetByteSize)); + numDocs = Math.min(this.numDocs - numDocsIndexed.get(), numDocs); + if (numDocs <= 0) { + break; + } + logger.info( + "path={}, docsPathSizeInBytes={}, numDocs={}, dim={}, vectorEncoding={}, byteSize={}", + docsPath, + docsPathSizeInBytes, + numDocs, + dim, + vectorEncoding, + vectorEncoding.byteSize ); - } - dim = preamble.getInt(0); - if (dim <= 0) { - throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim); - } - } - FieldType fieldType = switch (vectorEncoding) { - case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction); - case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction); - }; - if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) { - throw new IllegalArgumentException( - "docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes - ); - } - logger.info( - "docsPathSizeInBytes={}, dim={}, vectorEncoding={}, byteSize={}", - docsPathSizeInBytes, - dim, - vectorEncoding, - vectorEncoding.byteSize - ); + // adjust numDocs to account for the number of documents already indexed + // numDocsIndexed tracks the total docs read in order and is used for docIds + // numDocs is the total number of docs to index from this file + numDocs += numDocsIndexed.get(); - VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize); - try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) { - AtomicInteger numDocsIndexed = new AtomicInteger(); - List> threads = new ArrayList<>(); - for (int i = 0; i < numIndexThreads; i++) { - Thread t = new IndexerThread(iw, inReader, dim, vectorEncoding, fieldType, numDocsIndexed, numDocs); - t.setDaemon(true); - threads.add(exec.submit(t)); - } - for (Future t : threads) { - t.get(); + VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize); + try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) { + List> threads = new ArrayList<>(); + for (int i = 0; i < numIndexThreads; i++) { + Thread t = new IndexerThread(iw, inReader, dim, vectorEncoding, fieldType, numDocsIndexed, numDocs); + t.setDaemon(true); + threads.add(exec.submit(t)); + } + for (Future t : threads) { + t.get(); + } + } } } + logger.info("KnnIndexer: indexed {} documents of desired {} numDocs", numDocsIndexed, numDocs); logger.debug("all indexing threads finished, now IndexWriter.commit()"); iw.commit(); ConcurrentMergeScheduler cms = (ConcurrentMergeScheduler) iwc.getMergeScheduler(); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index fb84df66b0138..bb13dd75a4d9e 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -98,7 +98,7 @@ class KnnSearcher { - private final Path docPath; + private final List docPath; private final Path indexPath; private final Path queryPath; private final int numDocs; @@ -153,12 +153,6 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th : null ) { long queryPathSizeInBytes = input.size(); - logger.info( - "queryPath size: " - + queryPathSizeInBytes - + " bytes, assuming vector count is " - + (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize)) - ); if (dim == -1) { offsetByteSize = 4; ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); @@ -171,6 +165,17 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim); } } + if (queryPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) { + throw new IllegalArgumentException( + "docsPath \"" + queryPath + "\" does not contain a whole number of vectors? size=" + queryPathSizeInBytes + ); + } + logger.info( + "queryPath size: " + + queryPathSizeInBytes + + " bytes, assuming vector count is " + + (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize + offsetByteSize)) + ); KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize); long startNS; try (MMapDirectory dir = new MMapDirectory(indexPath)) { @@ -368,8 +373,13 @@ private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQue } } - private boolean isNewer(Path path, Path... others) throws IOException { + private boolean isNewer(Path path, List paths, Path... others) throws IOException { FileTime modified = Files.getLastModifiedTime(path); + for (Path p : paths) { + if (Files.getLastModifiedTime(p).compareTo(modified) >= 0) { + return false; + } + } for (Path other : others) { if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) { return false; diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/cluster.allocation_explain.json b/rest-api-spec/src/main/resources/rest-api-spec/api/cluster.allocation_explain.json index a3922033ec2a8..af9cbbc2c049a 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/cluster.allocation_explain.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/cluster.allocation_explain.json @@ -22,6 +22,22 @@ ] }, "params":{ + "index": { + "type": "string", + "description": "Specifies the name of the index that you would like an explanation for" + }, + "shard": { + "type": "number", + "description": "Specifies the ID of the shard that you would like an explanation for" + }, + "primary": { + "type":"boolean", + "description":"If true, returns explanation for the primary shard for the given shard ID" + }, + "current_node": { + "type": "string", + "description": "Specifies the node ID or the name of the node to only explain a shard that is currently located on the specified node" + }, "master_timeout":{ "type":"time", "description":"Timeout for connection to master node" diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/create.json b/rest-api-spec/src/main/resources/rest-api-spec/api/create.json index 65cb0da4753cc..88b040d1dbee3 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/create.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/create.json @@ -73,6 +73,14 @@ "include_source_on_error": { "type": "boolean", "description": "True or false if to include the document source in the error message in case of parsing errors. Defaults to true." + }, + "require_alias":{ + "type":"boolean", + "description":"When true, requires destination to be an alias. Default is false" + }, + "require_data_stream":{ + "type":"boolean", + "description":"When true, requires destination to be a data stream (existing or to be created). Default is false" } }, "body":{ diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/eql.search.json b/rest-api-spec/src/main/resources/rest-api-spec/api/eql.search.json index 0b1a7ad5a38d3..9164a15be4cc1 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/eql.search.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/eql.search.json @@ -51,6 +51,31 @@ "type":"boolean", "description":"Control whether a sequence query should return partial results or no results at all in case of shard failures. This option has effect only if [allow_partial_search_results] is true.", "default":false + }, + "ccs_minimize_roundtrips":{ + "type":"boolean", + "description":"Indicates whether network round-trips should be minimized as part of cross-cluster search requests execution", + "default":true + }, + "ignore_unavailable":{ + "type":"boolean", + "description":"Whether specified concrete indices should be ignored when unavailable (missing or closed)" + }, + "allow_no_indices":{ + "type":"boolean", + "description":"Whether to ignore if a wildcard indices expression resolves into no concrete indices. (This includes `_all` string or when no indices have been specified)" + }, + "expand_wildcards":{ + "type":"enum", + "options":[ + "open", + "closed", + "hidden", + "none", + "all" + ], + "default":"open", + "description":"Whether to expand wildcard expression to concrete indices that are open, closed or both." } }, "body":{ diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json b/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json index b1174b89df0bd..fe66541fa9b0b 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json @@ -41,6 +41,28 @@ "type":"boolean", "description":"Display only those recoveries that are currently on-going", "default":false + }, + "ignore_unavailable":{ + "type":"boolean", + "description":"Whether specified concrete indices should be ignored when unavailable (missing or closed)", + "default":false + }, + "allow_no_indices":{ + "type":"boolean", + "description":"Whether to ignore if a wildcard indices expression resolves into no concrete indices. (This includes `_all` string or when no indices have been specified)", + "default":true + }, + "expand_wildcards":{ + "type":"enum", + "options":[ + "open", + "closed", + "hidden", + "none", + "all" + ], + "default":"open", + "description":"Whether to expand wildcard expression to concrete indices that are open, closed or both." } } } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json b/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json index 359d1e67b07e5..0d8223a71c79d 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json @@ -69,6 +69,38 @@ "type":"boolean", "description":"Indicates whether network round-trips should be minimized as part of cross-cluster search requests execution", "default":"true" + }, + "index":{ + "type":"list", + "description":"A comma-separated list of index names to use as default" + }, + "ignore_unavailable":{ + "type":"boolean", + "description":"Whether specified concrete indices should be ignored when unavailable (missing or closed)" + }, + "ignore_throttled":{ + "type":"boolean", + "description":"Whether specified concrete, expanded or aliased indices should be ignored when throttled", + "deprecated":true + }, + "allow_no_indices":{ + "type":"boolean", + "description":"Whether to ignore if a wildcard indices expression resolves into no concrete indices. (This includes `_all` string or when no indices have been specified)" + }, + "expand_wildcards":{ + "type":"enum", + "options": ["open", "closed", "hidden", "none", "all"], + "default":"open", + "description":"Whether to expand wildcard expression to concrete indices that are open, closed or both." + }, + "routing":{ + "type":"list", + "description":"A comma-separated list of specific routing values" + }, + "include_named_queries_score":{ + "type":"boolean", + "description":"Indicates whether hit.matched_queries should be rendered as a map that includes the name of the matched query associated with its score (true) or as an array containing the name of the matched queries (false)", + "default": false } }, "body":{ diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/reindex.json b/rest-api-spec/src/main/resources/rest-api-spec/api/reindex.json index f8038853e4731..c28e2f8417883 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/reindex.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/reindex.json @@ -57,6 +57,11 @@ "max_docs":{ "type":"number", "description":"Maximum number of documents to process (default: all documents)" + }, + "require_alias":{ + "type":"boolean", + "default":false, + "description":"When true, requires destination to be an alias." } }, "body":{ diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/search_mvt.json b/rest-api-spec/src/main/resources/rest-api-spec/api/search_mvt.json index 35ebe3b3f1d16..faa67132f3aca 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/search_mvt.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/search_mvt.json @@ -73,6 +73,15 @@ "description":"Determines the geometry type for features in the aggs layer.", "default":"grid" }, + "grid_agg":{ + "type":"enum", + "options":[ + "geotile", + "geohex" + ], + "description":"Aggregation used to create a grid for `field`.", + "default":"geotile" + }, "size":{ "type":"int", "description":"Maximum number of features to return in the hits layer. Accepts 0-10000.", diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/searchable_snapshots.clear_cache.json b/rest-api-spec/src/main/resources/rest-api-spec/api/searchable_snapshots.clear_cache.json index d2d7000195c04..8b39cdcb24218 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/searchable_snapshots.clear_cache.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/searchable_snapshots.clear_cache.json @@ -50,10 +50,6 @@ ], "default": "open", "description": "Whether to expand wildcard expression to concrete indices that are open, closed or both." - }, - "index": { - "type": "list", - "description": "A comma-separated list of index name to limit the operation" } } } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/snapshot.repository_analyze.json b/rest-api-spec/src/main/resources/rest-api-spec/api/snapshot.repository_analyze.json index 2578cd5684d6d..0b2bb66c709d7 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/snapshot.repository_analyze.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/snapshot.repository_analyze.json @@ -36,6 +36,10 @@ "type":"number", "description":"Number of operations to run concurrently during the test. Defaults to 10." }, + "register_operation_count":{ + "type":"number", + "description":"The minimum number of linearizable register operations to perform in total. Defaults to 10." + }, "read_node_count":{ "type":"number", "description":"Number of nodes on which to read a blob after writing. Defaults to 10." diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml index 1a1fee6e04559..1f5e9d80b7702 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml @@ -25,8 +25,7 @@ - is_true: can_allocate --- -# This test has a valid integer input, but it's above the shard limit, so the index cannot be located -"cluster shard allocation explanation test with max integer shard value": +"cluster shard allocation explanation test with only index provided in the body": - do: indices.create: index: test @@ -34,12 +33,12 @@ - match: { acknowledged: true } - do: - catch: /shard_not_found_exception/ + catch: /action_request_validation_exception/ cluster.allocation_explain: - body: { "index": "test", "shard": 2147483647, "primary": true } + body: { "index": "test"} --- -"cluster shard allocation explanation test with long shard value": +"cluster shard allocation explanation test with only shard provided in the body": - do: indices.create: index: test @@ -47,50 +46,105 @@ - match: { acknowledged: true } - do: - catch: /x_content_parse_exception/ + catch: /action_request_validation_exception/ cluster.allocation_explain: - body: { "index": "test", "shard": 214748364777, "primary": true } + body: { "shard": 0} --- -"cluster shard allocation explanation test with float shard value": +"cluster shard allocation explanation test with only primary provided in the body": - do: indices.create: index: test - body: { "settings": { "index.number_of_shards": 2, "index.number_of_replicas": 0 } } - match: { acknowledged: true } - do: + catch: /action_request_validation_exception/ cluster.allocation_explain: - body: { "index": "test", "shard": 1.0, "primary": true } + body: { "primary": true} - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 1 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation +--- +"cluster shard allocation explanation test with only index and shard provided in the body": + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + body: { "index": "test", "shard": 0} --- -"cluster shard allocation explanation test with double shard value": +"cluster shard allocation explanation test with only shard and primary provided in the body": - do: indices.create: index: test - body: { "settings": { "index.number_of_shards": 2, "index.number_of_replicas": 0 } } - match: { acknowledged: true } - do: + catch: /action_request_validation_exception/ cluster.allocation_explain: - body: { "index": "test", "shard": 1.1234567891234567, "primary": true } + body: { "shard": 0, "primary": true } + +--- +"cluster shard allocation explanation test with only index and primary provided in the body": + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + body: { "index": "test", "primary": true } + +--- +"cluster shard allocation explanation test with incorrect index parameter in the body": + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /index_not_found_exception/ + cluster.allocation_explain: + body: { "index": "test2", "shard": 0, "primary": true } + +--- +# This test has a valid integer input, but it's above the shard limit, so the index cannot be located +"cluster shard allocation explanation test with max integer shard value": + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /shard_not_found_exception/ + cluster.allocation_explain: + body: { "index": "test", "shard": 2147483647, "primary": true } + +--- +"cluster shard allocation explanation test with three valid body parameters": + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + cluster.allocation_explain: + body: { "index": "test", "shard": 0, "primary": true } - match: { current_state: "started" } - is_true: current_node.id - match: { index: "test" } - - match: { shard: 1 } + - match: { shard: 0 } - match: { primary: true } - is_true: can_remain_on_current_node - is_true: can_rebalance_cluster @@ -98,7 +152,7 @@ - is_true: rebalance_explanation --- -"cluster shard allocation explanation test with three valid body parameters": +"cluster shard allocation explanation test with 3 body parameters and all query parameters": - do: indices.create: index: test @@ -108,6 +162,9 @@ - do: cluster.allocation_explain: body: { "index": "test", "shard": 0, "primary": true } + include_disk_info: true + include_yes_decisions: true + master_timeout: 0 - match: { current_state: "started" } - is_true: current_node.id @@ -118,6 +175,9 @@ - is_true: can_rebalance_cluster - is_true: can_rebalance_to_other_node - is_true: rebalance_explanation + # Modified by the existing of the query parameters + - is_true: cluster_info + - is_false: note --- "Cluster shard allocation explanation test with a closed index": @@ -180,3 +240,275 @@ body: { "index": "test", "shard": 0, "primary": true } - is_true: current_node.roles + +# These tests were added as part of https://github.com/elastic/elasticsearch/issues/127028 which added support +# for path parameters alongside body parameters in 9.2.0 + +--- +"cluster shard allocation explanation test with empty body and no URL parameters": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + body: { "settings": { "index.number_of_shards": 1, "index.number_of_replicas": 9 } } + + - do: + cluster.allocation_explain: {} + + - match: { current_state: "unassigned" } + - match: { unassigned_info.reason: "INDEX_CREATED" } + - is_true: unassigned_info.at + - match: { index: "test" } + - match: { shard: 0 } + - match: { primary: false } + +--- +"cluster shard allocation explanation test with only index provided in URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + index: "test" + +--- +"cluster shard allocation explanation test with only shard provided in URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + shard: 0 + +--- +"cluster shard allocation explanation test with only primary provided in URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + primary: true + +--- +"cluster shard allocation explanation test with only index and shard provided in URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + index: "test" + shard: 0 + +--- +"cluster shard allocation explanation test with only shard and primary provided in the URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + shard: 0 + primary: true + +--- +"cluster shard allocation explanation test with only index and primary provided in the URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /action_request_validation_exception/ + cluster.allocation_explain: + index: "test" + primary: true + +--- +"cluster shard allocation explanation test with 3 parameters in the URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + cluster.allocation_explain: + index: "test" + shard: 0 + primary: true + + - match: { current_state: "started" } + - is_true: current_node.id + - match: { index: "test" } + - match: { shard: 0 } + - match: { primary: true } + - is_true: can_remain_on_current_node + - is_true: can_rebalance_cluster + - is_true: can_rebalance_to_other_node + - is_true: rebalance_explanation + +--- +"cluster shard allocation explanation test with all parameters in the URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + cluster.allocation_explain: + index: "test" + shard: 0 + primary: true + include_disk_info: true + include_yes_decisions: true + master_timeout: 0 + + - match: { current_state: "started" } + - is_true: current_node.id + - match: { index: "test" } + - match: { shard: 0 } + - match: { primary: true } + - is_true: can_remain_on_current_node + - is_true: can_rebalance_cluster + - is_true: can_rebalance_to_other_node + - is_true: rebalance_explanation + # Modified by the existing of the query parameters + - is_true: cluster_info + - is_false: note + +--- +"cluster shard allocation explanation test with parameters passed in both the body and URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /illegal_argument_exception/ + cluster.allocation_explain: + index: "test" + body: { "shard": 0, "primary": true } + +--- +"cluster shard allocation explanation test with incorrect index parameter passed in URL": + - requires: + capabilities: + - method: GET + path: /_cluster/allocation/explain + capabilities: [ query_parameter_support ] + test_runner_features: [ capabilities ] + reason: "Query parameter support was added in version 9.2.0" + + - do: + indices.create: + index: test + + - match: { acknowledged: true } + + - do: + catch: /index_not_found_exception/ + cluster.allocation_explain: + index: "test2" + shard: 0 + primary: true diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml index 8ad06910ebe4d..0724f3831aeab 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml @@ -17,6 +17,9 @@ indices.recovery: index: [test_1] human: true + ignore_unavailable: false + allow_no_indices: true + expand_wildcards: open - match: { test_1.shards.0.type: "EMPTY_STORE" } - match: { test_1.shards.0.stage: "DONE" } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml index 1052508ca2b88..8ac4ee60f2bbc 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml @@ -1,5 +1,7 @@ --- setup: + - requires: + test_runner_features: allowed_warnings - do: index: @@ -67,6 +69,12 @@ setup: rest_total_hits_as_int: true max_concurrent_shard_requests: 1 max_concurrent_searches: 1 + ignore_unavailable: false + ignore_throttled: false + allow_no_indices: false + expand_wildcards: open + include_named_queries_score: false + index: index_* body: - index: index_* - query: @@ -83,6 +91,8 @@ setup: - {} - query: match_all: {} + allowed_warnings: + - "[ignore_throttled] parameter is deprecated because frozen indices have been deprecated. Consider cold or frozen tiers in place of frozen indices." - match: { responses.0.hits.total: 2 } - match: { responses.1.hits.total: 1 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/40_routing.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/40_routing.yml new file mode 100644 index 0000000000000..5b69a4da98418 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/40_routing.yml @@ -0,0 +1,25 @@ +--- +setup: + - do: + index: + index: index_1 + routing: "1" + id: "1" + body: { foo: bar } + + - do: + indices.refresh: {} + +--- +"Routing": + + - do: + msearch: + rest_total_hits_as_int: true + routing: "1" + body: + - {} + - query: + match_all: {} + + - match: { responses.0.hits.total: 1 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml index 047255818df64..fe19a9b8578fb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml @@ -1,7 +1,7 @@ setup: - requires: - cluster_features: "mapper.ivf_nested_support" - reason: 'ivf nested support required' + cluster_features: "mapper.bbq_disk_support" + reason: 'bbq disk support required' - do: indices.create: index: test @@ -24,7 +24,7 @@ setup: index: true similarity: l2_norm index_options: - type: bbq_ivf + type: bbq_disk aliases: my_alias: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index 35e59b0e4e31a..e617d08940f84 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -670,3 +670,36 @@ setup: properties: embedding: type: dense_vector + + +--- +"Searching with no data dimensions specified": + - requires: + cluster_features: "search.vectors.no_dimensions_bugfix" + reason: "Search with no dimensions bugfix" + + - do: + indices.create: + index: empty-test + body: + mappings: + properties: + vector: + type: dense_vector + index: true + + - do: + search: + index: empty-test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + similarity: 0.1 + + - match: { hits.total.value: 0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml index 1f3c07be2942e..3ce9232fc4ecd 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml @@ -1,12 +1,12 @@ setup: - requires: - cluster_features: ["mapper.ivf_format_cluster_feature"] - reason: Needs mapper.ivf_format_cluster_feature feature + cluster_features: ["mapper.bbq_disk_support"] + reason: Needs mapper.bbq_disk_support feature - skip: features: "headers" - do: indices.create: - index: bbq_ivf + index: bbq_disk body: settings: index: @@ -19,11 +19,11 @@ setup: index: true similarity: max_inner_product index_options: - type: bbq_ivf + type: bbq_disk - do: index: - index: bbq_ivf + index: bbq_disk id: "1" body: vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, @@ -37,11 +37,11 @@ setup: # Flush in order to provoke a merge later - do: indices.flush: - index: bbq_ivf + index: bbq_disk - do: index: - index: bbq_ivf + index: bbq_disk id: "2" body: vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, @@ -55,11 +55,11 @@ setup: # Flush in order to provoke a merge later - do: indices.flush: - index: bbq_ivf + index: bbq_disk - do: index: - index: bbq_ivf + index: bbq_disk id: "3" body: name: rabbit.jpg @@ -74,11 +74,11 @@ setup: # Flush in order to provoke a merge later - do: indices.flush: - index: bbq_ivf + index: bbq_disk - do: indices.forcemerge: - index: bbq_ivf + index: bbq_disk max_num_segments: 1 - do: @@ -87,7 +87,7 @@ setup: "Test knn search": - do: search: - index: bbq_ivf + index: bbq_disk body: knn: field: vector @@ -116,7 +116,7 @@ setup: Content-Type: application/json search: rest_total_hits_as_int: true - index: bbq_ivf + index: bbq_disk body: knn: field: vector @@ -182,7 +182,7 @@ setup: element_type: byte index: true index_options: - type: bbq_ivf + type: bbq_disk - do: catch: bad_request @@ -196,7 +196,7 @@ setup: dims: 64 index: false index_options: - type: bbq_ivf + type: bbq_disk --- "Test index configured rescore vector": - skip: @@ -216,7 +216,7 @@ setup: index: true similarity: max_inner_product index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 1.5 @@ -298,7 +298,7 @@ setup: vector: type: dense_vector index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 0 @@ -314,7 +314,7 @@ setup: vector: type: dense_vector index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 1 @@ -326,7 +326,7 @@ setup: vector: type: dense_vector index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 0 @@ -354,7 +354,7 @@ setup: index: true similarity: max_inner_product index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 0 @@ -430,7 +430,7 @@ setup: index: true similarity: max_inner_product index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 2 @@ -470,7 +470,7 @@ setup: index: true similarity: max_inner_product index_options: - type: bbq_ivf + type: bbq_disk rescore_vector: oversample: 0 @@ -509,6 +509,6 @@ setup: "default oversample value": - do: indices.get_mapping: - index: bbq_ivf + index: bbq_disk - - match: { bbq_ivf.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } + - match: { bbq_disk.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/40_synonyms_sets_get.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/40_synonyms_sets_get.yml index 9d6540c118ce5..9ba66d3100eb1 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/40_synonyms_sets_get.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/40_synonyms_sets_get.yml @@ -157,3 +157,33 @@ setup: - match: count: 12 + +--- +"Return empty rule set": + - requires: + cluster_features: [ synonyms_set.get.return_empty_synonym_sets ] + reason: "synonyms_set get api return empty synonym sets" + + - do: + synonyms.put_synonym: + id: empty-synonyms + body: + synonyms_set: [] + + - do: + synonyms.get_synonyms_sets: {} + + - match: + count: 4 + + - match: + results: + - synonyms_set: "empty-synonyms" + count: 0 + - synonyms_set: "test-synonyms-1" + count: 3 + - synonyms_set: "test-synonyms-2" + count: 1 + - synonyms_set: "test-synonyms-3" + count: 2 + diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionIT.java index 4c1324cb0378a..cabe428a7487c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionIT.java @@ -24,6 +24,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.cluster.metadata.ComponentTemplate; import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.Template; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.settings.Settings; @@ -85,7 +86,8 @@ public void testMappingValidationIndexExists() { assertThat(searchResponse.getHits().getTotalHits().value(), equalTo(0L)); searchResponse.decRef(); ClusterStateResponse clusterStateResponse = admin().cluster().state(new ClusterStateRequest(TEST_REQUEST_TIMEOUT)).actionGet(); - Map indexMapping = clusterStateResponse.getState().metadata().getProject().index(indexName).mapping().sourceAsMap(); + final var project = clusterStateResponse.getState().metadata().getProject(ProjectId.DEFAULT); + Map indexMapping = project.index(indexName).mapping().sourceAsMap(); Map fields = (Map) indexMapping.get("properties"); assertThat(fields.size(), equalTo(1)); } @@ -142,7 +144,8 @@ public void testMappingValidationIndexExistsTemplateSubstitutions() throws IOExc assertThat(searchResponse.getHits().getTotalHits().value(), equalTo(0L)); searchResponse.decRef(); ClusterStateResponse clusterStateResponse = admin().cluster().state(new ClusterStateRequest(TEST_REQUEST_TIMEOUT)).actionGet(); - Map indexMapping = clusterStateResponse.getState().metadata().getProject().index(indexName).mapping().sourceAsMap(); + final var project = clusterStateResponse.getState().metadata().getProject(ProjectId.DEFAULT); + Map indexMapping = project.index(indexName).mapping().sourceAsMap(); Map fields = (Map) indexMapping.get("properties"); assertThat(fields.size(), equalTo(1)); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java index 5f86111d352a9..5d5f2082fb71f 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java @@ -60,13 +60,15 @@ public void testElectOnlyBetweenMasterNodes() throws Exception { logger.info("--> start master node (1)"); final String masterNodeName = internalCluster().startMasterOnlyNode(); - awaitMasterNode(internalCluster().getNonMasterNodeName(), masterNodeName); - awaitMasterNode(internalCluster().getMasterName(), masterNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, masterNodeName); + } logger.info("--> start master node (2)"); final String nextMasterEligableNodeName = internalCluster().startMasterOnlyNode(); - awaitMasterNode(internalCluster().getNonMasterNodeName(), masterNodeName); - awaitMasterNode(internalCluster().getMasterName(), masterNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, masterNodeName); + } logger.info("--> closing master node (1)"); client().execute( @@ -74,12 +76,14 @@ public void testElectOnlyBetweenMasterNodes() throws Exception { new AddVotingConfigExclusionsRequest(TEST_REQUEST_TIMEOUT, masterNodeName) ).get(); // removing the master from the voting configuration immediately triggers the master to step down - awaitMasterNode(internalCluster().getNonMasterNodeName(), nextMasterEligableNodeName); - awaitMasterNode(internalCluster().getMasterName(), nextMasterEligableNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, nextMasterEligableNodeName); + } internalCluster().stopNode(masterNodeName); - awaitMasterNode(internalCluster().getNonMasterNodeName(), nextMasterEligableNodeName); - awaitMasterNode(internalCluster().getMasterName(), nextMasterEligableNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, nextMasterEligableNodeName); + } } public void testAliasFilterValidation() { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/UnsafeBootstrapAndDetachCommandIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/UnsafeBootstrapAndDetachCommandIT.java index 3e7b1c37a421f..e29e816e936c5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/UnsafeBootstrapAndDetachCommandIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/UnsafeBootstrapAndDetachCommandIT.java @@ -24,6 +24,7 @@ import org.elasticsearch.gateway.PersistedClusterStateService; import org.elasticsearch.node.Node; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalTestCluster; import java.io.IOException; @@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.notNullValue; @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, autoManageMasterNodes = false) +@ESTestCase.WithoutEntitlements // CLI tools don't run with entitlements enforced public class UnsafeBootstrapAndDetachCommandIT extends ESIntegTestCase { private MockTerminal executeCommand(ElasticsearchNodeCommand command, Environment environment, boolean abort) throws Exception { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorIT.java index a4533a674fe70..bb19c9d477a45 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; import org.elasticsearch.cluster.DiskUsageIntegTestCase; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; @@ -86,20 +87,9 @@ public void testFloodStageExceeded() throws Exception { // Verify that we can still move shards around even while blocked final String newDataNodeName = internalCluster().startDataOnlyNode(); final String newDataNodeId = clusterAdmin().prepareNodesInfo(newDataNodeName).get().getNodes().get(0).getNode().getId(); - assertBusy(() -> { - final ShardRouting primaryShard = clusterAdmin().prepareState(TEST_REQUEST_TIMEOUT) - .clear() - .setRoutingTable(true) - .setNodes(true) - .setIndices(indexName) - .get() - .getState() - .routingTable() - .index(indexName) - .shard(0) - .primaryShard(); - assertThat(primaryShard.state(), equalTo(ShardRoutingState.STARTED)); - assertThat(primaryShard.currentNodeId(), equalTo(newDataNodeId)); + awaitClusterState(state -> { + final ShardRouting primaryShard = state.routingTable(ProjectId.DEFAULT).index(indexName).shard(0).primaryShard(); + return primaryShard.state() == ShardRoutingState.STARTED && newDataNodeId.equals(primaryShard.currentNodeId()); }); // Verify that the block is removed once the shard migration is complete diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java index 6dafab431500e..5d389ad5ef11a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java @@ -8,6 +8,8 @@ */ package org.elasticsearch.index.shard; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.store.LockObtainFailedException; import org.apache.lucene.util.SetOnce; @@ -22,13 +24,17 @@ import org.elasticsearch.cluster.EstimatedHeapUsage; import org.elasticsearch.cluster.EstimatedHeapUsageCollector; import org.elasticsearch.cluster.InternalClusterInfoService; +import org.elasticsearch.cluster.NodeUsageStatsForThreadPools; +import org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.routing.RecoverySource; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.UnassignedInfo; +import org.elasticsearch.cluster.routing.allocation.WriteLoadConstraintSettings; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; @@ -73,6 +79,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.IndexSettingsModule; import org.elasticsearch.test.InternalSettingsPlugin; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.junit.Assert; @@ -85,6 +92,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -97,6 +105,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.carrotsearch.randomizedtesting.RandomizedTest.randomAsciiLettersOfLength; @@ -117,14 +126,20 @@ import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; public class IndexShardIT extends ESSingleNodeTestCase { + private static final Logger logger = LogManager.getLogger(IndexShardIT.class); @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, BogusEstimatedHeapUsagePlugin.class); + return pluginList( + InternalSettingsPlugin.class, + BogusEstimatedHeapUsagePlugin.class, + BogusNodeUsageStatsForThreadPoolsCollectorPlugin.class + ); } public void testLockTryingToDelete() throws Exception { @@ -295,6 +310,109 @@ public void testHeapUsageEstimateIsPresent() { } } + public void testNodeWriteLoadsArePresent() { + InternalClusterInfoService clusterInfoService = (InternalClusterInfoService) getInstanceFromNode(ClusterInfoService.class); + ClusterInfoServiceUtils.refresh(clusterInfoService); + Map nodeThreadPoolStats = clusterInfoService.getClusterInfo() + .getNodeUsageStatsForThreadPools(); + assertNotNull(nodeThreadPoolStats); + /** Not collecting stats yet because allocation write load stats collection is disabled by default. + * see {@link WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING} */ + assertTrue(nodeThreadPoolStats.isEmpty()); + + // Enable collection for node write loads. + updateClusterSettings( + Settings.builder() + .put( + WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey(), + WriteLoadConstraintSettings.WriteLoadDeciderStatus.ENABLED + ) + .build() + ); + try { + // Force a ClusterInfo refresh to run collection of the node thread pool usage stats. + ClusterInfoServiceUtils.refresh(clusterInfoService); + nodeThreadPoolStats = clusterInfoService.getClusterInfo().getNodeUsageStatsForThreadPools(); + + /** Verify that each node has usage stats reported. The test {@link BogusNodeUsageStatsForThreadPoolsCollector} implementation + * generates random usage values */ + ClusterState state = getInstanceFromNode(ClusterService.class).state(); + assertEquals(state.nodes().size(), nodeThreadPoolStats.size()); + for (DiscoveryNode node : state.nodes()) { + assertTrue(nodeThreadPoolStats.containsKey(node.getId())); + NodeUsageStatsForThreadPools nodeUsageStatsForThreadPools = nodeThreadPoolStats.get(node.getId()); + assertThat(nodeUsageStatsForThreadPools.nodeId(), equalTo(node.getId())); + NodeUsageStatsForThreadPools.ThreadPoolUsageStats writeThreadPoolStats = nodeUsageStatsForThreadPools + .threadPoolUsageStatsMap() + .get(ThreadPool.Names.WRITE); + assertNotNull(writeThreadPoolStats); + assertThat(writeThreadPoolStats.totalThreadPoolThreads(), greaterThanOrEqualTo(0)); + assertThat(writeThreadPoolStats.averageThreadPoolUtilization(), greaterThanOrEqualTo(0.0f)); + assertThat(writeThreadPoolStats.averageThreadPoolQueueLatencyMillis(), greaterThanOrEqualTo(0L)); + } + } finally { + updateClusterSettings( + Settings.builder().putNull(WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey()).build() + ); + } + } + + public void testShardWriteLoadsArePresent() { + // Create some indices and some write-load + final int numIndices = randomIntBetween(1, 5); + final String indexPrefix = randomIdentifier(); + IntStream.range(0, numIndices).forEach(i -> { + final String indexName = indexPrefix + "_" + i; + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 3)).build()); + IntStream.range(0, randomIntBetween(1, 500)) + .forEach(j -> prepareIndex(indexName).setSource("foo", randomIdentifier(), "bar", randomIdentifier()).get()); + }); + + final InternalClusterInfoService clusterInfoService = (InternalClusterInfoService) getInstanceFromNode(ClusterInfoService.class); + + // Not collecting stats yet because allocation write load stats collection is disabled by default. + { + ClusterInfoServiceUtils.refresh(clusterInfoService); + final Map shardWriteLoads = clusterInfoService.getClusterInfo().getShardWriteLoads(); + assertNotNull(shardWriteLoads); + assertTrue(shardWriteLoads.isEmpty()); + } + + // Turn on collection of write-load stats. + updateClusterSettings( + Settings.builder() + .put( + WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey(), + WriteLoadConstraintSettings.WriteLoadDeciderStatus.ENABLED + ) + .build() + ); + + try { + // Force a ClusterInfo refresh to run collection of the write-load stats. + ClusterInfoServiceUtils.refresh(clusterInfoService); + final Map shardWriteLoads = clusterInfoService.getClusterInfo().getShardWriteLoads(); + + // Verify that each shard has write-load reported. + final ClusterState state = getInstanceFromNode(ClusterService.class).state(); + assertEquals(state.projectState(ProjectId.DEFAULT).metadata().getTotalNumberOfShards(), shardWriteLoads.size()); + double maximumLoadRecorded = 0; + for (IndexMetadata indexMetadata : state.projectState(ProjectId.DEFAULT).metadata()) { + for (int i = 0; i < indexMetadata.getNumberOfShards(); i++) { + final ShardId shardId = new ShardId(indexMetadata.getIndex(), i); + assertTrue(shardWriteLoads.containsKey(shardId)); + maximumLoadRecorded = Math.max(shardWriteLoads.get(shardId), maximumLoadRecorded); + } + } + // And that at least one is greater than zero + assertThat(maximumLoadRecorded, greaterThan(0.0)); + } finally { + updateClusterSettings( + Settings.builder().putNull(WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey()).build() + ); + } + } + public void testIndexCanChangeCustomDataPath() throws Exception { final String index = "test-custom-data-path"; final Path sharedDataPath = getInstanceFromNode(Environment.class).sharedDataDir().resolve(randomAsciiLettersOfLength(10)); @@ -875,4 +993,61 @@ public ClusterService getClusterService() { return clusterService.get(); } } + + /** + * A simple {@link NodeUsageStatsForThreadPoolsCollector} implementation that creates and returns random + * {@link NodeUsageStatsForThreadPools} for each node in the cluster. + *

+ * Note: there's an 'org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector' file that declares this implementation so that the + * plugin system can pick it up and use it for the test set-up. + */ + public static class BogusNodeUsageStatsForThreadPoolsCollector implements NodeUsageStatsForThreadPoolsCollector { + + private final BogusNodeUsageStatsForThreadPoolsCollectorPlugin plugin; + + public BogusNodeUsageStatsForThreadPoolsCollector(BogusNodeUsageStatsForThreadPoolsCollectorPlugin plugin) { + this.plugin = plugin; + } + + @Override + public void collectUsageStats(ActionListener> listener) { + ActionListener.completeWith( + listener, + () -> plugin.getClusterService() + .state() + .nodes() + .stream() + .collect(Collectors.toUnmodifiableMap(DiscoveryNode::getId, node -> makeRandomNodeUsageStats(node.getId()))) + ); + } + + private NodeUsageStatsForThreadPools makeRandomNodeUsageStats(String nodeId) { + NodeUsageStatsForThreadPools.ThreadPoolUsageStats writeThreadPoolStats = new NodeUsageStatsForThreadPools.ThreadPoolUsageStats( + randomNonNegativeInt(), + randomFloat(), + randomNonNegativeLong() + ); + Map statsForThreadPools = new HashMap<>(); + statsForThreadPools.put(ThreadPool.Names.WRITE, writeThreadPoolStats); + return new NodeUsageStatsForThreadPools(nodeId, statsForThreadPools); + } + } + + /** + * Make a plugin to gain access to the {@link ClusterService} instance. + */ + public static class BogusNodeUsageStatsForThreadPoolsCollectorPlugin extends Plugin implements ClusterPlugin { + + private final SetOnce clusterService = new SetOnce<>(); + + @Override + public Collection createComponents(PluginServices services) { + clusterService.set(services.clusterService()); + return List.of(); + } + + public ClusterService getClusterService() { + return clusterService.get(); + } + } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index 02e17e3395760..b14f067992ba0 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalSettingsPlugin; import org.elasticsearch.test.MockLog; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -41,6 +42,7 @@ import static org.hamcrest.Matchers.is; @LuceneTestCase.SuppressCodecs("*") // only use our own codecs +@ESTestCase.WithoutEntitlements // requires entitlement delegation ES-10920 public class DirectIOIT extends ESIntegTestCase { @BeforeClass diff --git a/server/src/internalClusterTest/java/org/elasticsearch/threadpool/SimpleThreadPoolIT.java b/server/src/internalClusterTest/java/org/elasticsearch/threadpool/SimpleThreadPoolIT.java index b9f2a5eb79f22..fa81ee40cb76d 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/threadpool/SimpleThreadPoolIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/threadpool/SimpleThreadPoolIT.java @@ -234,19 +234,29 @@ public void assertValid(TestTelemetryPlugin testTelemetryPlugin, String metricSu } } - public void testWriteThreadpoolEwmaAlphaSetting() { + public void testWriteThreadpoolsEwmaAlphaSetting() { Settings settings = Settings.EMPTY; - var ewmaAlpha = DEFAULT_INDEX_AUTOSCALING_EWMA_ALPHA; + var executionEwmaAlpha = DEFAULT_INDEX_AUTOSCALING_EWMA_ALPHA; if (randomBoolean()) { - ewmaAlpha = randomDoubleBetween(0.0, 1.0, true); - settings = Settings.builder().put(WRITE_THREAD_POOLS_EWMA_ALPHA_SETTING.getKey(), ewmaAlpha).build(); + executionEwmaAlpha = randomDoubleBetween(0.0, 1.0, true); + settings = Settings.builder().put(WRITE_THREAD_POOLS_EWMA_ALPHA_SETTING.getKey(), executionEwmaAlpha).build(); } var nodeName = internalCluster().startNode(settings); var threadPool = internalCluster().getInstance(ThreadPool.class, nodeName); + + // Verify that the write thread pools all use the tracking executor. for (var name : List.of(ThreadPool.Names.WRITE, ThreadPool.Names.SYSTEM_WRITE, ThreadPool.Names.SYSTEM_CRITICAL_WRITE)) { assertThat(threadPool.executor(name), instanceOf(TaskExecutionTimeTrackingEsThreadPoolExecutor.class)); final var executor = (TaskExecutionTimeTrackingEsThreadPoolExecutor) threadPool.executor(name); - assertThat(Double.compare(executor.getEwmaAlpha(), ewmaAlpha), CoreMatchers.equalTo(0)); + assertThat(Double.compare(executor.getExecutionEwmaAlpha(), executionEwmaAlpha), CoreMatchers.equalTo(0)); + + // Only the WRITE thread pool should enable further tracking. + if (name.equals(ThreadPool.Names.WRITE) == false) { + assertFalse(executor.trackingMaxQueueLatency()); + } else { + // Verify that the WRITE thread pool has extra tracking enabled. + assertTrue(executor.trackingMaxQueueLatency()); + } } } } diff --git a/server/src/internalClusterTest/resources/META-INF/services/org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector b/server/src/internalClusterTest/resources/META-INF/services/org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector new file mode 100644 index 0000000000000..787ce436c3ca6 --- /dev/null +++ b/server/src/internalClusterTest/resources/META-INF/services/org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector @@ -0,0 +1,10 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the "Elastic License +# 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side +# Public License v 1"; you may not use this file except in compliance with, at +# your election, the "Elastic License 2.0", the "GNU Affero General Public +# License v3.0 only", or the "Server Side Public License, v 1". +# + +org.elasticsearch.index.shard.IndexShardIT$BogusNodeUsageStatsForThreadPoolsCollector diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index ef3d5da1c9531..90cd3c669a52c 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -429,6 +429,7 @@ org.elasticsearch.index.mapper.MapperFeatures, org.elasticsearch.index.IndexFeatures, org.elasticsearch.search.SearchFeatures, + org.elasticsearch.synonyms.SynonymFeatures, org.elasticsearch.script.ScriptFeatures, org.elasticsearch.search.retriever.RetrieversFeatures, org.elasticsearch.action.admin.cluster.stats.ClusterStatsFeatures, @@ -483,4 +484,5 @@ exports org.elasticsearch.index.codec.perfield; exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn; + exports org.elasticsearch.inference.telemetry; } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 35f5423df0ffb..30017f53aba18 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -211,6 +211,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_DOCUMENTS_FOUND_AND_VALUES_LOADED_8_19 = def(8_841_0_61); public static final TransportVersion ESQL_PROFILE_INCLUDE_PLAN_8_19 = def(8_841_0_62); public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES_8_19 = def(8_841_0_63); + public static final TransportVersion ESQL_FIXED_INDEX_LIKE_8_19 = def(8_841_0_64); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -328,6 +329,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_PROFILE_INCLUDE_PLAN = def(9_111_0_00); public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00); public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES_9_1 = def(9_112_0_01); + public static final TransportVersion ESQL_FIXED_INDEX_LIKE_9_1 = def(9_112_0_02); // Below is the first version in 9.2 and NOT in 9.1. public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00); public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00); @@ -337,6 +339,13 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = def(9_118_0_00); public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00); public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00); + public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00); + public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00); + public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00); + public static final TransportVersion PROJECT_STATE_REGISTRY_ENTRY = def(9_124_0_00); + public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_125_0_00); + public static final TransportVersion SHARD_WRITE_LOAD_IN_CLUSTER_INFO = def(9_126_0_00); + public static final TransportVersion ESQL_EMIT_EMPTY_BUCKETS = def(9_127_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequest.java index 445bac4cd659e..fd4c7c69ad57c 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequest.java @@ -28,12 +28,19 @@ */ public class ClusterAllocationExplainRequest extends MasterNodeRequest { + public static final String INDEX_PARAMETER_NAME = "index"; + public static final String SHARD_PARAMETER_NAME = "shard"; + public static final String PRIMARY_PARAMETER_NAME = "primary"; + public static final String CURRENT_NODE_PARAMETER_NAME = "current_node"; + public static final String INCLUDE_YES_DECISIONS_PARAMETER_NAME = "include_yes_decisions"; + public static final String INCLUDE_DISK_INFO_PARAMETER_NAME = "include_disk_info"; + private static final ObjectParser PARSER = new ObjectParser<>("cluster/allocation/explain"); static { - PARSER.declareString(ClusterAllocationExplainRequest::setIndex, new ParseField("index")); - PARSER.declareInt(ClusterAllocationExplainRequest::setShard, new ParseField("shard")); - PARSER.declareBoolean(ClusterAllocationExplainRequest::setPrimary, new ParseField("primary")); - PARSER.declareString(ClusterAllocationExplainRequest::setCurrentNode, new ParseField("current_node")); + PARSER.declareString(ClusterAllocationExplainRequest::setIndex, new ParseField(INDEX_PARAMETER_NAME)); + PARSER.declareInt(ClusterAllocationExplainRequest::setShard, new ParseField(SHARD_PARAMETER_NAME)); + PARSER.declareBoolean(ClusterAllocationExplainRequest::setPrimary, new ParseField(PRIMARY_PARAMETER_NAME)); + PARSER.declareString(ClusterAllocationExplainRequest::setCurrentNode, new ParseField(CURRENT_NODE_PARAMETER_NAME)); } @Nullable @@ -221,14 +228,15 @@ public String toString() { if (this.useAnyUnassignedShard()) { sb.append("useAnyUnassignedShard=true"); } else { - sb.append("index=").append(index); - sb.append(",shard=").append(shard); - sb.append(",primary?=").append(primary); + sb.append(INDEX_PARAMETER_NAME).append("=").append(index); + sb.append(",").append(SHARD_PARAMETER_NAME).append("=").append(shard); + sb.append(",").append(PRIMARY_PARAMETER_NAME).append("?=").append(primary); if (currentNode != null) { - sb.append(",currentNode=").append(currentNode); + sb.append(",").append(CURRENT_NODE_PARAMETER_NAME).append("=").append(currentNode); } } - sb.append(",includeYesDecisions?=").append(includeYesDecisions); + sb.append(",").append(INCLUDE_YES_DECISIONS_PARAMETER_NAME).append("?=").append(includeYesDecisions); + sb.append(",").append(INCLUDE_DISK_INFO_PARAMETER_NAME).append("?=").append(includeDiskInfo); return sb.toString(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java index 96fbb2c5d6649..f003cd3fc107d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java @@ -26,7 +26,6 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.ComponentTemplate; import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.project.ProjectResolver; @@ -34,7 +33,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Assertions; -import org.elasticsearch.core.FixForMultiProject; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexingPressure; @@ -193,34 +191,33 @@ private boolean applyPipelines(Task task, BulkRequest bulkRequest, Executor exec boolean hasIndexRequestsWithPipelines = false; ClusterState state = clusterService.state(); ProjectId projectId = projectResolver.getProjectId(); - final Metadata metadata; + final ProjectMetadata project; Map componentTemplateSubstitutions = bulkRequest.getComponentTemplateSubstitutions(); Map indexTemplateSubstitutions = bulkRequest.getIndexTemplateSubstitutions(); if (bulkRequest.isSimulated() && (componentTemplateSubstitutions.isEmpty() == false || indexTemplateSubstitutions.isEmpty() == false)) { /* - * If this is a simulated request, and there are template substitutions, then we want to create and use a new metadata that has + * If this is a simulated request, and there are template substitutions, then we want to create and use a new project that has * those templates. That is, we want to add the new templates (which will replace any that already existed with the same name), * and remove the indices and data streams that are referred to from the bulkRequest so that we get settings from the templates * rather than from the indices/data streams. */ - Metadata originalMetadata = state.metadata(); - @FixForMultiProject // properly ensure simulated actions work with MP - Metadata.Builder simulatedMetadataBuilder = Metadata.builder(originalMetadata); + ProjectMetadata originalProject = state.metadata().getProject(projectId); + ProjectMetadata.Builder simulatedMetadataBuilder = ProjectMetadata.builder(originalProject); if (componentTemplateSubstitutions.isEmpty() == false) { Map updatedComponentTemplates = new HashMap<>(); - updatedComponentTemplates.putAll(originalMetadata.getProject(projectId).componentTemplates()); + updatedComponentTemplates.putAll(originalProject.componentTemplates()); updatedComponentTemplates.putAll(componentTemplateSubstitutions); simulatedMetadataBuilder.componentTemplates(updatedComponentTemplates); } if (indexTemplateSubstitutions.isEmpty() == false) { Map updatedIndexTemplates = new HashMap<>(); - updatedIndexTemplates.putAll(originalMetadata.getProject(projectId).templatesV2()); + updatedIndexTemplates.putAll(originalProject.templatesV2()); updatedIndexTemplates.putAll(indexTemplateSubstitutions); simulatedMetadataBuilder.indexTemplates(updatedIndexTemplates); } /* - * We now remove the index from the simulated metadata to force the templates to be used. Note that simulated requests are + * We now remove the index from the simulated project to force the templates to be used. Note that simulated requests are * always index requests -- no other type of request is supported. */ for (DocWriteRequest actionRequest : bulkRequest.requests) { @@ -236,12 +233,11 @@ private boolean applyPipelines(Task task, BulkRequest bulkRequest, Executor exec } } } - metadata = simulatedMetadataBuilder.build(); + project = simulatedMetadataBuilder.build(); } else { - metadata = state.getMetadata(); + project = state.metadata().getProject(projectId); } - ProjectMetadata project = metadata.getProject(projectId); Map resolvedPipelineCache = new HashMap<>(); for (DocWriteRequest actionRequest : bulkRequest.requests) { IndexRequest indexRequest = getIndexWriteRequest(actionRequest); diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java index 2e24858d9781f..ec30886b1acbf 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java @@ -18,7 +18,13 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.BoostingQueryBuilder; +import org.elasticsearch.index.query.ConstantScoreQueryBuilder; +import org.elasticsearch.index.query.DisMaxQueryBuilder; +import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; @@ -268,9 +274,53 @@ public ActionRequestValidationException validate() { if (fields == null || fields.length == 0) { validationException = ValidateActions.addValidationError("no fields specified", validationException); } + + // Band-aid fix for https://github.com/elastic/elasticsearch/issues/116106. + // Semantic queries are high-recall queries, making them poor filters and effectively the same as an exists query when used in that + // context. + if (containsSemanticQuery(indexFilter)) { + validationException = ValidateActions.addValidationError( + "index filter cannot contain semantic queries. Use an exists query instead.", + validationException + ); + } + return validationException; } + /** + * Recursively checks if a query builder contains any semantic queries + */ + private static boolean containsSemanticQuery(QueryBuilder queryBuilder) { + boolean containsSemanticQuery = false; + + if (queryBuilder == null) { + return containsSemanticQuery; + } + + if ("semantic".equals(queryBuilder.getWriteableName())) { + containsSemanticQuery = true; + } else if (queryBuilder instanceof BoolQueryBuilder boolQuery) { + containsSemanticQuery = boolQuery.must().stream().anyMatch(FieldCapabilitiesRequest::containsSemanticQuery) + || boolQuery.mustNot().stream().anyMatch(FieldCapabilitiesRequest::containsSemanticQuery) + || boolQuery.should().stream().anyMatch(FieldCapabilitiesRequest::containsSemanticQuery) + || boolQuery.filter().stream().anyMatch(FieldCapabilitiesRequest::containsSemanticQuery); + } else if (queryBuilder instanceof DisMaxQueryBuilder disMaxQuery) { + containsSemanticQuery = disMaxQuery.innerQueries().stream().anyMatch(FieldCapabilitiesRequest::containsSemanticQuery); + } else if (queryBuilder instanceof NestedQueryBuilder nestedQuery) { + containsSemanticQuery = containsSemanticQuery(nestedQuery.query()); + } else if (queryBuilder instanceof BoostingQueryBuilder boostingQuery) { + containsSemanticQuery = containsSemanticQuery(boostingQuery.positiveQuery()) + || containsSemanticQuery(boostingQuery.negativeQuery()); + } else if (queryBuilder instanceof ConstantScoreQueryBuilder constantScoreQuery) { + containsSemanticQuery = containsSemanticQuery(constantScoreQuery.innerQuery()); + } else if (queryBuilder instanceof FunctionScoreQueryBuilder functionScoreQuery) { + containsSemanticQuery = containsSemanticQuery(functionScoreQuery.query()); + } + + return containsSemanticQuery; + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java index b3ffc564d848c..f8243097e80a3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java @@ -59,7 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("id", Base64.getUrlEncoder().encodeToString(BytesReference.toBytes(pointInTimeId))); - buildBroadcastShardsHeader(builder, params, totalShards, successfulShards, failedShards, skippedShards, null); + buildBroadcastShardsHeader(builder, params, totalShards, successfulShards, skippedShards, failedShards, null); builder.endObject(); return builder; } diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java b/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java index e8cb0421979c5..33172e30fb107 100644 --- a/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java +++ b/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java @@ -41,7 +41,7 @@ /** * ClusterInfo is an object representing a map of nodes to {@link DiskUsage} - * and a map of shard ids to shard sizes, see + * and a map of shard ids to shard sizes and shard write-loads, see * InternalClusterInfoService.shardIdentifierFromRouting(String) * for the key used in the shardSizes map */ @@ -58,9 +58,11 @@ public class ClusterInfo implements ChunkedToXContent, Writeable { final Map dataPath; final Map reservedSpace; final Map estimatedHeapUsages; + final Map nodeUsageStatsForThreadPools; + final Map shardWriteLoads; protected ClusterInfo() { - this(Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); + this(Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); } /** @@ -73,6 +75,7 @@ protected ClusterInfo() { * @param dataPath the shard routing to datapath mapping * @param reservedSpace reserved space per shard broken down by node and data path * @param estimatedHeapUsages estimated heap usage broken down by node + * @param nodeUsageStatsForThreadPools node-level usage stats (operational load) broken down by node * @see #shardIdentifierFromRouting */ public ClusterInfo( @@ -82,7 +85,9 @@ public ClusterInfo( Map shardDataSetSizes, Map dataPath, Map reservedSpace, - Map estimatedHeapUsages + Map estimatedHeapUsages, + Map nodeUsageStatsForThreadPools, + Map shardWriteLoads ) { this.leastAvailableSpaceUsage = Map.copyOf(leastAvailableSpaceUsage); this.mostAvailableSpaceUsage = Map.copyOf(mostAvailableSpaceUsage); @@ -91,6 +96,8 @@ public ClusterInfo( this.dataPath = Map.copyOf(dataPath); this.reservedSpace = Map.copyOf(reservedSpace); this.estimatedHeapUsages = Map.copyOf(estimatedHeapUsages); + this.nodeUsageStatsForThreadPools = Map.copyOf(nodeUsageStatsForThreadPools); + this.shardWriteLoads = Map.copyOf(shardWriteLoads); } public ClusterInfo(StreamInput in) throws IOException { @@ -107,6 +114,16 @@ public ClusterInfo(StreamInput in) throws IOException { } else { this.estimatedHeapUsages = Map.of(); } + if (in.getTransportVersion().onOrAfter(TransportVersions.NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO)) { + this.nodeUsageStatsForThreadPools = in.readImmutableMap(NodeUsageStatsForThreadPools::new); + } else { + this.nodeUsageStatsForThreadPools = Map.of(); + } + if (in.getTransportVersion().onOrAfter(TransportVersions.SHARD_WRITE_LOAD_IN_CLUSTER_INFO)) { + this.shardWriteLoads = in.readImmutableMap(ShardId::new, StreamInput::readDouble); + } else { + this.shardWriteLoads = Map.of(); + } } @Override @@ -124,6 +141,12 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.HEAP_USAGE_IN_CLUSTER_INFO)) { out.writeMap(this.estimatedHeapUsages, StreamOutput::writeWriteable); } + if (out.getTransportVersion().onOrAfter(TransportVersions.NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO)) { + out.writeMap(this.nodeUsageStatsForThreadPools, StreamOutput::writeWriteable); + } + if (out.getTransportVersion().onOrAfter(TransportVersions.SHARD_WRITE_LOAD_IN_CLUSTER_INFO)) { + out.writeMap(this.shardWriteLoads, StreamOutput::writeWriteable, StreamOutput::writeDouble); + } } /** @@ -204,8 +227,8 @@ public Iterator toXContentChunked(ToXContent.Params params return builder.endObject(); // NodeAndPath }), endArray() // end "reserved_sizes" - // NOTE: We don't serialize estimatedHeapUsages at this stage, to avoid - // committing to API payloads until the feature is settled + // NOTE: We don't serialize estimatedHeapUsages/nodeUsageStatsForThreadPools/shardWriteLoads at this stage, to avoid + // committing to API payloads until the features are settled ); } @@ -220,6 +243,13 @@ public Map getEstimatedHeapUsages() { return estimatedHeapUsages; } + /** + * Returns a map containing thread pool usage stats for each node, keyed by node ID. + */ + public Map getNodeUsageStatsForThreadPools() { + return nodeUsageStatsForThreadPools; + } + /** * Returns a node id to disk usage mapping for the path that has the least available space on the node. * Note that this does not take account of reserved space: there may be another path with less available _and unreserved_ space. @@ -236,6 +266,16 @@ public Map getNodeMostAvailableDiskUsages() { return this.mostAvailableSpaceUsage; } + /** + * Returns a map of shard IDs to the write-loads for use in balancing. The write-loads can be interpreted + * as the average number of threads that ingestion to the shard will consume. + * This information may be partial or missing altogether under some circumstances. The absence of a shard + * write load from the map should be interpreted as "unknown". + */ + public Map getShardWriteLoads() { + return shardWriteLoads; + } + /** * Returns the shard size for the given shardId or null if that metric is not available. */ @@ -311,12 +351,25 @@ public boolean equals(Object o) { && shardSizes.equals(that.shardSizes) && shardDataSetSizes.equals(that.shardDataSetSizes) && dataPath.equals(that.dataPath) - && reservedSpace.equals(that.reservedSpace); + && reservedSpace.equals(that.reservedSpace) + && estimatedHeapUsages.equals(that.estimatedHeapUsages) + && nodeUsageStatsForThreadPools.equals(that.nodeUsageStatsForThreadPools) + && shardWriteLoads.equals(that.shardWriteLoads); } @Override public int hashCode() { - return Objects.hash(leastAvailableSpaceUsage, mostAvailableSpaceUsage, shardSizes, shardDataSetSizes, dataPath, reservedSpace); + return Objects.hash( + leastAvailableSpaceUsage, + mostAvailableSpaceUsage, + shardSizes, + shardDataSetSizes, + dataPath, + reservedSpace, + estimatedHeapUsages, + nodeUsageStatsForThreadPools, + shardWriteLoads + ); } @Override @@ -424,4 +477,79 @@ public Builder add(ShardId shardId, long reservedBytes) { } } } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Map leastAvailableSpaceUsage = Map.of(); + private Map mostAvailableSpaceUsage = Map.of(); + private Map shardSizes = Map.of(); + private Map shardDataSetSizes = Map.of(); + private Map dataPath = Map.of(); + private Map reservedSpace = Map.of(); + private Map estimatedHeapUsages = Map.of(); + private Map nodeUsageStatsForThreadPools = Map.of(); + private Map shardWriteLoads = Map.of(); + + public ClusterInfo build() { + return new ClusterInfo( + leastAvailableSpaceUsage, + mostAvailableSpaceUsage, + shardSizes, + shardDataSetSizes, + dataPath, + reservedSpace, + estimatedHeapUsages, + nodeUsageStatsForThreadPools, + shardWriteLoads + ); + } + + public Builder leastAvailableSpaceUsage(Map leastAvailableSpaceUsage) { + this.leastAvailableSpaceUsage = leastAvailableSpaceUsage; + return this; + } + + public Builder mostAvailableSpaceUsage(Map mostAvailableSpaceUsage) { + this.mostAvailableSpaceUsage = mostAvailableSpaceUsage; + return this; + } + + public Builder shardSizes(Map shardSizes) { + this.shardSizes = shardSizes; + return this; + } + + public Builder shardDataSetSizes(Map shardDataSetSizes) { + this.shardDataSetSizes = shardDataSetSizes; + return this; + } + + public Builder dataPath(Map dataPath) { + this.dataPath = dataPath; + return this; + } + + public Builder reservedSpace(Map reservedSpace) { + this.reservedSpace = reservedSpace; + return this; + } + + public Builder estimatedHeapUsages(Map estimatedHeapUsages) { + this.estimatedHeapUsages = estimatedHeapUsages; + return this; + } + + public Builder nodeUsageStatsForThreadPools(Map nodeUsageStatsForThreadPools) { + this.nodeUsageStatsForThreadPools = nodeUsageStatsForThreadPools; + return this; + } + + public Builder shardWriteLoads(Map shardWriteLoads) { + this.shardWriteLoads = shardWriteLoads; + return this; + } + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java b/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java index b47b15f545ed8..fd9c62daebd29 100644 --- a/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java +++ b/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java @@ -34,6 +34,7 @@ public class ClusterInfoSimulator { private final Map shardDataSetSizes; private final Map dataPath; private final Map estimatedHeapUsages; + private final Map nodeThreadPoolUsageStats; public ClusterInfoSimulator(RoutingAllocation allocation) { this.allocation = allocation; @@ -43,6 +44,7 @@ public ClusterInfoSimulator(RoutingAllocation allocation) { this.shardDataSetSizes = Map.copyOf(allocation.clusterInfo().shardDataSetSizes); this.dataPath = Map.copyOf(allocation.clusterInfo().dataPath); this.estimatedHeapUsages = allocation.clusterInfo().getEstimatedHeapUsages(); + this.nodeThreadPoolUsageStats = allocation.clusterInfo().getNodeUsageStatsForThreadPools(); } /** @@ -156,7 +158,9 @@ public ClusterInfo getClusterInfo() { shardDataSetSizes, dataPath, Map.of(), - estimatedHeapUsages + estimatedHeapUsages, + nodeThreadPoolUsageStats, + allocation.clusterInfo().getShardWriteLoads() ); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java b/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java index 066667dfaba84..d4ecec83ebc8c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java +++ b/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java @@ -29,6 +29,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.allocation.DiskThresholdSettings; +import org.elasticsearch.cluster.routing.allocation.WriteLoadConstraintSettings.WriteLoadDeciderStatus; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; @@ -37,6 +38,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.shard.IndexingStats; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.store.StoreStats; import org.elasticsearch.threadpool.ThreadPool; @@ -50,6 +52,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.Consumer; +import static org.elasticsearch.cluster.routing.allocation.WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING; import static org.elasticsearch.core.Strings.format; /** @@ -92,6 +95,7 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt private volatile boolean diskThresholdEnabled; private volatile boolean estimatedHeapThresholdEnabled; + private volatile WriteLoadDeciderStatus writeLoadConstraintEnabled; private volatile TimeValue updateFrequency; private volatile TimeValue fetchTimeout; @@ -99,6 +103,7 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt private volatile Map mostAvailableSpaceUsages; private volatile Map maxHeapPerNode; private volatile Map estimatedHeapUsagePerNode; + private volatile Map nodeThreadPoolUsageStatsPerNode; private volatile IndicesStatsSummary indicesStatsSummary; private final ThreadPool threadPool; @@ -108,6 +113,7 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt private final Object mutex = new Object(); private final List> nextRefreshListeners = new ArrayList<>(); private final EstimatedHeapUsageCollector estimatedHeapUsageCollector; + private final NodeUsageStatsForThreadPoolsCollector nodeUsageStatsForThreadPoolsCollector; private AsyncRefresh currentRefresh; private RefreshScheduler refreshScheduler; @@ -118,16 +124,19 @@ public InternalClusterInfoService( ClusterService clusterService, ThreadPool threadPool, Client client, - EstimatedHeapUsageCollector estimatedHeapUsageCollector + EstimatedHeapUsageCollector estimatedHeapUsageCollector, + NodeUsageStatsForThreadPoolsCollector nodeUsageStatsForThreadPoolsCollector ) { this.leastAvailableSpaceUsages = Map.of(); this.mostAvailableSpaceUsages = Map.of(); this.maxHeapPerNode = Map.of(); this.estimatedHeapUsagePerNode = Map.of(); + this.nodeThreadPoolUsageStatsPerNode = Map.of(); this.indicesStatsSummary = IndicesStatsSummary.EMPTY; this.threadPool = threadPool; this.client = client; this.estimatedHeapUsageCollector = estimatedHeapUsageCollector; + this.nodeUsageStatsForThreadPoolsCollector = nodeUsageStatsForThreadPoolsCollector; this.updateFrequency = INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING.get(settings); this.fetchTimeout = INTERNAL_CLUSTER_INFO_TIMEOUT_SETTING.get(settings); this.diskThresholdEnabled = DiskThresholdSettings.CLUSTER_ROUTING_ALLOCATION_DISK_THRESHOLD_ENABLED_SETTING.get(settings); @@ -142,6 +151,8 @@ public InternalClusterInfoService( CLUSTER_ROUTING_ALLOCATION_ESTIMATED_HEAP_THRESHOLD_DECIDER_ENABLED, this::setEstimatedHeapThresholdEnabled ); + + clusterSettings.initializeAndWatch(WRITE_LOAD_DECIDER_ENABLED_SETTING, this::setWriteLoadConstraintEnabled); } private void setDiskThresholdEnabled(boolean diskThresholdEnabled) { @@ -152,6 +163,10 @@ private void setEstimatedHeapThresholdEnabled(boolean estimatedHeapThresholdEnab this.estimatedHeapThresholdEnabled = estimatedHeapThresholdEnabled; } + private void setWriteLoadConstraintEnabled(WriteLoadDeciderStatus writeLoadConstraintEnabled) { + this.writeLoadConstraintEnabled = writeLoadConstraintEnabled; + } + private void setFetchTimeout(TimeValue fetchTimeout) { this.fetchTimeout = fetchTimeout; } @@ -201,9 +216,10 @@ void execute() { logger.trace("starting async refresh"); try (var ignoredRefs = fetchRefs) { - maybeFetchIndicesStats(diskThresholdEnabled); + maybeFetchIndicesStats(diskThresholdEnabled || writeLoadConstraintEnabled == WriteLoadDeciderStatus.ENABLED); maybeFetchNodeStats(diskThresholdEnabled || estimatedHeapThresholdEnabled); maybeFetchNodesEstimatedHeapUsage(estimatedHeapThresholdEnabled); + maybeFetchNodesUsageStatsForThreadPools(writeLoadConstraintEnabled); } } @@ -242,6 +258,32 @@ private void maybeFetchNodesEstimatedHeapUsage(boolean shouldFetch) { } } + private void maybeFetchNodesUsageStatsForThreadPools(WriteLoadDeciderStatus writeLoadConstraintEnabled) { + if (writeLoadConstraintEnabled != WriteLoadDeciderStatus.DISABLED) { + try (var ignored = threadPool.getThreadContext().clearTraceContext()) { + fetchNodesUsageStatsForThreadPools(); + } + } else { + logger.trace("skipping collecting shard/node write load estimates from cluster, feature currently disabled"); + nodeThreadPoolUsageStatsPerNode = Map.of(); + } + } + + private void fetchNodesUsageStatsForThreadPools() { + nodeUsageStatsForThreadPoolsCollector.collectUsageStats(ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(Map writeLoads) { + nodeThreadPoolUsageStatsPerNode = writeLoads; + } + + @Override + public void onFailure(Exception e) { + logger.warn("failed to fetch write load estimates for nodes", e); + nodeThreadPoolUsageStatsPerNode = Map.of(); + } + }, fetchRefs.acquire())); + } + private void fetchNodesEstimatedHeapUsage() { estimatedHeapUsageCollector.collectClusterHeapUsage(ActionListener.releaseAfter(new ActionListener<>() { @Override @@ -260,7 +302,14 @@ public void onFailure(Exception e) { private void fetchIndicesStats() { final IndicesStatsRequest indicesStatsRequest = new IndicesStatsRequest(); indicesStatsRequest.clear(); - indicesStatsRequest.store(true); + if (diskThresholdEnabled) { + // This returns the shard sizes on disk + indicesStatsRequest.store(true); + } + if (writeLoadConstraintEnabled == WriteLoadDeciderStatus.ENABLED) { + // This returns the shard write-loads + indicesStatsRequest.indexing(true); + } indicesStatsRequest.indicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_CLOSED_HIDDEN); indicesStatsRequest.timeout(fetchTimeout); client.admin() @@ -309,6 +358,7 @@ public void onResponse(IndicesStatsResponse indicesStatsResponse) { } final ShardStats[] stats = indicesStatsResponse.getShards(); + final Map shardWriteLoadByIdentifierBuilder = new HashMap<>(); final Map shardSizeByIdentifierBuilder = new HashMap<>(); final Map shardDataSetSizeBuilder = new HashMap<>(); final Map dataPath = new HashMap<>(); @@ -316,6 +366,7 @@ public void onResponse(IndicesStatsResponse indicesStatsResponse) { new HashMap<>(); buildShardLevelInfo( adjustShardStats(stats), + shardWriteLoadByIdentifierBuilder, shardSizeByIdentifierBuilder, shardDataSetSizeBuilder, dataPath, @@ -329,7 +380,8 @@ public void onResponse(IndicesStatsResponse indicesStatsResponse) { Map.copyOf(shardSizeByIdentifierBuilder), Map.copyOf(shardDataSetSizeBuilder), Map.copyOf(dataPath), - Map.copyOf(reservedSpace) + Map.copyOf(reservedSpace), + Map.copyOf(shardWriteLoadByIdentifierBuilder) ); } @@ -493,7 +545,9 @@ public ClusterInfo getClusterInfo() { indicesStatsSummary.shardDataSetSizes, indicesStatsSummary.dataPath, indicesStatsSummary.reservedSpace, - estimatedHeapUsages + estimatedHeapUsages, + nodeThreadPoolUsageStatsPerNode, + indicesStatsSummary.shardWriteLoads() ); } @@ -523,6 +577,7 @@ public void addListener(Consumer clusterInfoConsumer) { static void buildShardLevelInfo( ShardStats[] stats, + Map shardWriteLoads, Map shardSizes, Map shardDataSetSizeBuilder, Map dataPathByShard, @@ -533,25 +588,31 @@ static void buildShardLevelInfo( dataPathByShard.put(ClusterInfo.NodeAndShard.from(shardRouting), s.getDataPath()); final StoreStats storeStats = s.getStats().getStore(); - if (storeStats == null) { - continue; - } - final long size = storeStats.sizeInBytes(); - final long dataSetSize = storeStats.totalDataSetSizeInBytes(); - final long reserved = storeStats.reservedSizeInBytes(); - - final String shardIdentifier = ClusterInfo.shardIdentifierFromRouting(shardRouting); - logger.trace("shard: {} size: {} reserved: {}", shardIdentifier, size, reserved); - shardSizes.put(shardIdentifier, size); - if (dataSetSize > shardDataSetSizeBuilder.getOrDefault(shardRouting.shardId(), -1L)) { - shardDataSetSizeBuilder.put(shardRouting.shardId(), dataSetSize); + if (storeStats != null) { + final long size = storeStats.sizeInBytes(); + final long dataSetSize = storeStats.totalDataSetSizeInBytes(); + final long reserved = storeStats.reservedSizeInBytes(); + + final String shardIdentifier = ClusterInfo.shardIdentifierFromRouting(shardRouting); + logger.trace("shard: {} size: {} reserved: {}", shardIdentifier, size, reserved); + shardSizes.put(shardIdentifier, size); + if (dataSetSize > shardDataSetSizeBuilder.getOrDefault(shardRouting.shardId(), -1L)) { + shardDataSetSizeBuilder.put(shardRouting.shardId(), dataSetSize); + } + if (reserved != StoreStats.UNKNOWN_RESERVED_BYTES) { + final ClusterInfo.ReservedSpace.Builder reservedSpaceBuilder = reservedSpaceByShard.computeIfAbsent( + new ClusterInfo.NodeAndPath(shardRouting.currentNodeId(), s.getDataPath()), + t -> new ClusterInfo.ReservedSpace.Builder() + ); + reservedSpaceBuilder.add(shardRouting.shardId(), reserved); + } } - if (reserved != StoreStats.UNKNOWN_RESERVED_BYTES) { - final ClusterInfo.ReservedSpace.Builder reservedSpaceBuilder = reservedSpaceByShard.computeIfAbsent( - new ClusterInfo.NodeAndPath(shardRouting.currentNodeId(), s.getDataPath()), - t -> new ClusterInfo.ReservedSpace.Builder() - ); - reservedSpaceBuilder.add(shardRouting.shardId(), reserved); + final IndexingStats indexingStats = s.getStats().getIndexing(); + if (indexingStats != null) { + final double shardWriteLoad = indexingStats.getTotal().getPeakWriteLoad(); + if (shardWriteLoad > shardWriteLoads.getOrDefault(shardRouting.shardId(), -1.0)) { + shardWriteLoads.put(shardRouting.shardId(), shardWriteLoad); + } } } } @@ -579,9 +640,10 @@ private record IndicesStatsSummary( Map shardSizes, Map shardDataSetSizes, Map dataPath, - Map reservedSpace + Map reservedSpace, + Map shardWriteLoads ) { - static final IndicesStatsSummary EMPTY = new IndicesStatsSummary(Map.of(), Map.of(), Map.of(), Map.of()); + static final IndicesStatsSummary EMPTY = new IndicesStatsSummary(Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeUsageStatsForThreadPools.java b/server/src/main/java/org/elasticsearch/cluster/NodeUsageStatsForThreadPools.java new file mode 100644 index 0000000000000..5e84f29af8412 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/cluster/NodeUsageStatsForThreadPools.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.cluster; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * Record of a node's thread pool usage stats (operation load). Maps thread pool stats by thread pool name. + * + * @param nodeId The node ID. + * @param threadPoolUsageStatsMap A map of thread pool name ({@link org.elasticsearch.threadpool.ThreadPool.Names}) to the thread pool's + * usage stats ({@link ThreadPoolUsageStats}). + */ +public record NodeUsageStatsForThreadPools(String nodeId, Map threadPoolUsageStatsMap) implements Writeable { + + public NodeUsageStatsForThreadPools(StreamInput in) throws IOException { + this(in.readString(), in.readMap(ThreadPoolUsageStats::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.nodeId); + out.writeMap(threadPoolUsageStatsMap, StreamOutput::writeWriteable); + } + + @Override + public int hashCode() { + return Objects.hash(nodeId, threadPoolUsageStatsMap); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NodeUsageStatsForThreadPools other = (NodeUsageStatsForThreadPools) o; + for (var entry : other.threadPoolUsageStatsMap.entrySet()) { + var loadStats = threadPoolUsageStatsMap.get(entry.getKey()); + if (loadStats == null || loadStats.equals(entry.getValue()) == false) { + return false; + } + } + return true; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(getClass().getSimpleName() + "{nodeId=" + nodeId + ", threadPoolUsageStatsMap=["); + for (var entry : threadPoolUsageStatsMap.entrySet()) { + builder.append("{ThreadPool.Names=" + entry.getKey() + ", ThreadPoolUsageStats=" + entry.getValue() + "}"); + } + builder.append("]}"); + return builder.toString(); + } + + /** + * Record of usage stats for a thread pool. + * + * @param totalThreadPoolThreads Total number of threads in the thread pool. + * @param averageThreadPoolUtilization Percent of thread pool threads that are in use, averaged over some period of time. + * @param averageThreadPoolQueueLatencyMillis How much time tasks spend in the thread pool queue. Zero if there is nothing being queued + * in the write thread pool. + */ + public record ThreadPoolUsageStats( + int totalThreadPoolThreads, + float averageThreadPoolUtilization, + long averageThreadPoolQueueLatencyMillis + ) implements Writeable { + + public ThreadPoolUsageStats(StreamInput in) throws IOException { + this(in.readVInt(), in.readFloat(), in.readVLong()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(this.totalThreadPoolThreads); + out.writeFloat(this.averageThreadPoolUtilization); + out.writeVLong(this.averageThreadPoolQueueLatencyMillis); + } + + @Override + public int hashCode() { + return Objects.hash(totalThreadPoolThreads, averageThreadPoolUtilization, averageThreadPoolQueueLatencyMillis); + } + + @Override + public String toString() { + return "[totalThreadPoolThreads=" + + totalThreadPoolThreads + + ", averageThreadPoolUtilization=" + + averageThreadPoolUtilization + + ", averageThreadPoolQueueLatencyMillis=" + + averageThreadPoolQueueLatencyMillis + + "]"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ThreadPoolUsageStats other = (ThreadPoolUsageStats) o; + return totalThreadPoolThreads == other.totalThreadPoolThreads + && averageThreadPoolUtilization == other.averageThreadPoolUtilization + && averageThreadPoolQueueLatencyMillis == other.averageThreadPoolQueueLatencyMillis; + } + + } // ThreadPoolUsageStats + +} diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeUsageStatsForThreadPoolsCollector.java b/server/src/main/java/org/elasticsearch/cluster/NodeUsageStatsForThreadPoolsCollector.java new file mode 100644 index 0000000000000..e302a4abed559 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/cluster/NodeUsageStatsForThreadPoolsCollector.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.cluster; + +import org.elasticsearch.action.ActionListener; + +import java.util.Map; + +/** + * Collects the usage stats (like write thread pool load) estimations for each node in the cluster. + *

+ * Results are returned as a map of node ID to node usage stats. + */ +public interface NodeUsageStatsForThreadPoolsCollector { + /** + * This will be used when there is no NodeUsageLoadCollector available. + */ + NodeUsageStatsForThreadPoolsCollector EMPTY = listener -> listener.onResponse(Map.of()); + + /** + * Collects the write load estimates from the cluster. + * + * @param listener The listener to receive the write load results. + */ + void collectUsageStats(ActionListener> listener); +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFailureStoreDefinition.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFailureStoreDefinition.java index 6b99dcbf22417..5f6221fc15872 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFailureStoreDefinition.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFailureStoreDefinition.java @@ -10,6 +10,7 @@ package org.elasticsearch.cluster.metadata; import org.elasticsearch.cluster.routing.allocation.DataTier; +import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; @@ -40,7 +41,9 @@ public class DataStreamFailureStoreDefinition { IndexMetadata.SETTING_NUMBER_OF_SHARDS, IndexMetadata.SETTING_NUMBER_OF_REPLICAS, IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, - IndexSettings.INDEX_REFRESH_INTERVAL_SETTING.getKey() + IndexSettings.INDEX_REFRESH_INTERVAL_SETTING.getKey(), + // Different recovery implementations may be provided on the index which need to be preserved. + ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_SETTING.getKey() ); public static final Set SUPPORTED_USER_SETTINGS_PREFIXES = Set.of( IndexMetadata.INDEX_ROUTING_REQUIRE_GROUP_PREFIX + ".", diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java index a6a37f8bec332..d4bc58c299435 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java @@ -835,10 +835,6 @@ private Iterator toXContentChunkedWithSingleProjectFormat( ); } - private static final DiffableUtils.KeySerializer PROJECT_ID_SERIALIZER = DiffableUtils.getWriteableKeySerializer( - ProjectId.READER - ); - private static class MetadataDiff implements Diff { private final long version; @@ -880,7 +876,7 @@ private static class MetadataDiff implements Diff { multiProject = null; } else { singleProject = null; - multiProject = DiffableUtils.diff(before.projectMetadata, after.projectMetadata, PROJECT_ID_SERIALIZER); + multiProject = DiffableUtils.diff(before.projectMetadata, after.projectMetadata, ProjectId.PROJECT_ID_SERIALIZER); } if (empty) { @@ -1004,7 +1000,7 @@ private MetadataDiff(StreamInput in) throws IOException { singleProject = null; multiProject = DiffableUtils.readJdkMapDiff( in, - PROJECT_ID_SERIALIZER, + ProjectId.PROJECT_ID_SERIALIZER, ProjectMetadata::readFrom, ProjectMetadata.ProjectMetadataDiff::new ); @@ -1059,7 +1055,7 @@ public void writeTo(StreamOutput out) throws IOException { if (multiProject != null) { multiProject.writeTo(out); } else { - DiffableUtils.singleEntryDiff(DEFAULT_PROJECT_ID, singleProject, PROJECT_ID_SERIALIZER).writeTo(out); + DiffableUtils.singleEntryDiff(DEFAULT_PROJECT_ID, singleProject, ProjectId.PROJECT_ID_SERIALIZER).writeTo(out); } } } @@ -1230,6 +1226,7 @@ public static Metadata readFrom(StreamInput in) throws IOException { } private static void readBwcCustoms(StreamInput in, Builder builder) throws IOException { + final ProjectMetadata.Builder projectBuilder = builder.getProject(ProjectId.DEFAULT); final Set clusterScopedNames = in.namedWriteableRegistry().getReaders(ClusterCustom.class).keySet(); final Set projectScopedNames = in.namedWriteableRegistry().getReaders(ProjectCustom.class).keySet(); final int count = in.readVInt(); @@ -1245,9 +1242,9 @@ private static void readBwcCustoms(StreamInput in, Builder builder) throws IOExc if (custom instanceof PersistentTasksCustomMetadata persistentTasksCustomMetadata) { final var tuple = persistentTasksCustomMetadata.split(); builder.putCustom(tuple.v1().getWriteableName(), tuple.v1()); - builder.putProjectCustom(tuple.v2().getWriteableName(), tuple.v2()); + projectBuilder.putCustom(tuple.v2().getWriteableName(), tuple.v2()); } else { - builder.putProjectCustom(custom.getWriteableName(), custom); + projectBuilder.putCustom(custom.getWriteableName(), custom); } } else { throw new IllegalArgumentException("Unknown custom name [" + name + "]"); @@ -1499,12 +1496,6 @@ public Builder put(String name, ComponentTemplate componentTemplate) { return this; } - @Deprecated(forRemoval = true) - public Builder removeComponentTemplate(String name) { - getSingleProject().removeComponentTemplate(name); - return this; - } - @Deprecated(forRemoval = true) public Builder componentTemplates(Map componentTemplates) { getSingleProject().componentTemplates(componentTemplates); @@ -1523,12 +1514,6 @@ public Builder put(String name, ComposableIndexTemplate indexTemplate) { return this; } - @Deprecated(forRemoval = true) - public Builder removeIndexTemplate(String name) { - getSingleProject().removeIndexTemplate(name); - return this; - } - @Deprecated(forRemoval = true) public Builder dataStreams(Map dataStreams, Map dataStreamAliases) { getSingleProject().dataStreams(dataStreams, dataStreamAliases); @@ -1557,11 +1542,6 @@ public Builder removeDataStream(String name) { return this; } - @Deprecated(forRemoval = true) - public boolean removeDataStreamAlias(String aliasName, String dataStreamName, boolean mustExist) { - return getSingleProject().removeDataStreamAlias(aliasName, dataStreamName, mustExist); - } - public Builder putCustom(String type, ClusterCustom custom) { customs.put(type, Objects.requireNonNull(custom, type)); return this; @@ -1569,7 +1549,8 @@ public Builder putCustom(String type, ClusterCustom custom) { @Deprecated(forRemoval = true) public Builder putCustom(String type, ProjectCustom custom) { - return putProjectCustom(type, custom); + getSingleProject().putCustom(type, Objects.requireNonNull(custom, type)); + return this; } public ClusterCustom getCustom(String type) { @@ -1592,12 +1573,6 @@ public Builder customs(Map clusterCustoms) { return this; } - @Deprecated(forRemoval = true) - public Builder putProjectCustom(String type, ProjectCustom custom) { - getSingleProject().putCustom(type, Objects.requireNonNull(custom, type)); - return this; - } - @Deprecated(forRemoval = true) public Builder projectCustoms(Map projectCustoms) { projectCustoms.forEach((key, value) -> Objects.requireNonNull(value, key)); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java index 94fa0164b5fbe..88f314ea6cbfe 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.metadata; +import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -26,6 +27,7 @@ public class ProjectId implements Writeable, ToXContent { private static final String DEFAULT_STRING = "default"; public static final ProjectId DEFAULT = new ProjectId(DEFAULT_STRING); public static final Reader READER = ProjectId::readFrom; + public static final DiffableUtils.KeySerializer PROJECT_ID_SERIALIZER = DiffableUtils.getWriteableKeySerializer(READER); private static final int MAX_LENGTH = 128; private final String id; diff --git a/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java b/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java index 2876ebc13c70c..014ee37724cbc 100644 --- a/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java +++ b/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java @@ -13,15 +13,23 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.AbstractNamedDiffable; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterState.Custom; +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.Diffable; +import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.NamedDiffable; +import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Collections; @@ -30,22 +38,29 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; /** * Represents a registry for managing and retrieving project-specific state in the cluster state. */ -public class ProjectStateRegistry extends AbstractNamedDiffable implements ClusterState.Custom { +public class ProjectStateRegistry extends AbstractNamedDiffable implements Custom, NamedDiffable { public static final String TYPE = "projects_registry"; public static final ProjectStateRegistry EMPTY = new ProjectStateRegistry(Collections.emptyMap(), Collections.emptySet(), 0); + private static final Entry EMPTY_ENTRY = new Entry(Settings.EMPTY); - private final Map projectsSettings; + private final Map projectsEntries; // Projects that have been marked for deletion based on their file-based setting private final Set projectsMarkedForDeletion; // A counter that is incremented each time one or more projects are marked for deletion. private final long projectsMarkedForDeletionGeneration; public ProjectStateRegistry(StreamInput in) throws IOException { - projectsSettings = in.readMap(ProjectId::readFrom, Settings::readSettingsFromStream); + if (in.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY)) { + projectsEntries = in.readMap(ProjectId::readFrom, Entry::readFrom); + } else { + Map settingsMap = in.readMap(ProjectId::readFrom, Settings::readSettingsFromStream); + projectsEntries = settingsMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> new Entry(e.getValue()))); + } if (in.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_RECORDS_DELETIONS)) { projectsMarkedForDeletion = in.readCollectionAsImmutableSet(ProjectId::readFrom); projectsMarkedForDeletionGeneration = in.readVLong(); @@ -56,11 +71,11 @@ public ProjectStateRegistry(StreamInput in) throws IOException { } private ProjectStateRegistry( - Map projectsSettings, + Map projectEntries, Set projectsMarkedForDeletion, long projectsMarkedForDeletionGeneration ) { - this.projectsSettings = projectsSettings; + this.projectsEntries = projectEntries; this.projectsMarkedForDeletion = projectsMarkedForDeletion; this.projectsMarkedForDeletionGeneration = projectsMarkedForDeletionGeneration; } @@ -75,7 +90,11 @@ private ProjectStateRegistry( */ public static Settings getProjectSettings(ProjectId projectId, ClusterState clusterState) { ProjectStateRegistry registry = clusterState.custom(TYPE, EMPTY); - return registry.projectsSettings.getOrDefault(projectId, Settings.EMPTY); + return registry.getProjectSettings(projectId); + } + + public Settings getProjectSettings(ProjectId projectId) { + return projectsEntries.getOrDefault(projectId, EMPTY_ENTRY).settings; } public boolean isProjectMarkedForDeletion(ProjectId projectId) { @@ -91,12 +110,10 @@ public Iterator toXContentChunked(ToXContent.Params params return Iterators.concat( Iterators.single((builder, p) -> builder.startArray("projects")), - Iterators.map(projectsSettings.entrySet().iterator(), entry -> (builder, p) -> { + Iterators.map(projectsEntries.entrySet().iterator(), entry -> (builder, p) -> { builder.startObject(); builder.field("id", entry.getKey()); - builder.startObject("settings"); - entry.getValue().toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("flat_settings", "true"))); - builder.endObject(); + entry.getValue().toXContent(builder, params); builder.field("marked_for_deletion", projectsMarkedForDeletion.contains(entry.getKey())); return builder.endObject(); }), @@ -105,8 +122,19 @@ public Iterator toXContentChunked(ToXContent.Params params ); } - public static NamedDiff readDiffFrom(StreamInput in) throws IOException { - return readDiffFrom(ClusterState.Custom.class, TYPE, in); + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { + if (in.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY)) { + return new ProjectStateRegistryDiff(in); + } + return readDiffFrom(Custom.class, TYPE, in); + } + + @Override + public Diff diff(Custom previousState) { + if (this.equals(previousState)) { + return SimpleDiffable.empty(); + } + return new ProjectStateRegistryDiff((ProjectStateRegistry) previousState, this); } @Override @@ -121,7 +149,14 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap(projectsSettings); + if (out.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY)) { + out.writeMap(projectsEntries); + } else { + Map settingsMap = projectsEntries.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().settings())); + out.writeMap(settingsMap); + } if (out.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_RECORDS_DELETIONS)) { out.writeCollection(projectsMarkedForDeletion); out.writeVLong(projectsMarkedForDeletionGeneration); @@ -133,7 +168,7 @@ public void writeTo(StreamOutput out) throws IOException { } public int size() { - return projectsSettings.size(); + return projectsEntries.size(); } public long getProjectsMarkedForDeletionGeneration() { @@ -141,15 +176,15 @@ public long getProjectsMarkedForDeletionGeneration() { } // visible for testing - Map getProjectsSettings() { - return Collections.unmodifiableMap(projectsSettings); + Set knownProjects() { + return projectsEntries.keySet(); } @Override public String toString() { return "ProjectStateRegistry[" - + "projectsSettings=" - + projectsSettings + + "entities=" + + projectsEntries + ", projectsMarkedForDeletion=" + projectsMarkedForDeletion + ", projectsMarkedForDeletionGeneration=" @@ -163,13 +198,13 @@ public boolean equals(Object o) { if (o instanceof ProjectStateRegistry == false) return false; ProjectStateRegistry that = (ProjectStateRegistry) o; return projectsMarkedForDeletionGeneration == that.projectsMarkedForDeletionGeneration - && Objects.equals(projectsSettings, that.projectsSettings) + && Objects.equals(projectsEntries, that.projectsEntries) && Objects.equals(projectsMarkedForDeletion, that.projectsMarkedForDeletion); } @Override public int hashCode() { - return Objects.hash(projectsSettings, projectsMarkedForDeletion, projectsMarkedForDeletionGeneration); + return Objects.hash(projectsEntries, projectsMarkedForDeletion, projectsMarkedForDeletionGeneration); } public static Builder builder(ClusterState original) { @@ -185,26 +220,86 @@ public static Builder builder() { return new Builder(); } + static class ProjectStateRegistryDiff implements NamedDiff { + private static final DiffableUtils.DiffableValueReader VALUE_READER = new DiffableUtils.DiffableValueReader<>( + Entry::readFrom, + Entry.EntryDiff::readFrom + ); + + private final DiffableUtils.MapDiff> projectsEntriesDiff; + private final Set projectsMarkedForDeletion; + private final long projectsMarkedForDeletionGeneration; + + ProjectStateRegistryDiff(StreamInput in) throws IOException { + projectsEntriesDiff = DiffableUtils.readJdkMapDiff(in, ProjectId.PROJECT_ID_SERIALIZER, VALUE_READER); + projectsMarkedForDeletion = in.readCollectionAsImmutableSet(ProjectId.READER); + projectsMarkedForDeletionGeneration = in.readVLong(); + } + + ProjectStateRegistryDiff(ProjectStateRegistry previousState, ProjectStateRegistry currentState) { + projectsEntriesDiff = DiffableUtils.diff( + previousState.projectsEntries, + currentState.projectsEntries, + ProjectId.PROJECT_ID_SERIALIZER, + VALUE_READER + ); + projectsMarkedForDeletion = currentState.projectsMarkedForDeletion; + projectsMarkedForDeletionGeneration = currentState.projectsMarkedForDeletionGeneration; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.PROJECT_STATE_REGISTRY_ENTRY; + } + + @Override + public Custom apply(Custom part) { + return new ProjectStateRegistry( + projectsEntriesDiff.apply(((ProjectStateRegistry) part).projectsEntries), + projectsMarkedForDeletion, + projectsMarkedForDeletionGeneration + ); + } + + @Override + public String getWriteableName() { + return TYPE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + projectsEntriesDiff.writeTo(out); + out.writeCollection(projectsMarkedForDeletion); + out.writeVLong(projectsMarkedForDeletionGeneration); + } + } + public static class Builder { - private final ImmutableOpenMap.Builder projectsSettings; + private final ImmutableOpenMap.Builder projectsEntries; private final Set projectsMarkedForDeletion; private final long projectsMarkedForDeletionGeneration; private boolean newProjectMarkedForDeletion = false; private Builder() { - this.projectsSettings = ImmutableOpenMap.builder(); + this.projectsEntries = ImmutableOpenMap.builder(); projectsMarkedForDeletion = new HashSet<>(); projectsMarkedForDeletionGeneration = 0; } private Builder(ProjectStateRegistry original) { - this.projectsSettings = ImmutableOpenMap.builder(original.projectsSettings); + this.projectsEntries = ImmutableOpenMap.builder(original.projectsEntries); this.projectsMarkedForDeletion = new HashSet<>(original.projectsMarkedForDeletion); this.projectsMarkedForDeletionGeneration = original.projectsMarkedForDeletionGeneration; } public Builder putProjectSettings(ProjectId projectId, Settings settings) { - projectsSettings.put(projectId, settings); + Entry entry = projectsEntries.get(projectId); + if (entry == null) { + entry = new Entry(settings); + } else { + entry = entry.withSettings(settings); + } + projectsEntries.put(projectId, entry); return this; } @@ -216,17 +311,63 @@ public Builder markProjectForDeletion(ProjectId projectId) { } public ProjectStateRegistry build() { - final var unknownButUnderDeletion = Sets.difference(projectsMarkedForDeletion, projectsSettings.keys()); + final var unknownButUnderDeletion = Sets.difference(projectsMarkedForDeletion, projectsEntries.keys()); if (unknownButUnderDeletion.isEmpty() == false) { throw new IllegalArgumentException( "Cannot mark projects for deletion that are not in the registry: " + unknownButUnderDeletion ); } return new ProjectStateRegistry( - projectsSettings.build(), + projectsEntries.build(), projectsMarkedForDeletion, newProjectMarkedForDeletion ? projectsMarkedForDeletionGeneration + 1 : projectsMarkedForDeletionGeneration ); } } + + private record Entry(Settings settings) implements Writeable, Diffable { + + public static Entry readFrom(StreamInput in) throws IOException { + return new Entry(Settings.readSettingsFromStream(in)); + } + + public Entry withSettings(Settings settings) { + return new Entry(settings); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeWriteable(settings); + } + + public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject("settings"); + settings.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("flat_settings", "true"))); + builder.endObject(); + } + + @Override + public Diff diff(Entry previousState) { + if (this == previousState) { + return SimpleDiffable.empty(); + } + return new EntryDiff(settings.diff(previousState.settings)); + } + + private record EntryDiff(Diff settingsDiff) implements Diff { + public static EntryDiff readFrom(StreamInput in) throws IOException { + return new EntryDiff(Settings.readSettingsDiffFromStream(in)); + } + + @Override + public Entry apply(Entry part) { + return part.withSettings(settingsDiff.apply(part.settings)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeWriteable(settingsDiff); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java b/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java index 3df0d2d65b657..050181802af8d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java @@ -15,6 +15,8 @@ import org.elasticsearch.action.RoutingMissingException; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexReshardingMetadata; +import org.elasticsearch.cluster.metadata.IndexReshardingState; import org.elasticsearch.cluster.metadata.MappingMetadata; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; @@ -73,11 +75,13 @@ public static IndexRouting fromIndexMetadata(IndexMetadata metadata) { protected final String indexName; private final int routingNumShards; private final int routingFactor; + private final IndexReshardingMetadata indexReshardingMetadata; private IndexRouting(IndexMetadata metadata) { this.indexName = metadata.getIndex().getName(); this.routingNumShards = metadata.getRoutingNumShards(); this.routingFactor = metadata.getRoutingFactor(); + this.indexReshardingMetadata = metadata.getReshardingMetadata(); } /** @@ -149,6 +153,23 @@ private static int effectiveRoutingToHash(String effectiveRouting) { */ public void checkIndexSplitAllowed() {} + /** + * If this index is in the process of resharding, and the shard to which this request is being routed, + * is a target shard that is not yet in HANDOFF state, then route it to the source shard. + * @param shardId shardId to which the current document is routed based on hashing + * @return Updated shardId + */ + protected final int rerouteIfResharding(int shardId) { + if (indexReshardingMetadata != null && indexReshardingMetadata.getSplit().isTargetShard(shardId)) { + assert indexReshardingMetadata.isSplit() : "Index resharding state is not a split"; + if (indexReshardingMetadata.getSplit() + .targetStateAtLeast(shardId, IndexReshardingState.Split.TargetShardState.HANDOFF) == false) { + return indexReshardingMetadata.getSplit().sourceShard(shardId); + } + } + return shardId; + } + private abstract static class IdAndRoutingOnly extends IndexRouting { private final boolean routingRequired; private final IndexVersion creationVersion; @@ -195,19 +216,22 @@ public int indexShard(String id, @Nullable String routing, XContentType sourceTy throw new IllegalStateException("id is required and should have been set by process"); } checkRoutingRequired(id, routing); - return shardId(id, routing); + int shardId = shardId(id, routing); + return rerouteIfResharding(shardId); } @Override public int updateShard(String id, @Nullable String routing) { checkRoutingRequired(id, routing); - return shardId(id, routing); + int shardId = shardId(id, routing); + return rerouteIfResharding(shardId); } @Override public int deleteShard(String id, @Nullable String routing) { checkRoutingRequired(id, routing); - return shardId(id, routing); + int shardId = shardId(id, routing); + return rerouteIfResharding(shardId); } @Override @@ -314,7 +338,8 @@ public int indexShard(String id, @Nullable String routing, XContentType sourceTy assert Transports.assertNotTransportThread("parsing the _source can get slow"); checkNoRouting(routing); hash = hashSource(sourceType, source).buildHash(IndexRouting.ExtractFromSource::defaultOnEmpty); - return hashToShardId(hash); + int shardId = hashToShardId(hash); + return (rerouteIfResharding(shardId)); } public String createId(XContentType sourceType, BytesReference source, byte[] suffix) { @@ -454,13 +479,15 @@ public int updateShard(String id, @Nullable String routing) { @Override public int deleteShard(String id, @Nullable String routing) { checkNoRouting(routing); - return idToHash(id); + int shardId = idToHash(id); + return (rerouteIfResharding(shardId)); } @Override public int getShard(String id, @Nullable String routing) { checkNoRouting(routing); - return idToHash(id); + int shardId = idToHash(id); + return (rerouteIfResharding(shardId)); } private void checkNoRouting(@Nullable String routing) { diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/RecoverySource.java b/server/src/main/java/org/elasticsearch/cluster/routing/RecoverySource.java index 838d3cf539b3b..4f8f1c2d8d64e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/RecoverySource.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/RecoverySource.java @@ -355,7 +355,7 @@ protected void writeAdditionalFields(StreamOutput out) throws IOException { @Override public void addAdditionalFields(XContentBuilder builder, Params params) throws IOException { - sourceShardId.toXContent(builder, params); + builder.field("sourceShardId", sourceShardId); } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java index 3b68004d3e00e..c737b89b80f73 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java @@ -1134,22 +1134,22 @@ private boolean tryRelocateShard(ModelNode minNode, ModelNode maxNode, ProjectIn continue; } - final Decision decision = new Decision.Multi().add(allocationDecision).add(rebalanceDecision); + final Decision.Type canAllocateOrRebalance = Decision.Type.min(allocationDecision.type(), rebalanceDecision.type()); maxNode.removeShard(projectIndex(shard), shard); long shardSize = allocation.clusterInfo().getShardSize(shard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE); - assert decision.type() == Type.YES || decision.type() == Type.THROTTLE : decision.type(); + assert canAllocateOrRebalance == Type.YES || canAllocateOrRebalance == Type.THROTTLE : canAllocateOrRebalance; logger.debug( "decision [{}]: relocate [{}] from [{}] to [{}]", - decision.type(), + canAllocateOrRebalance, shard, maxNode.getNodeId(), minNode.getNodeId() ); minNode.addShard( projectIndex(shard), - decision.type() == Type.YES + canAllocateOrRebalance == Type.YES /* only allocate on the cluster if we are not throttled */ ? routingNodes.relocateShard(shard, minNode.getNodeId(), shardSize, "rebalance", allocation.changes()).v1() : shard.relocate(minNode.getNodeId(), shardSize) diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java index 38c2806778dff..8e69d72777f04 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.routing.UnassignedInfo.AllocationStatus; import org.elasticsearch.cluster.routing.allocation.RoutingAllocation; import org.elasticsearch.cluster.routing.allocation.decider.Decision; +import org.elasticsearch.common.FrequencyCappedAction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDeciders.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDeciders.java index c80aa1e69f212..29d8e8051b421 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDeciders.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDeciders.java @@ -205,7 +205,7 @@ private Decision withDeciders( BiFunction logMessageCreator ) { if (debugMode == RoutingAllocation.DebugMode.OFF) { - var result = Decision.YES; + Decision result = Decision.YES; for (AllocationDecider decider : deciders) { var decision = deciderAction.apply(decider); if (decision.type() == Decision.Type.NO) { diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/Decision.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/Decision.java index 560eb2d62278a..b2cfd9cc986aa 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/Decision.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/Decision.java @@ -23,20 +23,18 @@ import java.util.Collections; import java.util.List; import java.util.Locale; -import java.util.Objects; /** - * This abstract class defining basic {@link Decision} used during shard - * allocation process. + * A {@link Decision} used during shard allocation process. * * @see AllocationDecider */ -public abstract class Decision implements ToXContent, Writeable { +public sealed interface Decision extends ToXContent, Writeable permits Decision.Single, Decision.Multi { - public static final Decision ALWAYS = new Single(Type.YES); - public static final Decision YES = new Single(Type.YES); - public static final Decision NO = new Single(Type.NO); - public static final Decision THROTTLE = new Single(Type.THROTTLE); + Single ALWAYS = new Single(Type.YES); + Single YES = new Single(Type.YES); + Single NO = new Single(Type.NO); + Single THROTTLE = new Single(Type.THROTTLE); /** * Creates a simple decision @@ -46,40 +44,69 @@ public abstract class Decision implements ToXContent, Writeable { * @param explanationParams additional parameters for the decision * @return new {@link Decision} instance */ - public static Decision single(Type type, @Nullable String label, @Nullable String explanation, @Nullable Object... explanationParams) { + static Decision single(Type type, @Nullable String label, @Nullable String explanation, @Nullable Object... explanationParams) { return new Single(type, label, explanation, explanationParams); } - public static Decision readFrom(StreamInput in) throws IOException { + static Decision readFrom(StreamInput in) throws IOException { // Determine whether to read a Single or Multi Decision if (in.readBoolean()) { Multi result = new Multi(); int decisionCount = in.readVInt(); for (int i = 0; i < decisionCount; i++) { - Decision s = readFrom(in); - result.decisions.add(s); + var flag = in.readBoolean(); + assert flag == false : "nested multi decision is not permitted"; + var single = readSingleFrom(in); + result.decisions.add(single); } return result; } else { - final Type type = Type.readFrom(in); - final String label = in.readOptionalString(); - final String explanation = in.readOptionalString(); - if (label == null && explanation == null) { - return switch (type) { - case YES -> YES; - case THROTTLE -> THROTTLE; - case NO -> NO; - }; - } - return new Single(type, label, explanation); + return readSingleFrom(in); + } + } + + private static Single readSingleFrom(StreamInput in) throws IOException { + final Type type = Type.readFrom(in); + final String label = in.readOptionalString(); + final String explanation = in.readOptionalString(); + if (label == null && explanation == null) { + return switch (type) { + case YES -> YES; + case THROTTLE -> THROTTLE; + case NO -> NO; + }; } + return new Single(type, label, explanation); } + /** + * Get the {@link Type} of this decision + * @return {@link Type} of this decision + */ + Type type(); + + /** + * Get the description label for this decision. + */ + @Nullable + String label(); + + /** + * Get the explanation for this decision. + */ + @Nullable + String getExplanation(); + + /** + * Return the list of all decisions that make up this decision + */ + List getDecisions(); + /** * This enumeration defines the * possible types of decisions */ - public enum Type implements Writeable { + enum Type implements Writeable { YES(1), THROTTLE(2), NO(0); @@ -110,45 +137,22 @@ public boolean higherThan(Type other) { return false; } else if (other == NO) { return true; - } else if (other == THROTTLE && this == YES) { - return true; - } - return false; + } else return other == THROTTLE && this == YES; } - } - - /** - * Get the {@link Type} of this decision - * @return {@link Type} of this decision - */ - public abstract Type type(); - - /** - * Get the description label for this decision. - */ - @Nullable - public abstract String label(); - - /** - * Get the explanation for this decision. - */ - @Nullable - public abstract String getExplanation(); + /** + * @return lowest decision by precedence NO->THROTTLE->YES + */ + public static Type min(Type a, Type b) { + return a.higherThan(b) ? b : a; + } - /** - * Return the list of all decisions that make up this decision - */ - public abstract List getDecisions(); + } /** * Simple class representing a single decision */ - public static class Single extends Decision implements ToXContentObject { - private final Type type; - private final String label; - private final String explanationString; - + record Single(Type type, String label, String explanationString) implements Decision, ToXContentObject { /** * Creates a new {@link Single} decision of a given type * @param type {@link Type} of the decision @@ -165,24 +169,13 @@ private Single(Type type) { * @param explanationParams A set of additional parameters */ public Single(Type type, @Nullable String label, @Nullable String explanation, @Nullable Object... explanationParams) { - this.type = type; - this.label = label; - if (explanationParams != null && explanationParams.length > 0) { - this.explanationString = String.format(Locale.ROOT, explanation, explanationParams); - } else { - this.explanationString = explanation; - } - } - - @Override - public Type type() { - return this.type; - } - - @Override - @Nullable - public String label() { - return this.label; + this( + type, + label, + explanationParams != null && explanationParams.length > 0 + ? String.format(Locale.ROOT, explanation, explanationParams) + : explanation + ); } @Override @@ -199,29 +192,6 @@ public String getExplanation() { return this.explanationString; } - @Override - public boolean equals(Object object) { - if (this == object) { - return true; - } - - if (object == null || getClass() != object.getClass()) { - return false; - } - - Decision.Single s = (Decision.Single) object; - return this.type == s.type && Objects.equals(label, s.label) && Objects.equals(explanationString, s.explanationString); - } - - @Override - public int hashCode() { - int result = type.hashCode(); - result = 31 * result + (label == null ? 0 : label.hashCode()); - String explanationStr = explanationString; - result = 31 * result + (explanationStr == null ? 0 : explanationStr.hashCode()); - return result; - } - @Override public String toString() { if (explanationString != null) { @@ -254,9 +224,11 @@ public void writeTo(StreamOutput out) throws IOException { /** * Simple class representing a list of decisions */ - public static class Multi extends Decision implements ToXContentFragment { + record Multi(List decisions) implements Decision, ToXContentFragment { - private final List decisions = new ArrayList<>(); + public Multi() { + this(new ArrayList<>()); + } /** * Add a decision to this {@link Multi}decision instance @@ -264,7 +236,8 @@ public static class Multi extends Decision implements ToXContentFragment { * @return {@link Multi}decision instance with the given decision added */ public Multi add(Decision decision) { - decisions.add(decision); + assert decision instanceof Single; + decisions.add((Single) decision); return this; } @@ -300,26 +273,6 @@ public List getDecisions() { return Collections.unmodifiableList(this.decisions); } - @Override - public boolean equals(final Object object) { - if (this == object) { - return true; - } - - if (object == null || getClass() != object.getClass()) { - return false; - } - - final Decision.Multi m = (Decision.Multi) object; - - return this.decisions.equals(m.decisions); - } - - @Override - public int hashCode() { - return 31 * decisions.hashCode(); - } - @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/FrequencyCappedAction.java b/server/src/main/java/org/elasticsearch/common/FrequencyCappedAction.java similarity index 95% rename from server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/FrequencyCappedAction.java rename to server/src/main/java/org/elasticsearch/common/FrequencyCappedAction.java index 299082bcc9c9d..b83a8a3fbc0f7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/FrequencyCappedAction.java +++ b/server/src/main/java/org/elasticsearch/common/FrequencyCappedAction.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.cluster.routing.allocation.allocator; +package org.elasticsearch.common; import org.elasticsearch.core.TimeValue; diff --git a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java index bf8451b81c55d..41faefaafc636 100644 --- a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java +++ b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java @@ -84,6 +84,7 @@ public enum ReferenceDocs { ALLOCATION_EXPLAIN_MAX_RETRY, SECURE_SETTINGS, CLUSTER_SHARD_LIMIT, + DEPLOY_CLOUD_DIFF_FROM_STATEFUL, // this comment keeps the ';' on the next line so every entry above has a trailing ',' which makes the diff for adding new links cleaner ; diff --git a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java index 99ed0917b12bf..d4162a3996032 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java @@ -739,15 +739,26 @@ public static Version parseVersionLenient(String toParse, Version defaultValue) * If no SegmentReader can be extracted an {@link IllegalStateException} is thrown. */ public static SegmentReader segmentReader(LeafReader reader) { + SegmentReader segmentReader = tryUnwrapSegmentReader(reader); + if (segmentReader == null) { + throw new IllegalStateException("Can not extract segment reader from given index reader [" + reader + "]"); + } + return segmentReader; + } + + /** + * Tries to extract a segment reader from the given index reader. Unlike {@link #segmentReader(LeafReader)} this method returns + * null if no SegmentReader can be unwrapped instead of throwing an exception. + */ + public static SegmentReader tryUnwrapSegmentReader(LeafReader reader) { if (reader instanceof SegmentReader) { return (SegmentReader) reader; } else if (reader instanceof final FilterLeafReader fReader) { - return segmentReader(FilterLeafReader.unwrap(fReader)); + return tryUnwrapSegmentReader(FilterLeafReader.unwrap(fReader)); } else if (reader instanceof final FilterCodecReader fReader) { - return segmentReader(FilterCodecReader.unwrap(fReader)); + return tryUnwrapSegmentReader(FilterCodecReader.unwrap(fReader)); } - // hard fail - we can't get a SegmentReader - throw new IllegalStateException("Can not extract segment reader from given index reader [" + reader + "]"); + return null; } @SuppressForbidden(reason = "Version#parseLeniently() used in a central place") diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java index 4fd5225a29167..c39ce209bf875 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java @@ -577,24 +577,43 @@ public void rejectedExecution(Runnable task, ThreadPoolExecutor executor) { } public static class TaskTrackingConfig { - // This is a random starting point alpha. TODO: revisit this with actual testing and/or make it configurable - public static final double DEFAULT_EWMA_ALPHA = 0.3; + // This is a random starting point alpha. + public static final double DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST = 0.3; private final boolean trackExecutionTime; private final boolean trackOngoingTasks; - private final double ewmaAlpha; - - public static final TaskTrackingConfig DO_NOT_TRACK = new TaskTrackingConfig(false, false, DEFAULT_EWMA_ALPHA); - public static final TaskTrackingConfig DEFAULT = new TaskTrackingConfig(true, false, DEFAULT_EWMA_ALPHA); - - public TaskTrackingConfig(boolean trackOngoingTasks, double ewmaAlpha) { - this(true, trackOngoingTasks, ewmaAlpha); - } + private final boolean trackMaxQueueLatency; + private final double executionTimeEwmaAlpha; + + public static final TaskTrackingConfig DO_NOT_TRACK = new TaskTrackingConfig( + false, + false, + false, + DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST + ); + public static final TaskTrackingConfig DEFAULT = new TaskTrackingConfig( + true, + false, + false, + DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST + ); - private TaskTrackingConfig(boolean trackExecutionTime, boolean trackOngoingTasks, double EWMAAlpha) { + /** + * @param trackExecutionTime Whether to track execution stats + * @param trackOngoingTasks Whether to track ongoing task execution time, not just finished tasks + * @param trackMaxQueueLatency Whether to track max queue latency. + * @param executionTimeEWMAAlpha The alpha seed for execution time EWMA (ExponentiallyWeightedMovingAverage). + */ + private TaskTrackingConfig( + boolean trackExecutionTime, + boolean trackOngoingTasks, + boolean trackMaxQueueLatency, + double executionTimeEWMAAlpha + ) { this.trackExecutionTime = trackExecutionTime; this.trackOngoingTasks = trackOngoingTasks; - this.ewmaAlpha = EWMAAlpha; + this.trackMaxQueueLatency = trackMaxQueueLatency; + this.executionTimeEwmaAlpha = executionTimeEWMAAlpha; } public boolean trackExecutionTime() { @@ -605,8 +624,45 @@ public boolean trackOngoingTasks() { return trackOngoingTasks; } - public double getEwmaAlpha() { - return ewmaAlpha; + public boolean trackMaxQueueLatency() { + return trackMaxQueueLatency; + } + + public double getExecutionTimeEwmaAlpha() { + return executionTimeEwmaAlpha; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private boolean trackExecutionTime = false; + private boolean trackOngoingTasks = false; + private boolean trackMaxQueueLatency = false; + private double ewmaAlpha = DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST; + + public Builder() {} + + public Builder trackExecutionTime(double alpha) { + trackExecutionTime = true; + ewmaAlpha = alpha; + return this; + } + + public Builder trackOngoingTasks() { + trackOngoingTasks = true; + return this; + } + + public Builder trackMaxQueueLatency() { + trackMaxQueueLatency = true; + return this; + } + + public TaskTrackingConfig build() { + return new TaskTrackingConfig(trackExecutionTime, trackOngoingTasks, trackMaxQueueLatency, ewmaAlpha); + } } } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java index 2b1a5ff6e9c0c..762a8c280b7f3 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java @@ -27,6 +27,7 @@ import java.util.concurrent.RejectedExecutionHandler; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.LongAccumulator; import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; @@ -37,7 +38,6 @@ * An extension to thread pool executor, which tracks statistics for the task execution time. */ public final class TaskExecutionTimeTrackingEsThreadPoolExecutor extends EsThreadPoolExecutor { - public static final int QUEUE_LATENCY_HISTOGRAM_BUCKETS = 18; private static final int[] LATENCY_PERCENTILES_TO_REPORT = { 50, 90, 99 }; @@ -47,9 +47,17 @@ public final class TaskExecutionTimeTrackingEsThreadPoolExecutor extends EsThrea private final boolean trackOngoingTasks; // The set of currently running tasks and the timestamp of when they started execution in the Executor. private final Map ongoingTasks = new ConcurrentHashMap<>(); - private volatile long lastPollTime = System.nanoTime(); - private volatile long lastTotalExecutionTime = 0; private final ExponentialBucketHistogram queueLatencyMillisHistogram = new ExponentialBucketHistogram(QUEUE_LATENCY_HISTOGRAM_BUCKETS); + private final boolean trackMaxQueueLatency; + private LongAccumulator maxQueueLatencyMillisSinceLastPoll = new LongAccumulator(Long::max, 0); + + public enum UtilizationTrackingPurpose { + APM, + ALLOCATION, + } + + private volatile UtilizationTracker apmUtilizationTracker = new UtilizationTracker(); + private volatile UtilizationTracker allocationUtilizationTracker = new UtilizationTracker(); TaskExecutionTimeTrackingEsThreadPoolExecutor( String name, @@ -65,9 +73,11 @@ public final class TaskExecutionTimeTrackingEsThreadPoolExecutor extends EsThrea TaskTrackingConfig trackingConfig ) { super(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler, contextHolder); + this.runnableWrapper = runnableWrapper; - this.executionEWMA = new ExponentiallyWeightedMovingAverage(trackingConfig.getEwmaAlpha(), 0); + this.executionEWMA = new ExponentiallyWeightedMovingAverage(trackingConfig.getExecutionTimeEwmaAlpha(), 0); this.trackOngoingTasks = trackingConfig.trackOngoingTasks(); + this.trackMaxQueueLatency = trackingConfig.trackMaxQueueLatency(); } public List setupMetrics(MeterRegistry meterRegistry, String threadPoolName) { @@ -95,7 +105,7 @@ public List setupMetrics(MeterRegistry meterRegistry, String threadP ThreadPool.THREAD_POOL_METRIC_PREFIX + threadPoolName + THREAD_POOL_METRIC_NAME_UTILIZATION, "fraction of maximum thread time utilized for " + threadPoolName, "fraction", - () -> new DoubleWithAttributes(pollUtilization(), Map.of()) + () -> new DoubleWithAttributes(pollUtilization(UtilizationTrackingPurpose.APM), Map.of()) ) ); } @@ -136,24 +146,30 @@ public int getCurrentQueueSize() { return getQueue().size(); } + public long getMaxQueueLatencyMillisSinceLastPollAndReset() { + if (trackMaxQueueLatency == false) { + return 0; + } + return maxQueueLatencyMillisSinceLastPoll.getThenReset(); + } + /** - * Returns the fraction of the maximum possible thread time that was actually used since the last time - * this method was called. + * Returns the fraction of the maximum possible thread time that was actually used since the last time this method was called. + * There are two periodic pulling mechanisms that access utilization reporting: {@link UtilizationTrackingPurpose} distinguishes the + * caller. * - * @return the utilization as a fraction, in the range [0, 1] + * @return the utilization as a fraction, in the range [0, 1]. This may return >1 if a task completed in the time range but started + * earlier, contributing a larger execution time. */ - public double pollUtilization() { - final long currentTotalExecutionTimeNanos = totalExecutionTime.sum(); - final long currentPollTimeNanos = System.nanoTime(); - - final long totalExecutionTimeSinceLastPollNanos = currentTotalExecutionTimeNanos - lastTotalExecutionTime; - final long timeSinceLastPoll = currentPollTimeNanos - lastPollTime; - final long maximumExecutionTimeSinceLastPollNanos = timeSinceLastPoll * getMaximumPoolSize(); - final double utilizationSinceLastPoll = (double) totalExecutionTimeSinceLastPollNanos / maximumExecutionTimeSinceLastPollNanos; - - lastTotalExecutionTime = currentTotalExecutionTimeNanos; - lastPollTime = currentPollTimeNanos; - return utilizationSinceLastPoll; + public double pollUtilization(UtilizationTrackingPurpose utilizationTrackingPurpose) { + switch (utilizationTrackingPurpose) { + case APM: + return apmUtilizationTracker.pollUtilization(); + case ALLOCATION: + return allocationUtilizationTracker.pollUtilization(); + default: + throw new IllegalStateException("No operation defined for [" + utilizationTrackingPurpose + "]"); + } } @Override @@ -161,12 +177,18 @@ protected void beforeExecute(Thread t, Runnable r) { if (trackOngoingTasks) { ongoingTasks.put(r, System.nanoTime()); } + assert super.unwrap(r) instanceof TimedRunnable : "expected only TimedRunnables in queue"; final TimedRunnable timedRunnable = (TimedRunnable) super.unwrap(r); timedRunnable.beforeExecute(); final long taskQueueLatency = timedRunnable.getQueueTimeNanos(); assert taskQueueLatency >= 0; - queueLatencyMillisHistogram.addObservation(TimeUnit.NANOSECONDS.toMillis(taskQueueLatency)); + var queueLatencyMillis = TimeUnit.NANOSECONDS.toMillis(taskQueueLatency); + queueLatencyMillisHistogram.addObservation(queueLatencyMillis); + + if (trackMaxQueueLatency) { + maxQueueLatencyMillisSinceLastPoll.accumulate(queueLatencyMillis); + } } @Override @@ -222,7 +244,39 @@ public Map getOngoingTasks() { } // Used for testing - public double getEwmaAlpha() { + public double getExecutionEwmaAlpha() { return executionEWMA.getAlpha(); } + + // Used for testing + public boolean trackingMaxQueueLatency() { + return trackMaxQueueLatency; + } + + /** + * Supports periodic polling for thread pool utilization. Tracks state since the last polling request so that the average utilization + * since the last poll can be calculated for the next polling request. + * + * Uses the difference of {@link #totalExecutionTime} since the last polling request to determine how much activity has occurred. + */ + private class UtilizationTracker { + long lastPollTime = System.nanoTime(); + long lastTotalExecutionTime = 0; + + public synchronized double pollUtilization() { + final long currentTotalExecutionTimeNanos = totalExecutionTime.sum(); + final long currentPollTimeNanos = System.nanoTime(); + + final long totalExecutionTimeSinceLastPollNanos = currentTotalExecutionTimeNanos - lastTotalExecutionTime; + final long timeSinceLastPoll = currentPollTimeNanos - lastPollTime; + + final long maximumExecutionTimeSinceLastPollNanos = timeSinceLastPoll * getMaximumPoolSize(); + final double utilizationSinceLastPoll = (double) totalExecutionTimeSinceLastPollNanos / maximumExecutionTimeSinceLastPollNanos; + + lastTotalExecutionTime = currentTotalExecutionTimeNanos; + lastPollTime = currentPollTimeNanos; + + return utilizationSinceLastPoll; + } + } } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index 641956fce5165..699baea2052be 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -189,37 +189,46 @@ public StoredContext newEmptySystemContext() { * moving tracing-related fields to different names so that a new child span can be started. This child span will pick up * the moved fields and use them to establish the parent-child relationship. * + * Response headers will be propagated. If no parent span is in progress (meaning there's no trace context), this will behave exactly + * the same way as {@link #newStoredContextPreservingResponseHeaders}. + * * @return a stored context, which can be restored when this context is no longer needed. */ public StoredContext newTraceContext() { final ThreadContextStruct originalContext = threadLocal.get(); - final Map newRequestHeaders = new HashMap<>(originalContext.requestHeaders); - final Map newTransientHeaders = new HashMap<>(originalContext.transientHeaders); - final String previousTraceParent = newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER); - if (previousTraceParent != null) { - newTransientHeaders.put(Task.PARENT_TRACE_PARENT_HEADER, previousTraceParent); - } + // this is the context when this method returns + final ThreadContextStruct newContext; + if (originalContext.hasTraceContext() == false) { + newContext = originalContext; + } else { + final Map newRequestHeaders = new HashMap<>(originalContext.requestHeaders); + final Map newTransientHeaders = new HashMap<>(originalContext.transientHeaders); - final String previousTraceState = newRequestHeaders.remove(Task.TRACE_STATE); - if (previousTraceState != null) { - newTransientHeaders.put(Task.PARENT_TRACE_STATE, previousTraceState); - } + final String previousTraceParent = newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER); + if (previousTraceParent != null) { + newTransientHeaders.put(Task.PARENT_TRACE_PARENT_HEADER, previousTraceParent); + } - final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT); - if (previousTraceContext != null) { - newTransientHeaders.put(Task.PARENT_APM_TRACE_CONTEXT, previousTraceContext); - } + final String previousTraceState = newRequestHeaders.remove(Task.TRACE_STATE); + if (previousTraceState != null) { + newTransientHeaders.put(Task.PARENT_TRACE_STATE, previousTraceState); + } - // this is the context when this method returns - final ThreadContextStruct newContext = new ThreadContextStruct( - newRequestHeaders, - originalContext.responseHeaders, - newTransientHeaders, - originalContext.isSystemContext, - originalContext.warningHeadersSize - ); - threadLocal.set(newContext); + final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT); + if (previousTraceContext != null) { + newTransientHeaders.put(Task.PARENT_APM_TRACE_CONTEXT, previousTraceContext); + } + + newContext = new ThreadContextStruct( + newRequestHeaders, + originalContext.responseHeaders, + newTransientHeaders, + originalContext.isSystemContext, + originalContext.warningHeadersSize + ); + threadLocal.set(newContext); + } // Tracing shouldn't interrupt the propagation of response headers, so in the same as // #newStoredContextPreservingResponseHeaders(), pass on any potential changes to the response headers. return () -> { @@ -233,10 +242,11 @@ public StoredContext newTraceContext() { } public boolean hasTraceContext() { - final ThreadContextStruct context = threadLocal.get(); - return context.requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER) - || context.requestHeaders.containsKey(Task.TRACE_STATE) - || context.transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); + return threadLocal.get().hasTraceContext(); + } + + public boolean hasParentTraceContext() { + return threadLocal.get().hasParentTraceContext(); } /** @@ -727,6 +737,7 @@ public void sanitizeHeaders() { entry -> entry.getKey().equalsIgnoreCase("authorization") || entry.getKey().equalsIgnoreCase("es-secondary-authorization") || entry.getKey().equalsIgnoreCase("ES-Client-Authentication") + || entry.getKey().equalsIgnoreCase("X-Client-Authentication") ); final ThreadContextStruct newContext = new ThreadContextStruct( @@ -853,6 +864,18 @@ private ThreadContextStruct putResponseHeaders(Map> headers) return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext); } + private boolean hasTraceContext() { + return requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER) + || requestHeaders.containsKey(Task.TRACE_STATE) + || transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); + } + + private boolean hasParentTraceContext() { + return transientHeaders.containsKey(Task.PARENT_TRACE_PARENT_HEADER) + || transientHeaders.containsKey(Task.PARENT_TRACE_STATE) + || transientHeaders.containsKey(Task.PARENT_APM_TRACE_CONTEXT); + } + private void logWarningHeaderThresholdExceeded(long threshold, Setting thresholdSetting) { // If available, log some selected headers to help identifying the source of the request. // Note: Only Task.HEADERS_TO_COPY are guaranteed to be preserved at this point. diff --git a/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java b/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java index 0a7451702ec66..2eeb8c470b5d8 100644 --- a/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java +++ b/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java @@ -16,12 +16,14 @@ import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.persistent.AllocatedPersistentTask; @@ -133,10 +135,11 @@ protected HealthNode createTask( * Returns the node id from the eligible health nodes */ @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( HealthNodeTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, candidateNodes, DiscoveryNode::canContainData); if (discoveryNode == null) { diff --git a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java index a01c48332fbb8..9445936fc9984 100644 --- a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java @@ -41,6 +41,8 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.tasks.Task; import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.telemetry.metric.LongWithAttributes; +import org.elasticsearch.telemetry.metric.MeterRegistry; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.BindTransportException; @@ -103,6 +105,8 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo private final HttpTracer httpLogger; private final Tracer tracer; + private final MeterRegistry meterRegistry; + private final List metricsToClose = new ArrayList<>(2); private volatile boolean shuttingDown; private final ReadWriteLock shuttingDownRWLock = new StampedLock().asReadWriteLock(); @@ -142,6 +146,7 @@ protected AbstractHttpServerTransport( this.maxContentLength = SETTING_HTTP_MAX_CONTENT_LENGTH.get(settings); this.tracer = telemetryProvider.getTracer(); + this.meterRegistry = telemetryProvider.getMeterRegistry(); this.httpLogger = new HttpTracer(settings, clusterSettings); clusterSettings.addSettingsUpdateConsumer( TransportSettings.SLOW_OPERATION_THRESHOLD_SETTING, @@ -238,6 +243,22 @@ private TransportAddress bindAddress(final InetAddress hostAddress) { @Override protected final void doStart() { + metricsToClose.add( + meterRegistry.registerLongAsyncCounter( + "es.http.connections.total", + "total number of inbound HTTP connections accepted", + "count", + () -> new LongWithAttributes(totalChannelsAccepted.get()) + ) + ); + metricsToClose.add( + meterRegistry.registerLongGauge( + "es.http.connections.current", + "number of inbound HTTP connections currently open", + "count", + () -> new LongWithAttributes(httpChannels.size()) + ) + ); startInternal(); } @@ -328,6 +349,16 @@ protected final void doStop() { logger.warn("unexpected exception while waiting for http channels to close", e); } } + + for (final var metricToClose : metricsToClose) { + try { + metricToClose.close(); + } catch (Exception e) { + logger.warn("unexpected exception while closing metric [{}]", metricToClose); + assert false : e; + } + } + stopInternal(); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/TrackingPostingsInMemoryBytesCodec.java b/server/src/main/java/org/elasticsearch/index/codec/TrackingPostingsInMemoryBytesCodec.java new file mode 100644 index 0000000000000..92aebd83398ce --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/TrackingPostingsInMemoryBytesCodec.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FieldsConsumer; +import org.apache.lucene.codecs.FieldsProducer; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.NormsProducer; +import org.apache.lucene.codecs.PostingsFormat; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.Fields; +import org.apache.lucene.index.FilterLeafReader; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.internal.hppc.IntIntHashMap; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.FeatureFlag; + +import java.io.IOException; +import java.util.function.IntConsumer; + +/** + * A codec that tracks the length of the min and max written terms. Used to improve memory usage estimates in serverless, since + * {@link org.apache.lucene.codecs.lucene90.blocktree.FieldReader} keeps an in-memory reference to the min and max term. + */ +public class TrackingPostingsInMemoryBytesCodec extends FilterCodec { + public static final FeatureFlag TRACK_POSTINGS_IN_MEMORY_BYTES = new FeatureFlag("track_postings_in_memory_bytes"); + + public static final String IN_MEMORY_POSTINGS_BYTES_KEY = "es.postings.in_memory_bytes"; + + public TrackingPostingsInMemoryBytesCodec(Codec delegate) { + super(delegate.getName(), delegate); + } + + @Override + public PostingsFormat postingsFormat() { + PostingsFormat format = super.postingsFormat(); + + return new PostingsFormat(format.getName()) { + @Override + public FieldsConsumer fieldsConsumer(SegmentWriteState state) throws IOException { + FieldsConsumer consumer = format.fieldsConsumer(state); + return new TrackingLengthFieldsConsumer(state, consumer); + } + + @Override + public FieldsProducer fieldsProducer(SegmentReadState state) throws IOException { + return format.fieldsProducer(state); + } + }; + } + + static final class TrackingLengthFieldsConsumer extends FieldsConsumer { + final SegmentWriteState state; + final FieldsConsumer in; + final IntIntHashMap termsBytesPerField; + + TrackingLengthFieldsConsumer(SegmentWriteState state, FieldsConsumer in) { + this.state = state; + this.in = in; + this.termsBytesPerField = new IntIntHashMap(state.fieldInfos.size()); + } + + @Override + public void write(Fields fields, NormsProducer norms) throws IOException { + in.write(new TrackingLengthFields(fields, termsBytesPerField, state.fieldInfos), norms); + long totalBytes = 0; + for (int bytes : termsBytesPerField.values) { + totalBytes += bytes; + } + state.segmentInfo.putAttribute(IN_MEMORY_POSTINGS_BYTES_KEY, Long.toString(totalBytes)); + } + + @Override + public void close() throws IOException { + in.close(); + } + } + + static final class TrackingLengthFields extends FilterLeafReader.FilterFields { + final IntIntHashMap termsBytesPerField; + final FieldInfos fieldInfos; + + TrackingLengthFields(Fields in, IntIntHashMap termsBytesPerField, FieldInfos fieldInfos) { + super(in); + this.termsBytesPerField = termsBytesPerField; + this.fieldInfos = fieldInfos; + } + + @Override + public Terms terms(String field) throws IOException { + Terms terms = super.terms(field); + if (terms == null) { + return null; + } + int fieldNum = fieldInfos.fieldInfo(field).number; + return new TrackingLengthTerms( + terms, + bytes -> termsBytesPerField.put(fieldNum, Math.max(termsBytesPerField.getOrDefault(fieldNum, 0), bytes)) + ); + } + } + + static final class TrackingLengthTerms extends FilterLeafReader.FilterTerms { + final IntConsumer onFinish; + + TrackingLengthTerms(Terms in, IntConsumer onFinish) { + super(in); + this.onFinish = onFinish; + } + + @Override + public TermsEnum iterator() throws IOException { + return new TrackingLengthTermsEnum(super.iterator(), onFinish); + } + } + + static final class TrackingLengthTermsEnum extends FilterLeafReader.FilterTermsEnum { + int maxTermLength = 0; + int minTermLength = 0; + int termCount = 0; + final IntConsumer onFinish; + + TrackingLengthTermsEnum(TermsEnum in, IntConsumer onFinish) { + super(in); + this.onFinish = onFinish; + } + + @Override + public BytesRef next() throws IOException { + final BytesRef term = super.next(); + if (term != null) { + if (termCount == 0) { + minTermLength = term.length; + } + maxTermLength = term.length; + termCount++; + } else { + if (termCount == 1) { + // If the minTerm and maxTerm are the same, only one instance is kept on the heap. + assert minTermLength == maxTermLength; + onFinish.accept(maxTermLength); + } else { + onFinish.accept(maxTermLength + minTermLength); + } + } + return term; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java index b5e1276d0747a..e92ece41077a6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -9,10 +9,11 @@ package org.elasticsearch.index.codec.vectors; -record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) { +record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { - CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) { - this(centroids.length, centroids, assignmentsByCluster); - assert centroids.length == assignmentsByCluster.length; + CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { + this(centroids.length, centroids, assignments, overspillAssignments); + assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0 + : "assignments and overspillAssignments must have the same length"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index ac95f3c8ad0af..304cc57284227 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -67,10 +67,10 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); return new CentroidQueryScorer() { int currentCentroid = -1; - private final float[] centroid = new float[fieldInfo.getVectorDimension()]; + long postingListOffset; private final float[] centroidCorrectiveValues = new float[3]; - private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES); - private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension(); + private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + + Short.BYTES); @Override public int size() { @@ -78,13 +78,13 @@ public int size() { } @Override - public float[] centroid(int centroidOrdinal) throws IOException { + public long postingListOffset(int centroidOrdinal) throws IOException { if (centroidOrdinal != currentCentroid) { - centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal); - centroids.readFloats(centroid, 0, centroid.length); + centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal); + postingListOffset = centroids.readLong(); currentCentroid = centroidOrdinal; } - return centroid; + return postingListOffset; } public void bulkScore(NeighborQueue queue) throws IOException { @@ -181,7 +181,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { int vectors; boolean quantized = false; float centroidDp; - float[] centroid; + final float[] centroid; long slicePos; OptimizedScalarQuantizer.QuantizationResult queryCorrections; DocIdsWriter docIdsWriter = new DocIdsWriter(); @@ -205,7 +205,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { this.entry = entry; this.fieldInfo = fieldInfo; this.needsScoring = needsScoring; - + centroid = new float[fieldInfo.getVectorDimension()]; scratch = new float[target.length]; quantizationScratch = new int[target.length]; final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); @@ -217,12 +217,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { } @Override - public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException { + public int resetPostingsScorer(long offset) throws IOException { quantized = false; - indexInput.seek(entry.postingListOffsets()[centroidOrdinal]); - vectors = indexInput.readVInt(); + indexInput.seek(offset); + indexInput.readFloats(centroid, 0, centroid.length); centroidDp = Float.intBitsToFloat(indexInput.readInt()); - this.centroid = centroid; + vectors = indexInput.readVInt(); // read the doc ids docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; docIdsWriter.readInts(indexInput, vectors, docIdsScratch); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 84abb1bea543f..f47ecc549831a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -14,9 +14,14 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.LongValues; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.IntToIntFunction; +import org.apache.lucene.util.packed.PackedInts; +import org.apache.lucene.util.packed.PackedLongValues; import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.logging.LogManager; @@ -44,44 +49,200 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec } @Override - long[] buildAndWritePostingsLists( + LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, - int[][] assignmentsByCluster + int[] assignments, + int[] overspillAssignments ) throws IOException { + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != -1) { + assignmentsByCluster[s][centroidVectorCount[s]++] = i; + } + } + } // write the posting lists - final long[] offsets = new long[centroidSupplier.size()]; - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); DocIdsWriter docIdsWriter = new DocIdsWriter(); - DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( - ES91OSQVectorsScorer.BULK_SIZE, - quantizer, + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, - postingsOutput + fieldInfo.getVectorDimension(), + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) ); + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); - // TODO: add back in sorting vectors by distance to centroid int[] cluster = assignmentsByCluster[c]; - // TODO align??? - offsets[c] = postingsOutput.getFilePointer(); + offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); + buffer.asFloatBuffer().put(centroid); + // write raw centroid for quantizing the query vectors + postingsOutput.writeBytes(buffer.array(), buffer.array().length); + // write centroid dot product for quantizing the query vectors + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); int size = cluster.length; + // write docIds postingsOutput.writeVInt(size); - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); + // write vectors + bulkWriter.writeVectors(onHeapQuantizedVectors); } if (logger.isDebugEnabled()) { printClusterQualityStatistics(assignmentsByCluster); } - return offsets; + return offsets.build(); + } + + @Override + LongValues buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState, + int[] assignments, + int[] overspillAssignments + ) throws IOException { + // first, quantize all the vectors into a temporary file + String quantizedVectorsTempName = null; + IndexOutput quantizedVectorsTemp = null; + boolean success = false; + try { + quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT); + quantizedVectorsTempName = quantizedVectorsTemp.getName(); + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + int[] quantized = new int[fieldInfo.getVectorDimension()]; + byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; + float[] overspillScratch = new float[fieldInfo.getVectorDimension()]; + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + float[] centroid = centroidSupplier.centroid(c); + float[] vector = floatVectorValues.vectorValue(i); + boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1; + // if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector + // so, make a copy of the vector to avoid mutating it + if (overspill) { + System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension()); + } + + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + if (overspill) { + int s = overspillAssignments[i]; + // write the overspill vector as well + result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s)); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + } else { + // write a zero vector for the overspill + Arrays.fill(binary, (byte) 0); + OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0); + writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult); + } + } + // close the temporary file so we can read it later + quantizedVectorsTemp.close(); + success = true; + } finally { + if (success == false && quantizedVectorsTemp != null) { + mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); + } + } + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + isOverspillByCluster[c] = new boolean[centroidVectorCount[c]]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != -1) { + assignmentsByCluster[s][centroidVectorCount[s]] = i; + isOverspillByCluster[s][centroidVectorCount[s]++] = true; + } + } + } + // now we can read the quantized vectors from the temporary file + try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { + final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( + quantizedVectorsInput, + fieldInfo.getVectorDimension() + ); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int c = 0; c < centroidSupplier.size(); c++) { + float[] centroid = centroidSupplier.centroid(c); + int[] cluster = assignmentsByCluster[c]; + boolean[] isOverspill = isOverspillByCluster[c]; + offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); + // write raw centroid for quantizing the query vectors + buffer.asFloatBuffer().put(centroid); + postingsOutput.writeBytes(buffer.array(), buffer.array().length); + // write centroid dot product for quantizing the query vectors + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // write docIds + int size = cluster.length; + postingsOutput.writeVInt(size); + offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + // write vectors + bulkWriter.writeVectors(offHeapQuantizedVectors); + } + + if (logger.isDebugEnabled()) { + printClusterQualityStatistics(assignmentsByCluster); + } + return offsets.build(); + } } private static void printClusterQualityStatistics(int[][] clusters) { @@ -119,8 +280,15 @@ CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentro return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo); } - static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) - throws IOException { + @Override + void writeCentroids( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + LongValues offsets, + IndexOutput centroidOutput + ) throws IOException { + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); int[] quantizedScratch = new int[fieldInfo.getVectorDimension()]; float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; @@ -128,7 +296,8 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo // TODO do we want to store these distances as well for future use? // TODO: sort centroids by global centroid (was doing so previously here) // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned - for (float[] centroid : centroids) { + for (int i = 0; i < centroidSupplier.size(); i++) { + float[] centroid = centroidSupplier.centroid(i); System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize( centroidScratch, @@ -136,54 +305,36 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo (byte) 4, globalCentroid ); - for (int i = 0; i < quantizedScratch.length; i++) { - quantized[i] = (byte) quantizedScratch[i]; + for (int j = 0; j < quantizedScratch.length; j++) { + quantized[j] = (byte) quantizedScratch[j]; } writeQuantizedValue(centroidOutput, quantized, result); } - final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (float[] centroid : centroids) { - buffer.asFloatBuffer().put(centroid); - centroidOutput.writeBytes(buffer.array(), buffer.array().length); + // write the centroid offsets at the end of the file + for (int i = 0; i < centroidSupplier.size(); i++) { + centroidOutput.writeLong(offsets.get(i)); } } - @Override - CentroidAssignments calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - MergeState mergeState, - float[] globalCentroid - ) throws IOException { - // TODO: take advantage of prior generated clusters from mergeState in the future - return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid); - } - /** - * Calculate the centroids for the given field and write them to the given centroid output. + * Calculate the centroids for the given field. * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments * * @param fieldInfo merging field info * @param floatVectorValues the float vector values to merge - * @param centroidOutput the centroid output * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed * @throws IOException if an I/O error occurs */ @Override - CentroidAssignments calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid - ) throws IOException { + CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + throws IOException { long nanoTime = System.nanoTime(); // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids - KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); - float[][] centroids = kMeansResult.centroids(); + CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster); + float[][] centroids = centroidAssignments.centroids(); // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative // preliminary tests suggest recall is good using only centroids but need to do further evaluation // TODO: push this logic into vector util? @@ -196,47 +347,19 @@ CentroidAssignments calculateAndWriteCentroids( globalCentroid[j] /= centroids.length; } - // write centroids - writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); - if (logger.isDebugEnabled()) { logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0); logger.debug("final centroid count: {}", centroids.length); } - return buildCentroidAssignments(kMeansResult); + return centroidAssignments; } - static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) { + static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { + KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); int[] soarAssignments = kMeansResult.soarAssignments(); - int[] centroidVectorCount = new int[centroids.length]; - for (int i = 0; i < assignments.length; i++) { - centroidVectorCount[assignments[i]]++; - // if soar assignments are present, count them as well - if (soarAssignments.length > i && soarAssignments[i] != -1) { - centroidVectorCount[soarAssignments[i]]++; - } - } - - int[][] assignmentsByCluster = new int[centroids.length][]; - for (int c = 0; c < centroids.length; c++) { - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; - } - Arrays.fill(centroidVectorCount, 0); - - for (int i = 0; i < assignments.length; i++) { - int c = assignments[i]; - assignmentsByCluster[c][centroidVectorCount[c]++] = i; - // if soar assignments are present, add them to the cluster as well - if (soarAssignments.length > i) { - int s = soarAssignments[i]; - if (s != -1) { - assignmentsByCluster[s][centroidVectorCount[s]++] = i; - } - } - } - return new CentroidAssignments(centroids, assignmentsByCluster); + return new CentroidAssignments(centroids, assignments, soarAssignments); } static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) @@ -254,7 +377,6 @@ static class OffHeapCentroidSupplier implements CentroidSupplier { private final int numCentroids; private final int dimension; private final float[] scratch; - private final long rawCentroidOffset; private int currOrd = -1; OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) { @@ -262,7 +384,6 @@ static class OffHeapCentroidSupplier implements CentroidSupplier { this.numCentroids = numCentroids; this.dimension = info.getVectorDimension(); this.scratch = new float[dimension]; - this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids; } @Override @@ -275,10 +396,143 @@ public float[] centroid(int centroidOrdinal) throws IOException { if (centroidOrdinal == currOrd) { return scratch; } - centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.seek((long) centroidOrdinal * dimension * Float.BYTES); centroidsInput.readFloats(scratch, 0, dimension); this.currOrd = centroidOrdinal; return scratch; } } + + interface QuantizedVectorValues { + int count(); + + byte[] next() throws IOException; + + OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException; + } + + interface IntToBooleanFunction { + boolean apply(int ord); + } + + static class OnHeapQuantizedVectors implements QuantizedVectorValues { + private final FloatVectorValues vectorValues; + private final OptimizedScalarQuantizer quantizer; + private final byte[] quantizedVector; + private final int[] quantizedVectorScratch; + private final float[] floatVectorScratch; + private OptimizedScalarQuantizer.QuantizationResult corrections; + private float[] currentCentroid; + private IntToIntFunction ordTransformer = null; + private int currOrd = -1; + private int count; + + OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) { + this.vectorValues = vectorValues; + this.quantizer = quantizer; + this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.floatVectorScratch = new float[dimension]; + this.quantizedVectorScratch = new int[dimension]; + this.corrections = null; + } + + private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) { + this.currentCentroid = centroid; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + this.count = count; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count() - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count()); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + float[] vector = vectorValues.vectorValue(ord); + // Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice + // due to overspill, so we copy the vector to a scratch array + System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length); + corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 1, currentCentroid); + BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector); + return quantizedVector; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call next first"); + } + return corrections; + } + } + + static class OffHeapQuantizedVectors implements QuantizedVectorValues { + private final IndexInput quantizedVectorsInput; + private final byte[] binaryScratch; + private final float[] corrections = new float[3]; + + private final int vectorByteSize; + private short bitSum; + private int currOrd = -1; + private int count; + private IntToBooleanFunction isOverspill = null; + private IntToIntFunction ordTransformer = null; + + OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { + this.quantizedVectorsInput = quantizedVectorsInput; + this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); + } + + private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) { + this.count = count; + this.isOverspill = isOverspill; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + boolean isOverspill = this.isOverspill.apply(currOrd); + return getVector(ord, isOverspill); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call readQuantizedVector first"); + } + return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum); + } + + byte[] getVector(int ord, boolean isOverspill) throws IOException { + readQuantizedVector(ord, isOverspill); + return binaryScratch; + } + + public void readQuantizedVector(int ord, boolean isOverspill) throws IOException { + long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0); + quantizedVectorsInput.seek(offset); + quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length); + quantizedVectorsInput.readFloats(corrections, 0, 3); + bitSum = quantizedVectorsInput.readShort(); + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 6974cd50d4abc..662878270ea09 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -9,34 +9,25 @@ package org.elasticsearch.index.codec.vectors; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.hnsw.IntToIntFunction; import java.io.IOException; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary; - /** * Base class for bulk writers that write vectors to disk using the BBQ encoding. * This class provides the structure for writing vectors in bulk, with specific * implementations for different bit sizes strategies. */ -public abstract class DiskBBQBulkWriter { +abstract class DiskBBQBulkWriter { protected final int bulkSize; - protected final OptimizedScalarQuantizer quantizer; protected final IndexOutput out; - protected final FloatVectorValues fvv; - protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { + protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) { this.bulkSize = bulkSize; - this.quantizer = quantizer; this.out = out; - this.fvv = fvv; } - public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException; + abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException; private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { @@ -64,39 +55,31 @@ private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult out.writeShort((short) targetComponentSum); } - public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { - private final byte[] binarized; - private final int[] initQuantized; + static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; - public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { - super(bulkSize, quantizer, fvv, out); - this.binarized = new byte[discretize(fvv.dimension(), 64) / 8]; - this.initQuantized = new int[fvv.dimension()]; + OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + super(bulkSize, out); this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } @Override - public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException { - int limit = count - bulkSize + 1; + void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException { + int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { for (int j = 0; j < bulkSize; j++) { - int ord = ords.apply(i + j); - float[] fv = fvv.vectorValue(ord); - corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); - packAsBinary(initQuantized, binarized); - out.writeBytes(binarized, binarized.length); + byte[] qv = qvv.next(); + corrections[j] = qvv.getCorrections(); + out.writeBytes(qv, qv.length); } writeCorrections(corrections, out); } // write tail - for (; i < count; ++i) { - int ord = ords.apply(i); - float[] fv = fvv.vectorValue(ord); - OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); - packAsBinary(initQuantized, binarized); - out.writeBytes(binarized, binarized.length); + for (; i < qvv.count(); ++i) { + byte[] qv = qvv.next(); + OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); + out.writeBytes(qv, qv.length); writeCorrection(correction, out); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index dbcdfd451df95..01cced04a9fcc 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -140,19 +140,6 @@ private void readFields(ChecksumIndexInput meta) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { final VectorEncoding vectorEncoding = readVectorEncoding(input); final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); - final long centroidOffset = input.readLong(); - final long centroidLength = input.readLong(); - final int numPostingLists = input.readVInt(); - final long[] postingListOffsets = new long[numPostingLists]; - for (int i = 0; i < numPostingLists; i++) { - postingListOffsets[i] = input.readLong(); - } - final float[] globalCentroid = new float[info.getVectorDimension()]; - float globalCentroidDp = 0; - if (numPostingLists > 0) { - input.readFloats(globalCentroid, 0, globalCentroid.length); - globalCentroidDp = Float.intBitsToFloat(input.readInt()); - } if (similarityFunction != info.getVectorSimilarityFunction()) { throw new IllegalStateException( "Inconsistent vector similarity function for field=\"" @@ -163,12 +150,21 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio + info.getVectorSimilarityFunction() ); } + final int numCentroids = input.readInt(); + final long centroidOffset = input.readLong(); + final long centroidLength = input.readLong(); + final float[] globalCentroid = new float[info.getVectorDimension()]; + float globalCentroidDp = 0; + if (centroidLength > 0) { + input.readFloats(globalCentroid, 0, globalCentroid.length); + globalCentroidDp = Float.intBitsToFloat(input.readInt()); + } return new FieldEntry( similarityFunction, vectorEncoding, + numCentroids, centroidOffset, centroidLength, - postingListOffsets, globalCentroid, globalCentroidDp ); @@ -242,7 +238,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector FieldEntry entry = fields.get(fieldInfo.number); CentroidQueryScorer centroidQueryScorer = getCentroidScorer( fieldInfo, - entry.postingListOffsets.length, + entry.numCentroids, entry.centroidSlice(ivfCentroids), target ); @@ -270,7 +266,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector int centroidOrdinal = centroidQueue.pop(); // todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing // is enough? - expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal)); actualDocs += scorer.visit(knnCollector); } if (acceptDocs != null) { @@ -279,7 +275,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { int centroidOrdinal = centroidQueue.pop(); - scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal)); actualDocs += scorer.visit(knnCollector); } } @@ -313,9 +309,9 @@ public void close() throws IOException { protected record FieldEntry( VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, + int numCentroids, long centroidOffset, long centroidLength, - long[] postingListOffsets, float[] globalCentroid, float globalCentroidDp ) { @@ -330,7 +326,7 @@ abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postin interface CentroidQueryScorer { int size(); - float[] centroid(int centroidOrdinal) throws IOException; + long postingListOffset(int centroidOrdinal) throws IOException; void bulkScore(NeighborQueue queue) throws IOException; } @@ -339,7 +335,7 @@ interface PostingVisitor { // TODO maybe we can not specifically pass the centroid... /** returns the number of documents in the posting list */ - int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException; + int resetPostingsScorer(long offset) throws IOException; /** returns the number of scored documents */ int visit(KnnCollector collector) throws IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index e6da0ae1caff0..149db2eb96b83 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -28,6 +28,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.RandomAccessInput; +import org.apache.lucene.util.LongValues; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.SuppressForbidden; @@ -119,27 +120,34 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc return rawVectorDelegate; } - abstract CentroidAssignments calculateAndWriteCentroids( + abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + throws IOException; + + abstract void writeCentroids( FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - MergeState mergeState, - float[] globalCentroid + CentroidSupplier centroidSupplier, + float[] globalCentroid, + LongValues centroidOffset, + IndexOutput centroidOutput ) throws IOException; - abstract CentroidAssignments calculateAndWriteCentroids( + abstract LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid + IndexOutput postingsOutput, + int[] assignments, + int[] overspillAssignments ) throws IOException; - abstract long[] buildAndWritePostingsLists( + abstract LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, - int[][] assignmentsByCluster + MergeState mergeState, + int[] assignments, + int[] overspillAssignments ) throws IOException; abstract CentroidSupplier createCentroidSupplier( @@ -153,31 +161,28 @@ abstract CentroidSupplier createCentroidSupplier( public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { rawVectorDelegate.flush(maxDoc, sortMap); for (FieldWriter fieldWriter : fieldWriters) { - float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; + final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; // build a float vector values with random access final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); // build centroids - long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - - final CentroidAssignments centroidAssignments = calculateAndWriteCentroids( - fieldWriter.fieldInfo, - floatVectorValues, - ivfCentroids, - globalCentroid - ); - - CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids()); - - long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; - final long[] offsets = buildAndWritePostingsLists( + final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid); + // wrap centroids with a supplier + final CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids()); + // write posting lists + final LongValues offsets = buildAndWritePostingsLists( fieldWriter.fieldInfo, centroidSupplier, floatVectorValues, ivfClusters, - centroidAssignments.assignmentsByCluster() + centroidAssignments.assignments(), + centroidAssignments.overspillAssignments() ); - // write posting lists - writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + // write centroids + final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, offsets, ivfCentroids); + final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + // write meta file + writeMeta(fieldWriter.fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, globalCentroid); } } @@ -284,7 +289,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws final long centroidOffset; final long centroidLength; final int numCentroids; - final int[][] assignmentsByCluster; + final int[] assignments; + final int[] overspillAssignments; final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; String centroidTempName = null; IndexOutput centroidTemp = null; @@ -292,15 +298,20 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws try { centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTempName = centroidTemp.getName(); - CentroidAssignments centroidAssignments = calculateAndWriteCentroids( + CentroidAssignments centroidAssignments = calculateCentroids( fieldInfo, getFloatVectorValues(fieldInfo, docs, vectors, numVectors), - centroidTemp, - mergeState, calculatedGlobalCentroid ); + // write the centroids to a temporary file so we are not holding them on heap + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float[] centroid : centroidAssignments.centroids()) { + buffer.asFloatBuffer().put(centroid); + centroidTemp.writeBytes(buffer.array(), buffer.array().length); + } numCentroids = centroidAssignments.numCentroids(); - assignmentsByCluster = centroidAssignments.assignmentsByCluster(); + assignments = centroidAssignments.assignments(); + overspillAssignments = centroidAssignments.overspillAssignments(); success = true; } finally { if (success == false && centroidTempName != null) { @@ -311,36 +322,37 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws try { if (numCentroids == 0) { centroidOffset = ivfCentroids.getFilePointer(); - writeMeta(fieldInfo, centroidOffset, 0, new long[0], null); + writeMeta(fieldInfo, 0, centroidOffset, 0, null); CodecUtil.writeFooter(centroidTemp); IOUtils.close(centroidTemp); return; } CodecUtil.writeFooter(centroidTemp); IOUtils.close(centroidTemp); - centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { - ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength()); - centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { CentroidSupplier centroidSupplier = createCentroidSupplier( centroidsInput, numCentroids, fieldInfo, calculatedGlobalCentroid ); - - // build a float vector values with random access - // build centroids - final long[] offsets = buildAndWritePostingsLists( + // write posting lists + final LongValues offsets = buildAndWritePostingsLists( fieldInfo, centroidSupplier, floatVectorValues, ivfClusters, - assignmentsByCluster + mergeState, + assignments, + overspillAssignments ); - assert offsets.length == centroidSupplier.size(); - writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); + // write centroids + centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, offsets, ivfCentroids); + centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + // write meta + writeMeta(fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, calculatedGlobalCentroid); } } finally { org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); @@ -423,18 +435,15 @@ private static int writeFloatVectorValues( return numVectors; } - private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid) + private void writeMeta(FieldInfo field, int numCentroids, long centroidOffset, long centroidLength, float[] globalCentroid) throws IOException { ivfMeta.writeInt(field.number); ivfMeta.writeInt(field.getVectorEncoding().ordinal()); ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); + ivfMeta.writeInt(numCentroids); ivfMeta.writeLong(centroidOffset); ivfMeta.writeLong(centroidLength); - ivfMeta.writeVInt(offsets.length); - for (long offset : offsets) { - ivfMeta.writeLong(offset); - } - if (offsets.length > 0) { + if (centroidLength > 0) { final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); buffer.asFloatBuffer().put(globalCentroid); ivfMeta.writeBytes(buffer.array(), buffer.array().length); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index 265763714d2db..fc13a4b9faa1a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -57,10 +57,20 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO return new KMeansIntermediate(); } - // if we have a small number of vectors pick one and output that as the centroid + // if we have a small number of vectors calculate the centroid directly if (vectors.size() <= targetSize) { float[] centroid = new float[dimension]; - System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, dimension); + // sum the vectors + for (int i = 0; i < vectors.size(); i++) { + float[] vector = vectors.vectorValue(i); + for (int j = 0; j < dimension; j++) { + centroid[j] += vector[j]; + } + } + // average the vectors + for (int j = 0; j < dimension; j++) { + centroid[j] /= vectors.size(); + } return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index 881de29c435db..1fdd1b52cfd51 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -62,6 +62,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.VersionType; import org.elasticsearch.index.codec.FieldInfosWithUsages; +import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec; import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils; import org.elasticsearch.index.mapper.DocumentParser; import org.elasticsearch.index.mapper.LuceneDocument; @@ -275,6 +276,7 @@ protected static ShardFieldStats shardFieldStats(List leaves) int numSegments = 0; int totalFields = 0; long usages = 0; + long totalPostingBytes = 0; for (LeafReaderContext leaf : leaves) { numSegments++; var fieldInfos = leaf.reader().getFieldInfos(); @@ -286,8 +288,19 @@ protected static ShardFieldStats shardFieldStats(List leaves) } else { usages = -1; } + if (TrackingPostingsInMemoryBytesCodec.TRACK_POSTINGS_IN_MEMORY_BYTES.isEnabled()) { + SegmentReader segmentReader = Lucene.tryUnwrapSegmentReader(leaf.reader()); + if (segmentReader != null) { + String postingBytes = segmentReader.getSegmentInfo().info.getAttribute( + TrackingPostingsInMemoryBytesCodec.IN_MEMORY_POSTINGS_BYTES_KEY + ); + if (postingBytes != null) { + totalPostingBytes += Long.parseLong(postingBytes); + } + } + } } - return new ShardFieldStats(numSegments, totalFields, usages); + return new ShardFieldStats(numSegments, totalFields, usages, totalPostingBytes); } /** diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index d5b45eb6de429..1ca795f69257f 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.engine; import org.apache.logging.log4j.Logger; +import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexCommit; @@ -79,6 +80,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.VersionType; import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy; +import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec; import org.elasticsearch.index.mapper.DocumentParser; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.mapper.LuceneDocument; @@ -2778,7 +2780,13 @@ private IndexWriterConfig getIndexWriterConfig() { iwc.setMaxFullFlushMergeWaitMillis(-1); iwc.setSimilarity(engineConfig.getSimilarity()); iwc.setRAMBufferSizeMB(engineConfig.getIndexingBufferSize().getMbFrac()); - iwc.setCodec(engineConfig.getCodec()); + + Codec codec = engineConfig.getCodec(); + if (TrackingPostingsInMemoryBytesCodec.TRACK_POSTINGS_IN_MEMORY_BYTES.isEnabled()) { + codec = new TrackingPostingsInMemoryBytesCodec(codec); + } + iwc.setCodec(codec); + boolean useCompoundFile = engineConfig.getUseCompoundFile(); iwc.setUseCompoundFile(useCompoundFile); if (useCompoundFile == false) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java index 22b198b10a7ad..f419d87d008fe 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java @@ -98,11 +98,11 @@ protected void writeExtent(BlockLoader.IntBuilder builder, Extent extent) { public BlockLoader.AllReader reader(LeafReaderContext context) throws IOException { return new BlockLoader.AllReader() { @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { var binaryDocValues = context.reader().getBinaryDocValues(fieldName); var reader = new GeometryDocValueReader(); - try (var builder = factory.ints(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (var builder = factory.ints(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(binaryDocValues, docs.get(i), reader, builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 363e956f1b211..f95e35a5d0845 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -124,10 +124,10 @@ private static class SingletonLongs extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -173,9 +173,9 @@ private static class Longs extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -259,10 +259,10 @@ private static class SingletonInts extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -308,9 +308,9 @@ private static class Ints extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -408,10 +408,10 @@ private static class SingletonDoubles extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -461,9 +461,9 @@ private static class Doubles extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -544,10 +544,10 @@ private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { // Doubles from doc values ensures that the values are in order - try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count(), dimensions)) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count() - offset, dimensions)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < iterator.docID()) { throw new IllegalStateException("docs within same block must be in order"); @@ -645,19 +645,19 @@ private BlockLoader.Block readSingleDoc(BlockFactory factory, int docId) throws if (ordinals.advanceExact(docId)) { BytesRef v = ordinals.lookupOrd(ordinals.ordValue()); // the returned BytesRef can be reused - return factory.constantBytes(BytesRef.deepCopyOf(v)); + return factory.constantBytes(BytesRef.deepCopyOf(v), 1); } else { - return factory.constantNulls(); + return factory.constantNulls(1); } } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - if (docs.count() == 1) { - return readSingleDoc(factory, docs.get(0)); + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + if (docs.count() - offset == 1) { + return readSingleDoc(factory, docs.get(offset)); } - try (BlockLoader.SingletonOrdinalsBuilder builder = factory.singletonOrdinalsBuilder(ordinals, docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (var builder = factory.singletonOrdinalsBuilder(ordinals, docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < ordinals.docID()) { throw new IllegalStateException("docs within same block must be in order"); @@ -700,14 +700,30 @@ private static class Ordinals extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BytesRefBuilder builder = factory.bytesRefsFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + if (docs.count() - offset == 1) { + return readSingleDoc(factory, docs.get(offset)); + } + try (var builder = factory.sortedSetOrdinalsBuilder(ordinals, docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < ordinals.docID()) { throw new IllegalStateException("docs within same block must be in order"); } - read(doc, builder); + if (ordinals.advanceExact(doc) == false) { + builder.appendNull(); + continue; + } + int count = ordinals.docValueCount(); + if (count == 1) { + builder.appendOrd(Math.toIntExact(ordinals.nextOrd())); + } else { + builder.beginPositionEntry(); + for (int c = 0; c < count; c++) { + builder.appendOrd(Math.toIntExact(ordinals.nextOrd())); + } + builder.endPositionEntry(); + } } return builder.build(); } @@ -718,6 +734,26 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build read(docId, (BytesRefBuilder) builder); } + private BlockLoader.Block readSingleDoc(BlockFactory factory, int docId) throws IOException { + if (ordinals.advanceExact(docId) == false) { + return factory.constantNulls(1); + } + int count = ordinals.docValueCount(); + if (count == 1) { + BytesRef v = ordinals.lookupOrd(ordinals.nextOrd()); + return factory.constantBytes(BytesRef.deepCopyOf(v), 1); + } + try (var builder = factory.bytesRefsFromDocValues(count)) { + builder.beginPositionEntry(); + for (int c = 0; c < count; c++) { + BytesRef v = ordinals.lookupOrd(ordinals.nextOrd()); + builder.appendBytesRef(v); + } + builder.endPositionEntry(); + return builder.build(); + } + } + private void read(int docId, BytesRefBuilder builder) throws IOException { if (false == ordinals.advanceExact(docId)) { builder.appendNull(); @@ -780,9 +816,9 @@ private static class BytesRefsFromBinary extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -879,9 +915,9 @@ private static class DenseVectorFromBinary extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count(), dimensions)) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count() - offset, dimensions)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -963,10 +999,10 @@ private static class SingletonBooleans extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -1012,9 +1048,9 @@ private static class Booleans extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java index 640a629410451..a4a498e4048db 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java @@ -43,7 +43,7 @@ interface ColumnAtATimeReader extends Reader { /** * Reads the values of all documents in {@code docs}. */ - BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException; + BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException; } interface RowStrideReader extends Reader { @@ -149,8 +149,8 @@ public String toString() { */ class ConstantNullsReader implements AllReader { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - return factory.constantNulls(); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + return factory.constantNulls(docs.count() - offset); } @Override @@ -183,8 +183,8 @@ public Builder builder(BlockFactory factory, int expectedCount) { public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) { return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) { - return factory.constantBytes(value); + public Block read(BlockFactory factory, Docs docs, int offset) { + return factory.constantBytes(value, docs.count() - offset); } @Override @@ -261,8 +261,8 @@ public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) throws } return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - return reader.read(factory, docs); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + return reader.read(factory, docs, offset); } @Override @@ -408,20 +408,23 @@ interface BlockFactory { /** * Build a block that contains only {@code null}. */ - Block constantNulls(); + Block constantNulls(int count); /** * Build a block that contains {@code value} repeated * {@code size} times. */ - Block constantBytes(BytesRef value); + Block constantBytes(BytesRef value, int count); /** - * Build a reader for reading keyword ordinals. + * Build a reader for reading {@link SortedDocValues} */ SingletonOrdinalsBuilder singletonOrdinalsBuilder(SortedDocValues ordinals, int count); - // TODO support non-singleton ords + /** + * Build a reader for reading {@link SortedSetDocValues} + */ + SortedSetOrdinalsBuilder sortedSetOrdinalsBuilder(SortedSetDocValues ordinals, int count); AggregateMetricDoubleBuilder aggregateMetricDoubleBuilder(int count); } @@ -509,6 +512,13 @@ interface SingletonOrdinalsBuilder extends Builder { SingletonOrdinalsBuilder appendOrd(int value); } + interface SortedSetOrdinalsBuilder extends Builder { + /** + * Appends an ordinal to the builder. + */ + SortedSetOrdinalsBuilder appendOrd(int value); + } + interface AggregateMetricDoubleBuilder extends Builder { DoubleBuilder min(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java index a3b10ea901395..3a1a805a25b64 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't emit falses before trues so we conform to the doc values contract and can use booleansFromDocValues - try (BlockLoader.BooleanBuilder builder = factory.booleans(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.BooleanBuilder builder = factory.booleans(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java index fb97b0f84c50f..0ec899e19a1cd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't sort the values sort, so we can't use factory.longsFromDocValues - try (BlockLoader.LongBuilder builder = factory.longs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.LongBuilder builder = factory.longs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java index d762acda9f7e4..f01cc65775e6e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't sort the values sort, so we can't use factory.doublesFromDocValues - try (BlockLoader.DoubleBuilder builder = factory.doubles(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.DoubleBuilder builder = factory.doubles(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java index 48d78129b8781..b232a8e1fc45a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't pre-sort our output so we can't use bytesRefsFromDocValues - try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java index cf731cc5cbc65..594b27f029901 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java @@ -513,6 +513,7 @@ public static final class KeywordFieldType extends StringFieldType { private final IndexMode indexMode; private final IndexSortConfig indexSortConfig; private final boolean hasDocValuesSkipper; + private final String originalName; public KeywordFieldType( String name, @@ -541,6 +542,7 @@ public KeywordFieldType( this.indexMode = builder.indexMode; this.indexSortConfig = builder.indexSortConfig; this.hasDocValuesSkipper = DocValuesSkipIndexType.NONE.equals(fieldType.docValuesSkipIndexType()) == false; + this.originalName = isSyntheticSource ? name + "._original" : null; } public KeywordFieldType(String name, boolean isIndexed, boolean hasDocValues, Map meta) { @@ -555,6 +557,7 @@ public KeywordFieldType(String name, boolean isIndexed, boolean hasDocValues, Ma this.indexMode = IndexMode.STANDARD; this.indexSortConfig = null; this.hasDocValuesSkipper = false; + this.originalName = null; } public KeywordFieldType(String name) { @@ -580,6 +583,7 @@ public KeywordFieldType(String name, FieldType fieldType) { this.indexMode = IndexMode.STANDARD; this.indexSortConfig = null; this.hasDocValuesSkipper = DocValuesSkipIndexType.NONE.equals(fieldType.docValuesSkipIndexType()) == false; + this.originalName = null; } public KeywordFieldType(String name, NamedAnalyzer analyzer) { @@ -594,6 +598,7 @@ public KeywordFieldType(String name, NamedAnalyzer analyzer) { this.indexMode = IndexMode.STANDARD; this.indexSortConfig = null; this.hasDocValuesSkipper = false; + this.originalName = null; } @Override @@ -1057,6 +1062,15 @@ public Query automatonQuery( ) { return new AutomatonQueryWithDescription(new Term(name()), automatonSupplier.get(), description); } + + /** + * The name used to store "original" that have been ignored + * by {@link KeywordFieldType#ignoreAbove()} so that they can be rebuilt + * for synthetic source. + */ + public String originalName() { + return originalName; + } } private final boolean indexed; @@ -1109,7 +1123,7 @@ private KeywordFieldMapper( this.useDocValuesSkipper = useDocValuesSkipper; this.offsetsFieldName = offsetsFieldName; this.indexSourceKeepMode = indexSourceKeepMode; - this.originalName = isSyntheticSource ? fullPath() + "._original" : null; + this.originalName = mappedFieldType.originalName(); } @Override @@ -1169,7 +1183,7 @@ private boolean indexValue(DocumentParserContext context, XContentString value) // Save a copy of the field so synthetic source can load it var utfBytes = value.bytes(); var bytesRef = new BytesRef(utfBytes.bytes(), utfBytes.offset(), utfBytes.length()); - context.doc().add(new StoredField(originalName(), bytesRef)); + context.doc().add(new StoredField(originalName, bytesRef)); } return false; } @@ -1280,15 +1294,6 @@ boolean hasNormalizer() { return normalizerName != null; } - /** - * The name used to store "original" that have been ignored - * by {@link KeywordFieldType#ignoreAbove()} so that they can be rebuilt - * for synthetic source. - */ - private String originalName() { - return originalName; - } - @Override protected SyntheticSourceSupport syntheticSourceSupport() { if (hasNormalizer()) { @@ -1337,7 +1342,7 @@ protected BytesRef preserve(BytesRef value) { } if (fieldType().ignoreAbove != Integer.MAX_VALUE) { - layers.add(new CompositeSyntheticFieldLoader.StoredFieldLayer(originalName()) { + layers.add(new CompositeSyntheticFieldLoader.StoredFieldLayer(originalName) { @Override protected void writeValue(Object value, XContentBuilder b) throws IOException { BytesRef ref = (BytesRef) value; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java index cfc7045a55513..220bba3d3c079 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java @@ -51,10 +51,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't pre-sort our output so we can't use bytesRefsFromDocValues - try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java index 0a1a8a86154ab..9c947a17de7b6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't pre-sort our output so we can't use longsFromDocValues - try (BlockLoader.LongBuilder builder = factory.longs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.LongBuilder builder = factory.longs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 0c6a3dbd00e6e..7ba2dfb9a69f5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -44,6 +44,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature NPE_ON_DIMS_UPDATE_FIX = new NodeFeature("mapper.npe_on_dims_update_fix"); static final NodeFeature IVF_FORMAT_CLUSTER_FEATURE = new NodeFeature("mapper.ivf_format_cluster_feature"); static final NodeFeature IVF_NESTED_SUPPORT = new NodeFeature("mapper.ivf_nested_support"); + static final NodeFeature BBQ_DISK_SUPPORT = new NodeFeature("mapper.bbq_disk_support"); static final NodeFeature SEARCH_LOAD_PER_SHARD = new NodeFeature("mapper.search_load_per_shard"); static final NodeFeature PATTERNED_TEXT = new NodeFeature("mapper.patterned_text"); @@ -76,6 +77,7 @@ public Set getTestFeatures() { USE_DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ, IVF_FORMAT_CLUSTER_FEATURE, IVF_NESTED_SUPPORT, + BBQ_DISK_SUPPORT, SEARCH_LOAD_PER_SHARD, SPARSE_VECTOR_INDEX_OPTIONS_FEATURE, PATTERNED_TEXT diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 4d1c4fc41526c..f0bf1db65ea72 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -35,6 +35,7 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; @@ -1674,7 +1675,7 @@ public boolean supportsDimension(int dims) { return dims >= BBQ_MIN_DIMS; } }, - BBQ_IVF("bbq_ivf", true) { + BBQ_DISK("bbq_disk", true) { @Override public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object clusterSizeNode = indexOptionsMap.remove("cluster_size"); @@ -2294,7 +2295,7 @@ static class BBQIVFIndexOptions extends QuantizedIndexOptions { final int defaultNProbe; BBQIVFIndexOptions(int clusterSize, int defaultNProbe, RescoreVector rescoreVector) { - super(VectorIndexType.BBQ_IVF, rescoreVector); + super(VectorIndexType.BBQ_DISK, rescoreVector); this.clusterSize = clusterSize; this.defaultNProbe = defaultNProbe; } @@ -2530,6 +2531,9 @@ public Query createKnnQuery( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } + if (dims == null) { + return new MatchNoDocsQuery("No data has been indexed for field [" + name() + "]"); + } KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy(); return switch (getElementType()) { case BYTE -> createKnnByteQuery( diff --git a/server/src/main/java/org/elasticsearch/index/query/WildcardQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/WildcardQueryBuilder.java index b79731f4ef3d1..a1eadb9ae7a5a 100644 --- a/server/src/main/java/org/elasticsearch/index/query/WildcardQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/WildcardQueryBuilder.java @@ -104,7 +104,7 @@ public WildcardQueryBuilder(StreamInput in) throws IOException { value = in.readString(); rewrite = in.readOptionalString(); caseInsensitive = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_FIXED_INDEX_LIKE)) { + if (expressionTransportSupported(in.getTransportVersion())) { forceStringMatch = in.readBoolean(); } else { forceStringMatch = false; @@ -117,11 +117,20 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(value); out.writeOptionalString(rewrite); out.writeBoolean(caseInsensitive); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_FIXED_INDEX_LIKE)) { + if (expressionTransportSupported(out.getTransportVersion())) { out.writeBoolean(forceStringMatch); } } + /** + * Returns true if the Transport version is compatible with ESQL_FIXED_INDEX_LIKE + */ + public static boolean expressionTransportSupported(TransportVersion version) { + return version.onOrAfter(TransportVersions.ESQL_FIXED_INDEX_LIKE) + || version.isPatchFrom(TransportVersions.ESQL_FIXED_INDEX_LIKE_8_19) + || version.isPatchFrom(TransportVersions.ESQL_FIXED_INDEX_LIKE_9_1); + } + @Override public String fieldName() { return fieldName; diff --git a/server/src/main/java/org/elasticsearch/index/shard/ShardFieldStats.java b/server/src/main/java/org/elasticsearch/index/shard/ShardFieldStats.java index 531df89116453..e7ad940b61319 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/ShardFieldStats.java +++ b/server/src/main/java/org/elasticsearch/index/shard/ShardFieldStats.java @@ -17,7 +17,8 @@ * @param totalFields the total number of fields across the segments * @param fieldUsages the number of usages for segment-level fields (e.g., doc_values, postings, norms, points) * -1 if unavailable + * @param postingsInMemoryBytes the total bytes in memory used for postings across all fields */ -public record ShardFieldStats(int numSegments, int totalFields, long fieldUsages) { +public record ShardFieldStats(int numSegments, int totalFields, long fieldUsages, long postingsInMemoryBytes) { } diff --git a/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java b/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java index b884b2c850cc5..0f9b724d965bf 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java +++ b/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java @@ -499,7 +499,7 @@ private void runUnlocked() { totalBytesUsed -= shardAndBytesUsed.bytesUsed; lastShardId = shardAndBytesUsed.shard.shardId(); if (doThrottle && throttled.contains(shardAndBytesUsed.shard) == false) { - logger.debug( + logger.info( "now throttling indexing for shard [{}]: segment writing can't keep up", shardAndBytesUsed.shard.shardId() ); diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 8fdc53e6b795f..528601f201fee 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -950,6 +950,7 @@ public synchronized void verifyIndexMetadata(IndexMetadata metadata, IndexMetada @Override public void createShard( + final ProjectId projectId, final ShardRouting shardRouting, final PeerRecoveryTargetService recoveryTargetService, final PeerRecoveryTargetService.RecoveryListener recoveryListener, @@ -968,26 +969,29 @@ public void createShard( RecoveryState recoveryState = indexService.createRecoveryState(shardRouting, targetNode, sourceNode); IndexShard indexShard = indexService.createShard(shardRouting, globalCheckpointSyncer, retentionLeaseSyncer); indexShard.addShardFailureCallback(onShardFailure); - indexShard.startRecovery( - recoveryState, - recoveryTargetService, - postRecoveryMerger.maybeMergeAfterRecovery(indexService.getMetadata(), shardRouting, recoveryListener), - repositoriesService, - (mapping, listener) -> { - assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS - : "mapping update consumer only required by local shards recovery"; - AcknowledgedRequest putMappingRequestAcknowledgedRequest = new PutMappingRequest() - // concrete index - no name clash, it uses uuid - .setConcreteIndex(shardRouting.index()) - .source(mapping.source().string(), XContentType.JSON); - client.execute( - TransportAutoPutMappingAction.TYPE, - putMappingRequestAcknowledgedRequest.ackTimeout(TimeValue.MAX_VALUE).masterNodeTimeout(TimeValue.MAX_VALUE), - new RefCountAwareThreadedActionListener<>(threadPool.generic(), listener.map(ignored -> null)) - ); - }, - this, - clusterStateVersion + projectResolver.executeOnProject( + projectId, + () -> indexShard.startRecovery( + recoveryState, + recoveryTargetService, + postRecoveryMerger.maybeMergeAfterRecovery(indexService.getMetadata(), shardRouting, recoveryListener), + repositoriesService, + (mapping, listener) -> { + assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS + : "mapping update consumer only required by local shards recovery"; + AcknowledgedRequest putMappingRequestAcknowledgedRequest = new PutMappingRequest() + // concrete index - no name clash, it uses uuid + .setConcreteIndex(shardRouting.index()) + .source(mapping.source().string(), XContentType.JSON); + client.execute( + TransportAutoPutMappingAction.TYPE, + putMappingRequestAcknowledgedRequest.ackTimeout(TimeValue.MAX_VALUE).masterNodeTimeout(TimeValue.MAX_VALUE), + new RefCountAwareThreadedActionListener<>(threadPool.generic(), listener.map(ignored -> null)) + ); + }, + this, + clusterStateVersion + ) ); } diff --git a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java index 83be37d553fef..95c462072ae5a 100644 --- a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java +++ b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java @@ -26,6 +26,7 @@ import org.elasticsearch.cluster.ClusterStateApplier; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -781,6 +782,7 @@ private void createShardWhenLockAvailable( try { logger.debug("{} creating shard with primary term [{}], iteration [{}]", shardRouting.shardId(), primaryTerm, iteration); indicesService.createShard( + originalState.metadata().projectFor(shardRouting.index()).id(), shardRouting, recoveryTargetService, new RecoveryListener(shardRouting, primaryTerm), @@ -1004,8 +1006,8 @@ private static DiscoveryNode findSourceNodeForReshardSplitRecovery( ShardRouting sourceShardRouting = routingTable.shardRoutingTable(sourceShardId).primaryShard(); if (sourceShardRouting.active() == false) { - assert false : sourceShardRouting.shortSummary(); - logger.trace("can't find reshard split source node because source shard {} is not active.", sourceShardRouting); + // Source shard is unassigned (likely due to failure), we will retry. + logger.trace("can't find reshard split source node because source shard {} is not active.", sourceShardRouting.shortSummary()); return null; } @@ -1330,6 +1332,7 @@ void removeIndex( /** * Creates a shard for the specified shard routing and starts recovery. * + * @param projectId the project for the shard * @param shardRouting the shard routing * @param recoveryTargetService recovery service for the target * @param recoveryListener a callback when recovery changes state (finishes or fails) @@ -1343,6 +1346,7 @@ void removeIndex( * @throws IOException if an I/O exception occurs when creating the shard */ void createShard( + ProjectId projectId, ShardRouting shardRouting, PeerRecoveryTargetService recoveryTargetService, PeerRecoveryTargetService.RecoveryListener recoveryListener, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java index 3274bf571d10a..a6857b82a747f 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.threadpool.ThreadPool; import java.util.List; @@ -23,7 +24,13 @@ public interface InferenceServiceExtension { List getInferenceServiceFactories(); - record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {} + record InferenceServiceFactoryContext( + Client client, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings, + InferenceStats inferenceStats + ) {} interface Factory { /** diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index db31aafc8c190..b6f724e69d40f 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -121,7 +121,7 @@ public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Para * - Key: {@link #MODEL_FIELD}, Value: modelId * - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()} */ - public static Params withMaxCompletionTokensTokens(String modelId, Params params) { + public static Params withMaxCompletionTokens(String modelId, Params params) { return new DelegatingMapParams( Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)), params diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java similarity index 65% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java rename to server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java index 17c91b81233fb..e73b1ad9c5ff6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.telemetry; +package org.elasticsearch.inference.telemetry; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; @@ -14,17 +16,17 @@ import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.LongHistogram; import org.elasticsearch.telemetry.metric.MeterRegistry; -import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import java.util.HashMap; import java.util.Map; import java.util.Objects; -public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) { +public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration, LongHistogram deploymentDuration) { public InferenceStats { Objects.requireNonNull(requestCount); Objects.requireNonNull(inferenceDuration); + Objects.requireNonNull(deploymentDuration); } public static InferenceStats create(MeterRegistry meterRegistry) { @@ -38,6 +40,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) { "es.inference.requests.time", "Inference API request counts for a particular service, task type, model ID", "ms" + ), + meterRegistry.registerLongHistogram( + "es.inference.trained_model.deployment.time", + "Inference API time spent waiting for Trained Model Deployments", + "ms" ) ); } @@ -54,8 +61,8 @@ public static Map modelAttributes(Model model) { return modelAttributesMap; } - public static Map routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) { - return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest); + public static Map routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) { + return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest); } public static Map modelAttributes(UnparsedModel model) { @@ -73,4 +80,11 @@ public static Map responseAttributes(@Nullable Throwable throwab return Map.of("error.type", throwable.getClass().getSimpleName()); } + + public static Map modelAndResponseAttributes(Model model, @Nullable Throwable throwable) { + var metricAttributes = new HashMap(); + metricAttributes.putAll(modelAttributes(model)); + metricAttributes.putAll(responseAttributes(throwable)); + return metricAttributes; + } } diff --git a/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java b/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java index bfa8f0a01c661..326002c7d346c 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java +++ b/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java @@ -14,6 +14,7 @@ import org.elasticsearch.cluster.ClusterInfoService; import org.elasticsearch.cluster.EstimatedHeapUsageCollector; import org.elasticsearch.cluster.InternalClusterInfoService; +import org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -79,12 +80,17 @@ ClusterInfoService newClusterInfoService( EstimatedHeapUsageCollector.class, () -> EstimatedHeapUsageCollector.EMPTY ); + final NodeUsageStatsForThreadPoolsCollector nodeUsageStatsForThreadPoolsCollector = pluginsService.loadSingletonServiceProvider( + NodeUsageStatsForThreadPoolsCollector.class, + () -> NodeUsageStatsForThreadPoolsCollector.EMPTY + ); final InternalClusterInfoService service = new InternalClusterInfoService( settings, clusterService, threadPool, client, - estimatedHeapUsageCollector + estimatedHeapUsageCollector, + nodeUsageStatsForThreadPoolsCollector ); if (DiscoveryNode.isMasterNode(settings)) { // listen for state changes (this node starts/stops being the elected master, or new nodes are added) diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java index ceeb1a4e27f1b..f3a25caf79bb6 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java @@ -171,9 +171,9 @@ public ClusterState execute(ClusterState currentState) { assert (projectId == null && taskExecutor.scope() == PersistentTasksExecutor.Scope.CLUSTER) || (projectId != null && taskExecutor.scope() == PersistentTasksExecutor.Scope.PROJECT) : "inconsistent project-id [" + projectId + "] and task scope [" + taskExecutor.scope() + "]"; - taskExecutor.validate(taskParams, currentState); + taskExecutor.validate(taskParams, currentState, projectId); - Assignment assignment = createAssignment(taskName, taskParams, currentState); + Assignment assignment = createAssignment(taskName, taskParams, currentState, projectId); logger.debug("creating {} persistent task [{}] with assignment [{}]", taskTypeString(projectId), taskName, assignment); return builder.addTask(taskId, taskName, taskParams, assignment).buildAndUpdate(currentState, projectId); } @@ -449,7 +449,8 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) private Assignment createAssignment( final String taskName, final Params taskParams, - final ClusterState currentState + final ClusterState currentState, + @Nullable final ProjectId projectId ) { PersistentTasksExecutor persistentTasksExecutor = registry.getPersistentTaskExecutorSafe(taskName); @@ -468,7 +469,7 @@ private Assignment createAssignment( // Task assignment should not rely on node order Randomness.shuffle(candidateNodes); - final Assignment assignment = persistentTasksExecutor.getAssignment(taskParams, candidateNodes, currentState); + final Assignment assignment = persistentTasksExecutor.getAssignment(taskParams, candidateNodes, currentState, projectId); assert assignment != null : "getAssignment() should always return an Assignment object, containing a node or a reason why not"; assert (assignment.getExecutorNode() == null || currentState.metadata().nodeShutdowns().contains(assignment.getExecutorNode()) == false) @@ -540,8 +541,8 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) * persistent tasks changed. */ boolean shouldReassignPersistentTasks(final ClusterChangedEvent event) { - final List allTasks = PersistentTasks.getAllTasks(event.state()).map(Tuple::v2).toList(); - if (allTasks.isEmpty()) { + var projectIdToTasksIterator = PersistentTasks.getAllTasks(event.state()).iterator(); + if (projectIdToTasksIterator.hasNext() == false) { return false; } @@ -553,10 +554,16 @@ boolean shouldReassignPersistentTasks(final ClusterChangedEvent event) { || event.metadataChanged() || masterChanged) { - for (PersistentTasks tasks : allTasks) { - for (PersistentTask task : tasks.tasks()) { + while (projectIdToTasksIterator.hasNext()) { + var projectIdToTasks = projectIdToTasksIterator.next(); + for (PersistentTask task : projectIdToTasks.v2().tasks()) { if (needsReassignment(task.getAssignment(), event.state().nodes())) { - Assignment assignment = createAssignment(task.getTaskName(), task.getParams(), event.state()); + Assignment assignment = createAssignment( + task.getTaskName(), + task.getParams(), + event.state(), + projectIdToTasks.v1() + ); if (Objects.equals(assignment, task.getAssignment()) == false) { return true; } @@ -602,7 +609,7 @@ private ClusterState reassignClusterOrSingleProjectTasks(@Nullable final Project // We need to check if removed nodes were running any of the tasks and reassign them for (PersistentTask task : tasks.tasks()) { if (needsReassignment(task.getAssignment(), nodes)) { - Assignment assignment = createAssignment(task.getTaskName(), task.getParams(), clusterState); + Assignment assignment = createAssignment(task.getTaskName(), task.getParams(), clusterState, projectId); if (Objects.equals(assignment, task.getAssignment()) == false) { logger.trace( "reassigning {} task {} from node {} to node {}", diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java index b58ef7523bf99..96c0767fe65f8 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java @@ -10,6 +10,7 @@ package org.elasticsearch.persistent; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; @@ -63,7 +64,31 @@ public Scope scope() { *

* The default implementation returns the least loaded data node from amongst the collection of candidate nodes */ - public Assignment getAssignment(Params params, Collection candidateNodes, ClusterState clusterState) { + public final Assignment getAssignment( + Params params, + Collection candidateNodes, + ClusterState clusterState, + @Nullable ProjectId projectId + ) { + assert (scope() == Scope.PROJECT && projectId != null) || (scope() == Scope.CLUSTER && projectId == null) + : "inconsistent project-id [" + projectId + "] and task scope [" + scope() + "]"; + return doGetAssignment(params, candidateNodes, clusterState, projectId); + } + + /** + * Returns the node id where the params has to be executed, + *

+ * The default implementation returns the least loaded data node from amongst the collection of candidate nodes. + *

+ * If {@link #scope()} returns CLUSTER, then {@link ProjectId} will be null. + * If {@link #scope()} returns PROJECT, then {@link ProjectId} will not be null. + */ + protected Assignment doGetAssignment( + Params params, + Collection candidateNodes, + ClusterState clusterState, + @Nullable ProjectId projectId + ) { DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, candidateNodes, DiscoveryNode::canContainData); if (discoveryNode == null) { return NO_NODE_FOUND; @@ -105,7 +130,7 @@ protected DiscoveryNode selectLeastLoadedNode( *

* Throws an exception if the supplied params cannot be executed on the cluster in the current state. */ - public void validate(Params params, ClusterState clusterState) {} + public void validate(Params params, ClusterState clusterState, @Nullable ProjectId projectId) {} /** * Creates a AllocatedPersistentTask for communicating with task manager diff --git a/server/src/main/java/org/elasticsearch/plugins/PluginsService.java b/server/src/main/java/org/elasticsearch/plugins/PluginsService.java index 78a8650a5e920..348cd7e3cd4b2 100644 --- a/server/src/main/java/org/elasticsearch/plugins/PluginsService.java +++ b/server/src/main/java/org/elasticsearch/plugins/PluginsService.java @@ -13,6 +13,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.admin.cluster.node.info.PluginsAndModules; @@ -477,6 +478,7 @@ static void reloadLuceneSPI(ClassLoader loader) { // Codecs: PostingsFormat.reloadPostingsFormats(loader); DocValuesFormat.reloadDocValuesFormats(loader); + KnnVectorsFormat.reloadKnnVectorsFormat(loader); Codec.reloadCodecs(loader); } diff --git a/server/src/main/java/org/elasticsearch/rest/RestRequest.java b/server/src/main/java/org/elasticsearch/rest/RestRequest.java index 54ad500a9144d..1ae558537e254 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/server/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -479,6 +479,18 @@ public int paramAsInt(String key, int defaultValue) { } } + public Integer paramAsInteger(String key, Integer defaultValue) { + String sValue = param(key); + if (sValue == null) { + return defaultValue; + } + try { + return Integer.valueOf(sValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse int parameter [" + key + "] with value [" + sValue + "]", e); + } + } + public long paramAsLong(String key, long defaultValue) { String sValue = param(key); if (sValue == null) { diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterAllocationExplainAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterAllocationExplainAction.java index 660d5261de86f..3bbb515d10b0b 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterAllocationExplainAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterAllocationExplainAction.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.List; +import java.util.Set; import static org.elasticsearch.rest.RestRequest.Method.GET; import static org.elasticsearch.rest.RestRequest.Method.POST; @@ -47,23 +48,62 @@ public boolean allowSystemIndexAccessByDefault() { return true; } + /* + The Cluster Allocation Explain API supports both query parameters and parameters passed through the request body, but not both. + The API also supports empty requests, which translates to "explain the first unassigned shard you find" + */ @Override public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { - final var req = new ClusterAllocationExplainRequest(RestUtils.getMasterNodeTimeout(request)); + final var clusterAllocationExplainRequest = new ClusterAllocationExplainRequest(RestUtils.getMasterNodeTimeout(request)); + + // A request body was passed if (request.hasContentOrSourceParam()) { try (XContentParser parser = request.contentOrSourceParamParser()) { - ClusterAllocationExplainRequest.parse(req, parser); + ClusterAllocationExplainRequest.parse(clusterAllocationExplainRequest, parser); } - } // else ok, an empty body means "explain the first unassigned shard you find" - req.includeYesDecisions(request.paramAsBoolean("include_yes_decisions", false)); - req.includeDiskInfo(request.paramAsBoolean("include_disk_info", false)); + } + // There is no request body. Check for optionally supplied query parameters + else { + clusterAllocationExplainRequest.setIndex( + request.param( + ClusterAllocationExplainRequest.INDEX_PARAMETER_NAME, + // Defaults to the existing value, which was instantiated as null + clusterAllocationExplainRequest.getIndex() + ) + ); + + clusterAllocationExplainRequest.setShard( + request.paramAsInteger(ClusterAllocationExplainRequest.SHARD_PARAMETER_NAME, clusterAllocationExplainRequest.getShard()) + ); + + clusterAllocationExplainRequest.setPrimary( + request.paramAsBoolean(ClusterAllocationExplainRequest.PRIMARY_PARAMETER_NAME, clusterAllocationExplainRequest.isPrimary()) + ); + + clusterAllocationExplainRequest.setCurrentNode( + request.param(ClusterAllocationExplainRequest.CURRENT_NODE_PARAMETER_NAME, clusterAllocationExplainRequest.getCurrentNode()) + ); + } + + clusterAllocationExplainRequest.includeYesDecisions( + request.paramAsBoolean(ClusterAllocationExplainRequest.INCLUDE_YES_DECISIONS_PARAMETER_NAME, false) + ); + clusterAllocationExplainRequest.includeDiskInfo( + request.paramAsBoolean(ClusterAllocationExplainRequest.INCLUDE_DISK_INFO_PARAMETER_NAME, false) + ); + return channel -> client.execute( TransportClusterAllocationExplainAction.TYPE, - req, + clusterAllocationExplainRequest, new RestRefCountedChunkedToXContentListener<>(channel) ); } + @Override + public Set supportedCapabilities() { + return Set.of("query_parameter_support"); + } + @Override public boolean canTripCircuitBreaker() { return false; diff --git a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java index 0c2f7c2aa625b..80ccd4c188538 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java @@ -32,6 +32,7 @@ public Set getFeatures() { public static final NodeFeature INT_SORT_FOR_INT_SHORT_BYTE_FIELDS = new NodeFeature("search.sort.int_sort_for_int_short_byte_fields"); static final NodeFeature MULTI_MATCH_CHECKS_POSITIONS = new NodeFeature("search.multi.match.checks.positions"); public static final NodeFeature BBQ_HNSW_DEFAULT_INDEXING = new NodeFeature("search.vectors.mappers.default_bbq_hnsw"); + public static final NodeFeature SEARCH_WITH_NO_DIMENSIONS_BUGFIX = new NodeFeature("search.vectors.no_dimensions_bugfix"); @Override public Set getTestFeatures() { @@ -41,7 +42,8 @@ public Set getTestFeatures() { RESCORER_MISSING_FIELD_BAD_REQUEST, INT_SORT_FOR_INT_SHORT_BYTE_FIELDS, MULTI_MATCH_CHECKS_POSITIONS, - BBQ_HNSW_DEFAULT_INDEXING + BBQ_HNSW_DEFAULT_INDEXING, + SEARCH_WITH_NO_DIMENSIONS_BUGFIX ); } } diff --git a/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java b/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java index 61d20fa72e262..3c5033355b826 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java @@ -599,7 +599,13 @@ private IndexGeoPointFieldData fieldData(SearchExecutionContext context) { throw new IllegalArgumentException("failed to find mapper for [" + fieldName + "] for geo distance based sort"); } } - return context.getForField(fieldType, MappedFieldType.FielddataOperation.SEARCH); + IndexFieldData indexFieldData = context.getForField(fieldType, MappedFieldType.FielddataOperation.SEARCH); + if (indexFieldData instanceof IndexGeoPointFieldData) { + return (IndexGeoPointFieldData) indexFieldData; + } + throw new IllegalArgumentException( + "unable to apply geo distance sort to field [" + fieldName + "] of type [" + fieldType.typeName() + "]" + ); } private Nested nested(SearchExecutionContext context) throws IOException { diff --git a/libs/core/src/main/java/org/elasticsearch/jdk/RuntimeVersionFeature.java b/server/src/main/java/org/elasticsearch/synonyms/SynonymFeatures.java similarity index 50% rename from libs/core/src/main/java/org/elasticsearch/jdk/RuntimeVersionFeature.java rename to server/src/main/java/org/elasticsearch/synonyms/SynonymFeatures.java index 682ca6ef19f3b..b42143ed899a3 100644 --- a/libs/core/src/main/java/org/elasticsearch/jdk/RuntimeVersionFeature.java +++ b/server/src/main/java/org/elasticsearch/synonyms/SynonymFeatures.java @@ -7,12 +7,18 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.jdk; +package org.elasticsearch.synonyms; -public class RuntimeVersionFeature { - private RuntimeVersionFeature() {} +import org.elasticsearch.features.FeatureSpecification; +import org.elasticsearch.features.NodeFeature; - public static boolean isSecurityManagerAvailable() { - return Runtime.version().feature() < 24; +import java.util.Set; + +public class SynonymFeatures implements FeatureSpecification { + private static final NodeFeature RETURN_EMPTY_SYNONYM_SETS = new NodeFeature("synonyms_set.get.return_empty_synonym_sets"); + + @Override + public Set getTestFeatures() { + return Set.of(RETURN_EMPTY_SYNONYM_SETS); } } diff --git a/server/src/main/java/org/elasticsearch/synonyms/SynonymsManagementAPIService.java b/server/src/main/java/org/elasticsearch/synonyms/SynonymsManagementAPIService.java index 70b020eb66ab5..06dce1724d392 100644 --- a/server/src/main/java/org/elasticsearch/synonyms/SynonymsManagementAPIService.java +++ b/server/src/main/java/org/elasticsearch/synonyms/SynonymsManagementAPIService.java @@ -50,6 +50,9 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -94,6 +97,8 @@ public class SynonymsManagementAPIService { private static final int MAX_SYNONYMS_SETS = 10_000; private static final String SYNONYM_RULE_ID_FIELD = SynonymRule.ID_FIELD.getPreferredName(); private static final String SYNONYM_SETS_AGG_NAME = "synonym_sets_aggr"; + private static final String RULE_COUNT_AGG_NAME = "rule_count"; + private static final String RULE_COUNT_FILTER_KEY = "synonym_rules"; private static final int SYNONYMS_INDEX_MAPPINGS_VERSION = 1; public static final int INDEX_SEARCHABLE_TIMEOUT_SECONDS = 30; private final int maxSynonymsSets; @@ -185,15 +190,33 @@ private static XContentBuilder mappings() { } } + /** + * Returns all synonym sets with their rule counts, including empty synonym sets. + * @param from The index of the first synonym set to return + * @param size The number of synonym sets to return + * @param listener The listener to return the synonym sets to + */ public void getSynonymsSets(int from, int size, ActionListener> listener) { + BoolQueryBuilder synonymSetQuery = QueryBuilders.boolQuery() + .should(QueryBuilders.termQuery(OBJECT_TYPE_FIELD, SYNONYM_SET_OBJECT_TYPE)) + .should(QueryBuilders.termQuery(OBJECT_TYPE_FIELD, SYNONYM_RULE_OBJECT_TYPE)) + .minimumShouldMatch(1); + + // Aggregation query to count only synonym rules (excluding synonym set objects) + FiltersAggregationBuilder ruleCountAggregation = new FiltersAggregationBuilder( + RULE_COUNT_AGG_NAME, + new FiltersAggregator.KeyedFilter(RULE_COUNT_FILTER_KEY, QueryBuilders.termQuery(OBJECT_TYPE_FIELD, SYNONYM_RULE_OBJECT_TYPE)) + ); + client.prepareSearch(SYNONYMS_ALIAS_NAME) .setSize(0) // Retrieves aggregated synonym rules for each synonym set, excluding the synonym set object type - .setQuery(QueryBuilders.termQuery(OBJECT_TYPE_FIELD, SYNONYM_RULE_OBJECT_TYPE)) + .setQuery(synonymSetQuery) .addAggregation( new TermsAggregationBuilder(SYNONYM_SETS_AGG_NAME).field(SYNONYMS_SET_FIELD) .order(BucketOrder.key(true)) .size(maxSynonymsSets) + .subAggregation(ruleCountAggregation) ) .setPreference(Preference.LOCAL.type()) .execute(new ActionListener<>() { @@ -201,11 +224,11 @@ public void getSynonymsSets(int from, int size, ActionListener buckets = termsAggregation.getBuckets(); - SynonymSetSummary[] synonymSetSummaries = buckets.stream() - .skip(from) - .limit(size) - .map(bucket -> new SynonymSetSummary(bucket.getDocCount(), bucket.getKeyAsString())) - .toArray(SynonymSetSummary[]::new); + SynonymSetSummary[] synonymSetSummaries = buckets.stream().skip(from).limit(size).map(bucket -> { + Filters ruleCountFilters = bucket.getAggregations().get(RULE_COUNT_AGG_NAME); + Filters.Bucket ruleCountBucket = ruleCountFilters.getBucketByKey(RULE_COUNT_FILTER_KEY); + return new SynonymSetSummary(ruleCountBucket.getDocCount(), bucket.getKeyAsString()); + }).toArray(SynonymSetSummary[]::new); listener.onResponse(new PagedResult<>(buckets.size(), synonymSetSummaries)); } diff --git a/server/src/main/java/org/elasticsearch/tasks/Task.java b/server/src/main/java/org/elasticsearch/tasks/Task.java index f76344fc6ca85..23ac4de8b618a 100644 --- a/server/src/main/java/org/elasticsearch/tasks/Task.java +++ b/server/src/main/java/org/elasticsearch/tasks/Task.java @@ -44,6 +44,7 @@ public class Task implements Traceable { * TRACE_PARENT once parsed in RestController.tryAllHandler is not preserved * has to be declared as a header copied over from http request. * May also be used internally when APM is enabled. + * https://www.w3.org/TR/trace-context-1/#traceparent-header */ public static final String TRACE_PARENT_HTTP_HEADER = "traceparent"; @@ -53,6 +54,10 @@ public class Task implements Traceable { */ public static final String TRACE_ID = "trace.id"; + /** + * Optional request header carrying vendor-specific trace information. + * https://www.w3.org/TR/trace-context-1/#tracestate-header + */ public static final String TRACE_STATE = "tracestate"; public static final String TRACE_START_TIME = "trace.starttime"; diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index 1b9b682666158..2ed347c226870 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -170,24 +170,33 @@ public Task register(String type, String action, TaskAwareRequest request, boole Task previousTask = tasks.put(task.getId(), task); assert previousTask == null; if (traceRequest) { - startTrace(threadContext, task); + maybeStartTrace(threadContext, task); } } return task; } - // package private for testing - void startTrace(ThreadContext threadContext, Task task) { + /** + * Start a new trace span if a parent trace context already exists. + * For REST actions this will be the case, otherwise {@link Tracer#startTrace} can be used. + */ + void maybeStartTrace(ThreadContext threadContext, Task task) { + if (threadContext.hasParentTraceContext() == false) { + return; + } TaskId parentTask = task.getParentTaskId(); - Map attributes = Map.of( - Tracer.AttributeKeys.TASK_ID, - task.getId(), - Tracer.AttributeKeys.PARENT_TASK_ID, - parentTask.toString() - ); + Map attributes = parentTask.isSet() + ? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString()) + : Map.of(Tracer.AttributeKeys.TASK_ID, task.getId()); tracer.startTrace(threadContext, task, task.getAction(), attributes); } + void maybeStopTrace(ThreadContext threadContext, Task task) { + if (threadContext.hasTraceContext()) { + tracer.stopTrace(task); + } + } + public Task registerAndExecute( String type, TransportAction action, @@ -250,7 +259,7 @@ private void registerCancellableTask(Task task, long requestId, boolean traceReq CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask); cancellableTasks.put(task, requestId, holder); if (traceRequest) { - startTrace(threadPool.getThreadContext(), task); + maybeStartTrace(threadPool.getThreadContext(), task); } // Check if this task was banned before we start it. if (task.getParentTaskId().isSet()) { @@ -349,7 +358,7 @@ public Task unregister(Task task) { return removedTask; } } finally { - tracer.stopTrace(task); + maybeStopTrace(threadPool.getThreadContext(), task); for (RemovedTaskListener listener : removedTaskListeners) { listener.onRemoved(task); } diff --git a/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java b/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java index 7b69e3a164d5b..4ebcae1cc2ac0 100644 --- a/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java +++ b/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java @@ -56,7 +56,11 @@ public Map getBuilders(Settings settings, int allocated allocatedProcessors, // 10,000 for all nodes with 8 cores or fewer. Scale up once we have more than 8 cores. Math.max(allocatedProcessors * 750, 10000), - new EsExecutors.TaskTrackingConfig(true, indexAutoscalingEWMA) + EsExecutors.TaskTrackingConfig.builder() + .trackOngoingTasks() + .trackMaxQueueLatency() + .trackExecutionTime(indexAutoscalingEWMA) + .build() ) ); int searchOrGetThreadPoolSize = ThreadPool.searchOrGetThreadPoolSize(allocatedProcessors); @@ -81,7 +85,7 @@ public Map getBuilders(Settings settings, int allocated ThreadPool.Names.SEARCH, searchOrGetThreadPoolSize, searchOrGetThreadPoolSize * 1000, - new EsExecutors.TaskTrackingConfig(true, searchAutoscalingEWMA) + EsExecutors.TaskTrackingConfig.builder().trackOngoingTasks().trackExecutionTime(searchAutoscalingEWMA).build() ) ); result.put( @@ -91,7 +95,7 @@ public Map getBuilders(Settings settings, int allocated ThreadPool.Names.SEARCH_COORDINATION, halfProc, 1000, - new EsExecutors.TaskTrackingConfig(true, searchAutoscalingEWMA) + EsExecutors.TaskTrackingConfig.builder().trackOngoingTasks().trackExecutionTime(searchAutoscalingEWMA).build() ) ); result.put( @@ -195,7 +199,7 @@ public Map getBuilders(Settings settings, int allocated ThreadPool.Names.SYSTEM_WRITE, halfProcMaxAt5, 1000, - new EsExecutors.TaskTrackingConfig(true, indexAutoscalingEWMA), + EsExecutors.TaskTrackingConfig.builder().trackOngoingTasks().trackExecutionTime(indexAutoscalingEWMA).build(), true ) ); @@ -228,7 +232,7 @@ public Map getBuilders(Settings settings, int allocated ThreadPool.Names.SYSTEM_CRITICAL_WRITE, halfProcMaxAt5, 1500, - new EsExecutors.TaskTrackingConfig(true, indexAutoscalingEWMA), + EsExecutors.TaskTrackingConfig.builder().trackOngoingTasks().trackExecutionTime(indexAutoscalingEWMA).build(), true ) ); diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index cab15fffa3fd0..b0408ac3c60cc 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -433,6 +433,16 @@ private void maybeLogSlowMessage(boolean success) { } }); } catch (RuntimeException ex) { + logger.error( + Strings.format( + "unexpected exception calling sendMessage for transport message [%s] of size [%d] on [%s]", + messageDescription.get(), + messageSize, + channel + ), + ex + ); + assert Thread.currentThread().getName().startsWith("TEST-") : ex; channel.setCloseException(ex); Releasables.closeExpectNoException(() -> listener.onFailure(ex), () -> CloseableChannel.closeChannel(channel)); throw ex; diff --git a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java index fdb597b47c137..d1816c7fc1687 100644 --- a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java +++ b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java @@ -247,16 +247,16 @@ public Map groupIndices(Set remoteClusterNames, } public Map groupIndices(IndicesOptions indicesOptions, String[] indices, boolean returnLocalAll) { - return groupIndices(getRemoteClusterNames(), indicesOptions, indices, returnLocalAll); + return groupIndices(getRegisteredRemoteClusterNames(), indicesOptions, indices, returnLocalAll); } public Map groupIndices(IndicesOptions indicesOptions, String[] indices) { - return groupIndices(getRemoteClusterNames(), indicesOptions, indices, true); + return groupIndices(getRegisteredRemoteClusterNames(), indicesOptions, indices, true); } @Override public Set getConfiguredClusters() { - return getRemoteClusterNames(); + return getRegisteredRemoteClusterNames(); } /** @@ -270,7 +270,6 @@ boolean isRemoteClusterRegistered(String clusterName) { * Returns the registered remote cluster names. */ public Set getRegisteredRemoteClusterNames() { - // remoteClusters is unmodifiable so its key set will be unmodifiable too return remoteClusters.keySet(); } @@ -355,10 +354,6 @@ public RemoteClusterConnection getRemoteClusterConnection(String cluster) { return connection; } - Set getRemoteClusterNames() { - return this.remoteClusters.keySet(); - } - @Override public void listenForUpdates(ClusterSettings clusterSettings) { super.listenForUpdates(clusterSettings); @@ -648,7 +643,7 @@ public RemoteClusterClient getRemoteClusterClient( "this node does not have the " + DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE.roleName() + " role" ); } - if (transportService.getRemoteClusterService().getRemoteClusterNames().contains(clusterAlias) == false) { + if (transportService.getRemoteClusterService().getRegisteredRemoteClusterNames().contains(clusterAlias) == false) { throw new NoSuchRemoteClusterException(clusterAlias); } return new RemoteClusterAwareClient(transportService, clusterAlias, responseExecutor, switch (disconnectedStrategy) { diff --git a/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification index 5a258a4cd774f..677a5a96891b5 100644 --- a/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification +++ b/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification @@ -14,6 +14,7 @@ org.elasticsearch.rest.action.admin.cluster.GetSnapshotsFeatures org.elasticsearch.index.IndexFeatures org.elasticsearch.index.mapper.MapperFeatures org.elasticsearch.search.SearchFeatures +org.elasticsearch.synonyms.SynonymFeatures org.elasticsearch.search.retriever.RetrieversFeatures org.elasticsearch.script.ScriptFeatures org.elasticsearch.cluster.routing.RoutingFeatures diff --git a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.txt b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.txt index ae99ef33a3716..50732d9289be5 100644 --- a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.txt +++ b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.txt @@ -45,4 +45,5 @@ CIRCUIT_BREAKER_ERRORS troubleshoot/ela ALLOCATION_EXPLAIN_NO_COPIES troubleshoot/elasticsearch/diagnose-unassigned-shards#no-shard-copy ALLOCATION_EXPLAIN_MAX_RETRY troubleshoot/elasticsearch/diagnose-unassigned-shards#maximum-retries-exceeded SECURE_SETTINGS deploy-manage/security/secure-settings -CLUSTER_SHARD_LIMIT reference/elasticsearch/configuration-reference/miscellaneous-cluster-settings#cluster-shard-limit \ No newline at end of file +CLUSTER_SHARD_LIMIT reference/elasticsearch/configuration-reference/miscellaneous-cluster-settings#cluster-shard-limit +DEPLOY_CLOUD_DIFF_FROM_STATEFUL deploy-manage/deploy/elastic-cloud/differences-from-other-elasticsearch-offerings diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequestTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequestTests.java index 9301232f86f85..7ea94632bedd8 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequestTests.java @@ -37,4 +37,66 @@ public void testSerialization() throws Exception { assertEquals(request.getCurrentNode(), actual.getCurrentNode()); } + public void testToStringWithEmptyBody() { + ClusterAllocationExplainRequest clusterAllocationExplainRequest = new ClusterAllocationExplainRequest(randomTimeValue()); + clusterAllocationExplainRequest.includeYesDecisions(true); + clusterAllocationExplainRequest.includeDiskInfo(false); + + String expected = "ClusterAllocationExplainRequest[useAnyUnassignedShard=true," + + "include_yes_decisions?=true,include_disk_info?=false"; + assertEquals(expected, clusterAllocationExplainRequest.toString()); + } + + public void testToStringWithValidBodyButCurrentNodeIsNull() { + String index = "test-index"; + int shard = randomInt(); + boolean primary = randomBoolean(); + ClusterAllocationExplainRequest clusterAllocationExplainRequest = new ClusterAllocationExplainRequest( + randomTimeValue(), + index, + shard, + primary, + null + ); + clusterAllocationExplainRequest.includeYesDecisions(false); + clusterAllocationExplainRequest.includeDiskInfo(true); + + String expected = "ClusterAllocationExplainRequest[index=" + + index + + ",shard=" + + shard + + ",primary?=" + + primary + + ",include_yes_decisions?=false" + + ",include_disk_info?=true"; + assertEquals(expected, clusterAllocationExplainRequest.toString()); + } + + public void testToStringWithAllBodyParameters() { + String index = "test-index"; + int shard = randomInt(); + boolean primary = randomBoolean(); + String currentNode = "current_node"; + ClusterAllocationExplainRequest clusterAllocationExplainRequest = new ClusterAllocationExplainRequest( + randomTimeValue(), + index, + shard, + primary, + currentNode + ); + clusterAllocationExplainRequest.includeYesDecisions(false); + clusterAllocationExplainRequest.includeDiskInfo(true); + + String expected = "ClusterAllocationExplainRequest[index=" + + index + + ",shard=" + + shard + + ",primary?=" + + primary + + ",current_node=" + + currentNode + + ",include_yes_decisions?=false" + + ",include_disk_info?=true"; + assertEquals(expected, clusterAllocationExplainRequest.toString()); + } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java index c422efed0f254..eab44f1c56b16 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportSimulateBulkActionTests.java @@ -19,7 +19,8 @@ import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; -import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.project.TestProjectResolvers; @@ -229,7 +230,7 @@ public void testIndexDataWithValidation() throws IOException { Map indicesMap = new HashMap<>(); Map v1Templates = new HashMap<>(); Map v2Templates = new HashMap<>(); - Metadata.Builder metadataBuilder = new Metadata.Builder(); + ProjectMetadata.Builder projectBuilder = ProjectMetadata.builder(ProjectId.DEFAULT); Set indicesWithInvalidMappings = new HashSet<>(); for (int i = 0; i < bulkItemCount; i++) { Map source = Map.of(randomAlphaOfLength(10), randomAlphaOfLength(5)); @@ -275,10 +276,10 @@ public void testIndexDataWithValidation() throws IOException { default -> throw new AssertionError("Illegal branch"); } } - metadataBuilder.indices(indicesMap); - metadataBuilder.templates(v1Templates); - metadataBuilder.indexTemplates(v2Templates); - ClusterServiceUtils.setState(clusterService, new ClusterState.Builder(clusterService.state()).metadata(metadataBuilder)); + projectBuilder.indices(indicesMap); + projectBuilder.templates(v1Templates); + projectBuilder.indexTemplates(v2Templates); + ClusterServiceUtils.setState(clusterService, ClusterState.builder(clusterService.state()).putProjectMetadata(projectBuilder)); AtomicBoolean onResponseCalled = new AtomicBoolean(false); ActionListener listener = new ActionListener<>() { @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/OpenPointInTimeResponseTests.java b/server/src/test/java/org/elasticsearch/action/search/OpenPointInTimeResponseTests.java new file mode 100644 index 0000000000000..3ed88419ab8d5 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/OpenPointInTimeResponseTests.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.action.search; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.Base64; +import java.util.Locale; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; + +public class OpenPointInTimeResponseTests extends ESTestCase { + + public void testIdCantBeNull() { + BytesReference pointInTimeId = null; + expectThrows(NullPointerException.class, () -> { new OpenPointInTimeResponse(pointInTimeId, 11, 8, 2, 1); }); + } + + public void testToXContent() throws IOException { + String id = "test-id"; + BytesReference pointInTimeId = new BytesArray(id); + + BytesReference actual; + try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) { + OpenPointInTimeResponse response = new OpenPointInTimeResponse(pointInTimeId, 11, 8, 2, 1); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + actual = BytesReference.bytes(builder); + } + + String encodedId = Base64.getUrlEncoder().encodeToString(BytesReference.toBytes(pointInTimeId)); + BytesReference expected = new BytesArray(String.format(Locale.ROOT, """ + { + "id": "%s", + "_shards": { + "total": 11, + "successful": 8, + "failed": 2, + "skipped": 1 + } + } + """, encodedId)); + assertToXContentEquivalent(expected, actual, XContentType.JSON); + } +} diff --git a/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java b/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java index 25932c9e8b9f3..e0e749aaa2360 100644 --- a/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.AbstractChunkedSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.threadpool.ThreadPool; import java.util.HashMap; import java.util.Map; @@ -42,10 +43,21 @@ public static ClusterInfo randomClusterInfo() { randomDataSetSizes(), randomRoutingToDataPath(), randomReservedSpace(), - randomNodeHeapUsage() + randomNodeHeapUsage(), + randomNodeUsageStatsForThreadPools(), + randomShardWriteLoad() ); } + private static Map randomShardWriteLoad() { + final int numEntries = randomIntBetween(0, 128); + final Map builder = new HashMap<>(numEntries); + for (int i = 0; i < numEntries; i++) { + builder.put(randomShardId(), randomDouble()); + } + return builder; + } + private static Map randomNodeHeapUsage() { int numEntries = randomIntBetween(0, 128); Map nodeHeapUsage = new HashMap<>(numEntries); @@ -62,6 +74,23 @@ private static Map randomNodeHeapUsage() { return nodeHeapUsage; } + private static Map randomNodeUsageStatsForThreadPools() { + int numEntries = randomIntBetween(0, 128); + Map nodeUsageStatsForThreadPools = new HashMap<>(numEntries); + for (int i = 0; i < numEntries; i++) { + String nodeIdKey = randomAlphaOfLength(32); + NodeUsageStatsForThreadPools.ThreadPoolUsageStats writeThreadPoolUsageStats = + new NodeUsageStatsForThreadPools.ThreadPoolUsageStats(/* totalThreadPoolThreads= */ randomIntBetween(1, 16), + /* averageThreadPoolUtilization= */ randomFloat(), + /* averageThreadPoolQueueLatencyMillis= */ randomLongBetween(0, 50000) + ); + Map usageStatsForThreadPools = new HashMap<>(); + usageStatsForThreadPools.put(ThreadPool.Names.WRITE, writeThreadPoolUsageStats); + nodeUsageStatsForThreadPools.put(ThreadPool.Names.WRITE, new NodeUsageStatsForThreadPools(nodeIdKey, usageStatsForThreadPools)); + } + return nodeUsageStatsForThreadPools; + } + private static Map randomDiskUsage() { int numEntries = randomIntBetween(0, 128); Map builder = new HashMap<>(numEntries); diff --git a/server/src/test/java/org/elasticsearch/cluster/ClusterStateSerializationTests.java b/server/src/test/java/org/elasticsearch/cluster/ClusterStateSerializationTests.java deleted file mode 100644 index 59baef0ebe05a..0000000000000 --- a/server/src/test/java/org/elasticsearch/cluster/ClusterStateSerializationTests.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.cluster; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.cluster.metadata.DataStreamTestHelper; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.cluster.metadata.ProjectMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.cluster.node.DiscoveryNodes; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.MapMatcher; -import org.elasticsearch.test.TransportVersionUtils; -import org.elasticsearch.test.XContentTestUtils; -import org.hamcrest.Matchers; - -import java.io.IOException; -import java.util.List; -import java.util.Set; - -public class ClusterStateSerializationTests extends ESTestCase { - - public void testSerializationInCurrentVersion() throws IOException { - assertSerializationRoundTrip(TransportVersion.current()); - } - - public void testSerializationPreMultiProject() throws IOException { - assertSerializationRoundTrip(TransportVersionUtils.getPreviousVersion(TransportVersions.MULTI_PROJECT)); - } - - private void assertSerializationRoundTrip(TransportVersion transportVersion) throws IOException { - ClusterState original = randomClusterState(transportVersion); - DiscoveryNode node = original.nodes().getLocalNode(); - assertThat(node, Matchers.notNullValue()); - - final ClusterState deserialized = ESTestCase.copyWriteable( - original, - new NamedWriteableRegistry(ClusterModule.getNamedWriteables()), - in -> ClusterState.readFrom(in, node), - transportVersion - ); - assertEquivalent("For transport version: " + transportVersion, original, deserialized); - } - - private void assertEquivalent(String context, ClusterState expected, ClusterState actual) throws IOException { - if (expected == actual) { - return; - } - // The simplest model we have for comparing equivalence is by comparing the XContent of the cluster state - var expectedJson = XContentTestUtils.convertToMap(expected); - var actualJson = XContentTestUtils.convertToMap(actual); - assertThat(context, actualJson, MapMatcher.matchesMap(expectedJson)); - } - - private ClusterState randomClusterState(TransportVersion transportVersion) { - final Set datastreamNames = randomSet(0, 10, () -> randomAlphaOfLengthBetween(4, 18)); - final List> datastreams = datastreamNames.stream() - .map(name -> new Tuple<>(name, randomIntBetween(1, 5))) - .toList(); - final List indices = List.copyOf( - randomSet(0, 10, () -> randomValueOtherThanMany(datastreamNames::contains, () -> randomAlphaOfLengthBetween(3, 12))) - ); - - final DiscoveryNodes.Builder nodes = DiscoveryNodes.builder(); - do { - final String id = randomUUID(); - nodes.add(DiscoveryNodeUtils.create(id)); - nodes.localNodeId(id); - } while (randomBoolean()); - - ProjectMetadata project = DataStreamTestHelper.getProjectWithDataStreams(datastreams, indices); - return ClusterState.builder(ClusterName.DEFAULT).metadata(Metadata.builder().put(project)).nodes(nodes).build(); - } -} diff --git a/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java b/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java index 80c2395ae9644..3551971f0daa0 100644 --- a/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.routing.ShardRoutingHelper; import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.IndexingStats; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardPath; import org.elasticsearch.index.store.StoreStats; @@ -107,6 +108,7 @@ public void testFillShardLevelInfo() { Path test0Path = createTempDir().resolve("indices").resolve(index.getUUID()).resolve("0"); CommonStats commonStats0 = new CommonStats(); commonStats0.store = new StoreStats(100, 101, 0L); + commonStats0.indexing = randomIndexingStats(); ShardRouting test_1 = ShardRouting.newUnassigned( new ShardId(index, 1), false, @@ -119,8 +121,10 @@ public void testFillShardLevelInfo() { Path test1Path = createTempDir().resolve("indices").resolve(index.getUUID()).resolve("1"); CommonStats commonStats1 = new CommonStats(); commonStats1.store = new StoreStats(1000, 1001, 0L); + commonStats1.indexing = randomIndexingStats(); CommonStats commonStats2 = new CommonStats(); commonStats2.store = new StoreStats(1000, 999, 0L); + commonStats2.indexing = randomIndexingStats(); ShardStats[] stats = new ShardStats[] { new ShardStats(test_0, new ShardPath(false, test0Path, test0Path, test_0.shardId()), commonStats0, null, null, null, false, 0), new ShardStats(test_1, new ShardPath(false, test1Path, test1Path, test_1.shardId()), commonStats1, null, null, null, false, 0), @@ -135,9 +139,17 @@ public void testFillShardLevelInfo() { 0 ) }; Map shardSizes = new HashMap<>(); + Map shardWriteLoads = new HashMap<>(); Map shardDataSetSizes = new HashMap<>(); Map routingToPath = new HashMap<>(); - InternalClusterInfoService.buildShardLevelInfo(stats, shardSizes, shardDataSetSizes, routingToPath, new HashMap<>()); + InternalClusterInfoService.buildShardLevelInfo( + stats, + shardWriteLoads, + shardSizes, + shardDataSetSizes, + routingToPath, + new HashMap<>() + ); assertThat( shardSizes, @@ -158,6 +170,41 @@ public void testFillShardLevelInfo() { hasEntry(ClusterInfo.NodeAndShard.from(test_1), test1Path.getParent().getParent().getParent().toAbsolutePath().toString()) ) ); + + assertThat( + shardWriteLoads, + equalTo( + Map.of( + test_0.shardId(), + commonStats0.indexing.getTotal().getPeakWriteLoad(), + test_1.shardId(), + Math.max(commonStats1.indexing.getTotal().getPeakWriteLoad(), commonStats2.indexing.getTotal().getPeakWriteLoad()) + ) + ) + ); + } + + private IndexingStats randomIndexingStats() { + return new IndexingStats( + new IndexingStats.Stats( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomMillisUpToYear9999(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomBoolean(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomDoubleBetween(0d, 10d, true), + randomDoubleBetween(0d, 10d, true) + ) + ); } public void testLeastAndMostAvailableDiskSpace() { diff --git a/server/src/test/java/org/elasticsearch/cluster/InternalClusterInfoServiceSchedulingTests.java b/server/src/test/java/org/elasticsearch/cluster/InternalClusterInfoServiceSchedulingTests.java index ea9c940793778..6e80e0d087993 100644 --- a/server/src/test/java/org/elasticsearch/cluster/InternalClusterInfoServiceSchedulingTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/InternalClusterInfoServiceSchedulingTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.routing.allocation.WriteLoadConstraintSettings; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.FakeThreadPoolMasterService; @@ -55,7 +56,11 @@ public void testScheduling() { final Settings.Builder settingsBuilder = Settings.builder() .put(Node.NODE_NAME_SETTING.getKey(), discoveryNode.getName()) - .put(InternalClusterInfoService.CLUSTER_ROUTING_ALLOCATION_ESTIMATED_HEAP_THRESHOLD_DECIDER_ENABLED.getKey(), true); + .put(InternalClusterInfoService.CLUSTER_ROUTING_ALLOCATION_ESTIMATED_HEAP_THRESHOLD_DECIDER_ENABLED.getKey(), true) + .put( + WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey(), + WriteLoadConstraintSettings.WriteLoadDeciderStatus.ENABLED + ); if (randomBoolean()) { settingsBuilder.put(INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING.getKey(), randomIntBetween(10000, 60000) + "ms"); } @@ -79,12 +84,16 @@ protected PrioritizedEsThreadPoolExecutor createThreadPoolExecutor() { final FakeClusterInfoServiceClient client = new FakeClusterInfoServiceClient(threadPool); final EstimatedHeapUsageCollector mockEstimatedHeapUsageCollector = spy(new StubEstimatedEstimatedHeapUsageCollector()); + final NodeUsageStatsForThreadPoolsCollector mockNodeUsageStatsForThreadPoolsCollector = spy( + new StubNodeUsageStatsForThreadPoolsCollector() + ); final InternalClusterInfoService clusterInfoService = new InternalClusterInfoService( settings, clusterService, threadPool, client, - mockEstimatedHeapUsageCollector + mockEstimatedHeapUsageCollector, + mockNodeUsageStatsForThreadPoolsCollector ); clusterService.addListener(clusterInfoService); clusterInfoService.addListener(ignored -> {}); @@ -122,12 +131,14 @@ protected PrioritizedEsThreadPoolExecutor createThreadPoolExecutor() { for (int i = 0; i < 3; i++) { Mockito.clearInvocations(mockEstimatedHeapUsageCollector); + Mockito.clearInvocations(mockNodeUsageStatsForThreadPoolsCollector); final int initialRequestCount = client.requestCount; final long duration = INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING.get(settings).millis(); runFor(deterministicTaskQueue, duration); deterministicTaskQueue.runAllRunnableTasks(); assertThat(client.requestCount, equalTo(initialRequestCount + 2)); // should have run two client requests per interval verify(mockEstimatedHeapUsageCollector).collectClusterHeapUsage(any()); // Should poll for heap usage once per interval + verify(mockNodeUsageStatsForThreadPoolsCollector).collectUsageStats(any()); } final AtomicBoolean failMaster2 = new AtomicBoolean(); @@ -152,6 +163,17 @@ public void collectClusterHeapUsage(ActionListener> listener) } } + /** + * Simple for test {@link NodeUsageStatsForThreadPoolsCollector} implementation that returns an empty map of nodeId string to + * {@link NodeUsageStatsForThreadPools}. + */ + private static class StubNodeUsageStatsForThreadPoolsCollector implements NodeUsageStatsForThreadPoolsCollector { + @Override + public void collectUsageStats(ActionListener> listener) { + listener.onResponse(Map.of()); + } + } + private static void runFor(DeterministicTaskQueue deterministicTaskQueue, long duration) { final long endTime = deterministicTaskQueue.getCurrentTimeMillis() + duration; while (deterministicTaskQueue.getCurrentTimeMillis() < endTime diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java index 7c0fc80ea73c4..1c57376a89295 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.cluster.routing.RoutingTable; import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.routing.allocation.DataTier; +import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator; import org.elasticsearch.cluster.routing.allocation.allocator.BalancedShardsAllocator; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider; @@ -791,6 +792,55 @@ public boolean overrulesTemplateAndRequestSettings() { assertThat(aggregatedIndexSettings.get("other_setting"), equalTo("other_value")); } + /** + * When a failure store index is created, we must filter out any unsupported settings from the create request or from the template that + * may have been provided by users in the create request or from the original data stream template. An exception to this is any settings + * that have been provided by index setting providers which should be considered default values on indices. + */ + public void testAggregateSettingsProviderIsNotFilteredOnFailureStore() { + IndexTemplateMetadata templateMetadata = addMatchingTemplate(builder -> { + builder.settings(Settings.builder().put("template_setting", "value1")); + }); + ProjectMetadata projectMetadata = ProjectMetadata.builder(projectId).templates(Map.of("template_1", templateMetadata)).build(); + ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).putProjectMetadata(projectMetadata).build(); + var request = new CreateIndexClusterStateUpdateRequest("create index", projectId, "test", "test").settings( + Settings.builder().put("request_setting", "value2").build() + ).isFailureIndex(true); + + Settings aggregatedIndexSettings = aggregateIndexSettings( + clusterState, + request, + templateMetadata.settings(), + null, + null, + Settings.EMPTY, + IndexScopedSettings.DEFAULT_SCOPED_SETTINGS, + randomShardLimitService(), + Set.of(new IndexSettingProvider() { + @Override + public Settings getAdditionalIndexSettings( + String indexName, + String dataStreamName, + IndexMode templateIndexMode, + ProjectMetadata projectMetadata, + Instant resolvedAt, + Settings indexTemplateAndCreateRequestSettings, + List combinedTemplateMappings + ) { + return Settings.builder().put(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_SETTING.getKey(), "override").build(); + } + + @Override + public boolean overrulesTemplateAndRequestSettings() { + return true; + } + }) + ); + + assertThat(aggregatedIndexSettings.get("template_setting"), nullValue()); + assertThat(aggregatedIndexSettings.get(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_SETTING.getKey()), equalTo("override")); + } + public void testAggregateSettingsProviderOverrulesNullFromRequest() { IndexTemplateMetadata templateMetadata = addMatchingTemplate(builder -> { builder.settings(Settings.builder().put("template_setting", "value1")); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataPersistentTasksTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataPersistentTasksTests.java index ac3644b6ecc8a..a1780c8a69b5b 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataPersistentTasksTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataPersistentTasksTests.java @@ -306,13 +306,16 @@ private Tuple randomMetadataAndUpdate() { ClusterPersistentTasksCustomMetadata::new ) ) - .putProjectCustom( - PersistentTasksCustomMetadata.TYPE, - mutatePersistentTasks( - PersistentTasksCustomMetadata.get(before.getProject(Metadata.DEFAULT_PROJECT_ID)), - MetadataPersistentTasksTests::oneProjectPersistentTask, - PersistentTasksCustomMetadata::new - ) + .put( + ProjectMetadata.builder(before.getProject(Metadata.DEFAULT_PROJECT_ID)) + .putCustom( + PersistentTasksCustomMetadata.TYPE, + mutatePersistentTasks( + PersistentTasksCustomMetadata.get(before.getProject(Metadata.DEFAULT_PROJECT_ID)), + MetadataPersistentTasksTests::oneProjectPersistentTask, + PersistentTasksCustomMetadata::new + ) + ) ) .build(); return new Tuple<>(before, after); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataTests.java index 6eecf4921c50e..373831b54804b 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataTests.java @@ -9,11 +9,9 @@ package org.elasticsearch.cluster.metadata; -import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.ClusterModule; -import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.coordination.CoordinationMetadata; import org.elasticsearch.cluster.coordination.CoordinationMetadata.VotingConfigExclusion; @@ -22,7 +20,6 @@ import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -33,22 +30,13 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.FixForMultiProject; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Predicates; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.health.node.selection.HealthNode; import org.elasticsearch.health.node.selection.HealthNodeTaskExecutor; import org.elasticsearch.health.node.selection.HealthNodeTaskParams; import org.elasticsearch.index.Index; -import org.elasticsearch.index.IndexMode; -import org.elasticsearch.index.IndexNotFoundException; -import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.alias.RandomAliasActionsGenerator; -import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.indices.IndicesModule; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; @@ -56,10 +44,9 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksExecutorRegistry; import org.elasticsearch.persistent.PersistentTasksService; -import org.elasticsearch.plugins.FieldPredicate; -import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.test.AbstractChunkedSerializingTestCase; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.test.rest.ObjectPath; import org.elasticsearch.threadpool.ThreadPool; @@ -89,31 +76,21 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.SortedMap; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; -import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.createBackingIndex; -import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.createFirstBackingIndex; -import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.newInstance; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_API; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_SNAPSHOT; -import static org.elasticsearch.cluster.metadata.ProjectMetadata.Builder.assertDataStreams; -import static org.elasticsearch.test.LambdaMatchers.transformedItemsMatch; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresentWith; import static org.hamcrest.Matchers.aMapWithSize; -import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -121,472 +98,11 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; -import static org.hamcrest.Matchers.startsWith; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class MetadataTests extends ESTestCase { - public void testFindAliases() { - Metadata metadata = Metadata.builder() - .put( - IndexMetadata.builder("index") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("alias1").build()) - .putAlias(AliasMetadata.builder("alias2").build()) - ) - .put( - IndexMetadata.builder("index2") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("alias2").build()) - .putAlias(AliasMetadata.builder("alias3").build()) - ) - .build(); - - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT); - Map> aliases = metadata.getProject().findAliases(request.aliases(), Strings.EMPTY_ARRAY); - assertThat(aliases, anEmptyMap()); - } - { - final GetAliasesRequest request; - if (randomBoolean()) { - request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT); - } else { - request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT, randomFrom("alias1", "alias2")); - // replacing with empty aliases behaves as if aliases were unspecified at request building - request.replaceAliases(Strings.EMPTY_ARRAY); - } - Map> aliases = metadata.getProject().findAliases(request.aliases(), new String[] { "index" }); - assertThat(aliases, aMapWithSize(1)); - List aliasMetadataList = aliases.get("index"); - assertThat(aliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1", "alias2"))); - } - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT, "alias*"); - Map> aliases = metadata.getProject() - .findAliases(request.aliases(), new String[] { "index", "index2" }); - assertThat(aliases, aMapWithSize(2)); - List indexAliasMetadataList = aliases.get("index"); - assertThat(indexAliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1", "alias2"))); - List index2AliasMetadataList = aliases.get("index2"); - assertThat(index2AliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias2", "alias3"))); - } - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT, "alias1"); - Map> aliases = metadata.getProject().findAliases(request.aliases(), new String[] { "index" }); - assertThat(aliases, aMapWithSize(1)); - List aliasMetadataList = aliases.get("index"); - assertThat(aliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1"))); - } - { - Map> aliases = metadata.getProject().findAllAliases(new String[] { "index" }); - assertThat(aliases, aMapWithSize(1)); - List aliasMetadataList = aliases.get("index"); - assertThat(aliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1", "alias2"))); - } - { - Map> aliases = metadata.getProject().findAllAliases(Strings.EMPTY_ARRAY); - assertThat(aliases, anEmptyMap()); - } - } - - public void testFindDataStreamAliases() { - Metadata.Builder builder = Metadata.builder(); - - addDataStream("d1", builder); - addDataStream("d2", builder); - addDataStream("d3", builder); - addDataStream("d4", builder); - - builder.put("alias1", "d1", null, null); - builder.put("alias2", "d2", null, null); - builder.put("alias2-part2", "d2", null, null); - - Metadata metadata = builder.build(); - - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT); - Map> aliases = metadata.getProject() - .findDataStreamAliases(request.aliases(), Strings.EMPTY_ARRAY); - assertThat(aliases, anEmptyMap()); - } - - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("alias1"); - Map> aliases = metadata.getProject() - .findDataStreamAliases(request.aliases(), new String[] { "index" }); - assertThat(aliases, anEmptyMap()); - } - - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("alias1"); - Map> aliases = metadata.getProject() - .findDataStreamAliases(request.aliases(), new String[] { "index", "d1", "d2" }); - assertEquals(1, aliases.size()); - List found = aliases.get("d1"); - assertThat(found, transformedItemsMatch(DataStreamAlias::getAlias, contains("alias1"))); - } - - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("ali*"); - Map> aliases = metadata.getProject() - .findDataStreamAliases(request.aliases(), new String[] { "index", "d2" }); - assertEquals(1, aliases.size()); - List found = aliases.get("d2"); - assertThat(found, transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); - } - - // test exclusion - { - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("*"); - Map> aliases = metadata.getProject() - .findDataStreamAliases(request.aliases(), new String[] { "index", "d1", "d2", "d3", "d4" }); - assertThat(aliases.get("d2"), transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); - assertThat(aliases.get("d1"), transformedItemsMatch(DataStreamAlias::getAlias, contains("alias1"))); - - request.aliases("*", "-alias1"); - aliases = metadata.getProject().findDataStreamAliases(request.aliases(), new String[] { "index", "d1", "d2", "d3", "d4" }); - assertThat(aliases.get("d2"), transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); - assertNull(aliases.get("d1")); - } - } - - public void testDataStreamAliasesByDataStream() { - Metadata.Builder builder = Metadata.builder(); - - addDataStream("d1", builder); - addDataStream("d2", builder); - addDataStream("d3", builder); - addDataStream("d4", builder); - - builder.put("alias1", "d1", null, null); - builder.put("alias2", "d2", null, null); - builder.put("alias2-part2", "d2", null, null); - - Metadata metadata = builder.build(); - - var aliases = metadata.getProject().dataStreamAliasesByDataStream(); - - assertTrue(aliases.containsKey("d1")); - assertTrue(aliases.containsKey("d2")); - assertFalse(aliases.containsKey("d3")); - assertFalse(aliases.containsKey("d4")); - - assertEquals(1, aliases.get("d1").size()); - assertEquals(2, aliases.get("d2").size()); - - assertThat(aliases.get("d2"), transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); - } - - public void testFindAliasWithExclusion() { - Metadata metadata = Metadata.builder() - .put( - IndexMetadata.builder("index") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("alias1").build()) - .putAlias(AliasMetadata.builder("alias2").build()) - ) - .put( - IndexMetadata.builder("index2") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("alias1").build()) - .putAlias(AliasMetadata.builder("alias3").build()) - ) - .build(); - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("*", "-alias1"); - Map> aliases = metadata.getProject().findAliases(request.aliases(), new String[] { "index", "index2" }); - assertThat(aliases.get("index"), transformedItemsMatch(AliasMetadata::alias, contains("alias2"))); - assertThat(aliases.get("index2"), transformedItemsMatch(AliasMetadata::alias, contains("alias3"))); - } - - public void testFindDataStreams() { - final int numIndices = randomIntBetween(2, 5); - final int numBackingIndices = randomIntBetween(2, 5); - final String dataStreamName = "my-data-stream"; - CreateIndexResult result = createIndices(numIndices, numBackingIndices, dataStreamName); - - List allIndices = new ArrayList<>(result.indices); - allIndices.addAll(result.backingIndices); - String[] concreteIndices = allIndices.stream().map(Index::getName).toArray(String[]::new); - Map dataStreams = result.metadata.getProject().findDataStreams(concreteIndices); - assertThat(dataStreams, aMapWithSize(numBackingIndices)); - for (Index backingIndex : result.backingIndices) { - assertThat(dataStreams, hasKey(backingIndex.getName())); - assertThat(dataStreams.get(backingIndex.getName()).getName(), equalTo(dataStreamName)); - } - } - - public void testFindAliasWithExclusionAndOverride() { - Metadata metadata = Metadata.builder() - .put( - IndexMetadata.builder("index") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("aa").build()) - .putAlias(AliasMetadata.builder("ab").build()) - .putAlias(AliasMetadata.builder("bb").build()) - ) - .build(); - GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("a*", "-*b", "b*"); - List aliases = metadata.getProject().findAliases(request.aliases(), new String[] { "index" }).get("index"); - assertThat(aliases, transformedItemsMatch(AliasMetadata::alias, contains("aa", "bb"))); - } - - public void testAliasCollidingWithAnExistingIndex() { - int indexCount = randomIntBetween(10, 100); - Set indices = Sets.newHashSetWithExpectedSize(indexCount); - for (int i = 0; i < indexCount; i++) { - indices.add(randomAlphaOfLength(10)); - } - Map> aliasToIndices = new HashMap<>(); - for (String alias : randomSubsetOf(randomIntBetween(1, 10), indices)) { - Set indicesInAlias; - do { - indicesInAlias = new HashSet<>(randomSubsetOf(randomIntBetween(1, 3), indices)); - indicesInAlias.remove(alias); - } while (indicesInAlias.isEmpty()); - aliasToIndices.put(alias, indicesInAlias); - } - int properAliases = randomIntBetween(0, 3); - for (int i = 0; i < properAliases; i++) { - aliasToIndices.put(randomAlphaOfLength(5), new HashSet<>(randomSubsetOf(randomIntBetween(1, 3), indices))); - } - Metadata.Builder metadataBuilder = Metadata.builder(); - for (String index : indices) { - IndexMetadata.Builder indexBuilder = IndexMetadata.builder(index) - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0); - aliasToIndices.forEach((key, value) -> { - if (value.contains(index)) { - indexBuilder.putAlias(AliasMetadata.builder(key).build()); - } - }); - metadataBuilder.put(indexBuilder); - } - - Exception e = expectThrows(IllegalStateException.class, metadataBuilder::build); - assertThat(e.getMessage(), startsWith("index, alias, and data stream names need to be unique")); - } - - public void testValidateAliasWriteOnly() { - String alias = randomAlphaOfLength(5); - String indexA = randomAlphaOfLength(6); - String indexB = randomAlphaOfLength(7); - Boolean aWriteIndex = randomBoolean() ? null : randomBoolean(); - Boolean bWriteIndex; - if (Boolean.TRUE.equals(aWriteIndex)) { - bWriteIndex = randomFrom(Boolean.FALSE, null); - } else { - bWriteIndex = randomFrom(Boolean.TRUE, Boolean.FALSE, null); - } - // when only one index/alias pair exist - Metadata metadata = Metadata.builder().put(buildIndexMetadata(indexA, alias, aWriteIndex)).build(); - - // when alias points to two indices, but valid - // one of the following combinations: [(null, null), (null, true), (null, false), (false, false)] - Metadata.builder(metadata).put(buildIndexMetadata(indexB, alias, bWriteIndex)).build(); - - // when too many write indices - Exception exception = expectThrows(IllegalStateException.class, () -> { - IndexMetadata.Builder metaA = buildIndexMetadata(indexA, alias, true); - IndexMetadata.Builder metaB = buildIndexMetadata(indexB, alias, true); - Metadata.builder().put(metaA).put(metaB).build(); - }); - assertThat(exception.getMessage(), startsWith("alias [" + alias + "] has more than one write index [")); - } - - public void testValidateHiddenAliasConsistency() { - String alias = randomAlphaOfLength(5); - String indexA = randomAlphaOfLength(6); - String indexB = randomAlphaOfLength(7); - - { - Exception ex = expectThrows( - IllegalStateException.class, - () -> buildMetadataWithHiddenIndexMix(alias, indexA, true, indexB, randomFrom(false, null)).build() - ); - assertThat(ex.getMessage(), containsString("has is_hidden set to true on indices")); - } - - { - Exception ex = expectThrows( - IllegalStateException.class, - () -> buildMetadataWithHiddenIndexMix(alias, indexA, randomFrom(false, null), indexB, true).build() - ); - assertThat(ex.getMessage(), containsString("has is_hidden set to true on indices")); - } - } - - private Metadata.Builder buildMetadataWithHiddenIndexMix( - String aliasName, - String indexAName, - Boolean indexAHidden, - String indexBName, - Boolean indexBHidden - ) { - IndexMetadata.Builder indexAMeta = IndexMetadata.builder(indexAName) - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder(aliasName).isHidden(indexAHidden).build()); - IndexMetadata.Builder indexBMeta = IndexMetadata.builder(indexBName) - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder(aliasName).isHidden(indexBHidden).build()); - return Metadata.builder().put(indexAMeta).put(indexBMeta); - } - - public void testResolveIndexRouting() { - IndexMetadata.Builder builder = IndexMetadata.builder("index") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("alias0").build()) - .putAlias(AliasMetadata.builder("alias1").routing("1").build()) - .putAlias(AliasMetadata.builder("alias2").routing("1,2").build()); - Metadata metadata = Metadata.builder().put(builder).build(); - - // no alias, no index - assertNull(metadata.getProject().resolveIndexRouting(null, null)); - assertEquals(metadata.getProject().resolveIndexRouting("0", null), "0"); - - // index, no alias - assertNull(metadata.getProject().resolveIndexRouting(null, "index")); - assertEquals(metadata.getProject().resolveIndexRouting("0", "index"), "0"); - - // alias with no index routing - assertNull(metadata.getProject().resolveIndexRouting(null, "alias0")); - assertEquals(metadata.getProject().resolveIndexRouting("0", "alias0"), "0"); - - // alias with index routing. - assertEquals(metadata.getProject().resolveIndexRouting(null, "alias1"), "1"); - Exception ex = expectThrows(IllegalArgumentException.class, () -> metadata.getProject().resolveIndexRouting("0", "alias1")); - assertThat( - ex.getMessage(), - is("Alias [alias1] has index routing associated with it [1], and was provided with routing value [0], rejecting operation") - ); - - // alias with invalid index routing. - ex = expectThrows(IllegalArgumentException.class, () -> metadata.getProject().resolveIndexRouting(null, "alias2")); - assertThat( - ex.getMessage(), - is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") - ); - - ex = expectThrows(IllegalArgumentException.class, () -> metadata.getProject().resolveIndexRouting("1", "alias2")); - assertThat( - ex.getMessage(), - is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") - ); - - IndexMetadata.Builder builder2 = IndexMetadata.builder("index2") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(AliasMetadata.builder("alias0").build()); - Metadata metadataTwoIndices = Metadata.builder(metadata).put(builder2).build(); - - // alias with multiple indices - IllegalArgumentException exception = expectThrows( - IllegalArgumentException.class, - () -> metadataTwoIndices.getProject().resolveIndexRouting("1", "alias0") - ); - assertThat(exception.getMessage(), startsWith("Alias [alias0] has more than one index associated with it")); - } - - public void testResolveWriteIndexRouting() { - AliasMetadata.Builder aliasZeroBuilder = AliasMetadata.builder("alias0"); - if (randomBoolean()) { - aliasZeroBuilder.writeIndex(true); - } - IndexMetadata.Builder builder = IndexMetadata.builder("index") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(aliasZeroBuilder.build()) - .putAlias(AliasMetadata.builder("alias1").routing("1").build()) - .putAlias(AliasMetadata.builder("alias2").routing("1,2").build()) - .putAlias(AliasMetadata.builder("alias3").writeIndex(false).build()) - .putAlias(AliasMetadata.builder("alias4").routing("1,2").writeIndex(true).build()); - Metadata metadata = Metadata.builder().put(builder).build(); - - // no alias, no index - assertNull(metadata.getProject().resolveWriteIndexRouting(null, null)); - assertEquals(metadata.getProject().resolveWriteIndexRouting("0", null), "0"); - - // index, no alias - assertNull(metadata.getProject().resolveWriteIndexRouting(null, "index")); - assertEquals(metadata.getProject().resolveWriteIndexRouting("0", "index"), "0"); - - // alias with no index routing - assertNull(metadata.getProject().resolveWriteIndexRouting(null, "alias0")); - assertEquals(metadata.getProject().resolveWriteIndexRouting("0", "alias0"), "0"); - - // alias with index routing. - assertEquals(metadata.getProject().resolveWriteIndexRouting(null, "alias1"), "1"); - Exception exception = expectThrows( - IllegalArgumentException.class, - () -> metadata.getProject().resolveWriteIndexRouting("0", "alias1") - ); - assertThat( - exception.getMessage(), - is("Alias [alias1] has index routing associated with it [1], and was provided with routing value [0], rejecting operation") - ); - - // alias with invalid index routing. - exception = expectThrows(IllegalArgumentException.class, () -> metadata.getProject().resolveWriteIndexRouting(null, "alias2")); - assertThat( - exception.getMessage(), - is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") - ); - exception = expectThrows(IllegalArgumentException.class, () -> metadata.getProject().resolveWriteIndexRouting("1", "alias2")); - assertThat( - exception.getMessage(), - is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") - ); - exception = expectThrows( - IllegalArgumentException.class, - () -> metadata.getProject().resolveWriteIndexRouting(randomFrom("1", null), "alias4") - ); - assertThat( - exception.getMessage(), - is("index/alias [alias4] provided with routing value [1,2] that resolved to several routing values, rejecting operation") - ); - - // alias with no write index - exception = expectThrows(IllegalArgumentException.class, () -> metadata.getProject().resolveWriteIndexRouting("1", "alias3")); - assertThat(exception.getMessage(), is("alias [alias3] does not have a write index")); - - // aliases with multiple indices - AliasMetadata.Builder aliasZeroBuilderTwo = AliasMetadata.builder("alias0"); - if (randomBoolean()) { - aliasZeroBuilder.writeIndex(false); - } - IndexMetadata.Builder builder2 = IndexMetadata.builder("index2") - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(aliasZeroBuilderTwo.build()) - .putAlias(AliasMetadata.builder("alias1").routing("0").writeIndex(true).build()) - .putAlias(AliasMetadata.builder("alias2").writeIndex(true).build()); - Metadata metadataTwoIndices = Metadata.builder(metadata).put(builder2).build(); - - // verify that new write index is used - assertThat("0", equalTo(metadataTwoIndices.getProject().resolveWriteIndexRouting("0", "alias1"))); - } - public void testUnknownFieldClusterMetadata() throws IOException { BytesReference metadata = BytesReference.bytes( JsonXContent.contentBuilder().startObject().startObject("meta-data").field("random", "value").endObject().endObject() @@ -837,7 +353,7 @@ public void testParseXContentFormatBeforeMultiProject() throws IOException { containsInAnyOrder("health-node") ); assertThat( - metadata.getProject().customs().keySet(), + metadata.getProject(ProjectId.DEFAULT).customs().keySet(), containsInAnyOrder("persistent_tasks", "index-graveyard", "component_template", "repositories") ); assertThat(metadata.customs(), not(hasKey("repositories"))); @@ -915,1719 +431,207 @@ public void testParseXContentFormatBeforeRepositoriesMetadataMigration() throws "location": "backup" }, "generation": 42, - "pending_generation": 42 - } - }, - "reserved_state":{ } - } - } - """, IndexVersion.current(), IndexVersion.current()); - - final Metadata metadata = fromJsonXContentStringWithPersistentTasks(json); - assertThat(metadata, notNullValue()); - assertThat(metadata.clusterUUID(), is("aba1aa1ababbbaabaabaab")); - - assertThat(metadata.projects().keySet(), containsInAnyOrder(ProjectId.fromId("default"), ProjectId.fromId("another_project"))); - assertThat(metadata.customs(), not(hasKey("repositories"))); - final var repositoriesMetadata = RepositoriesMetadata.get(metadata.getProject(ProjectId.DEFAULT)); - assertThat( - repositoriesMetadata.repositories(), - equalTo( - List.of( - new RepositoryMetadata("my-repo", "_my-repo-uuid_", "fs", Settings.builder().put("location", "backup").build(), 42, 42) - ) - ) - ); - assertThat(metadata.getProject(ProjectId.fromId("another_project")).customs(), not(hasKey("repositories"))); - } - - private Metadata fromJsonXContentStringWithPersistentTasks(String json) throws IOException { - List registry = new ArrayList<>(); - registry.addAll(ClusterModule.getNamedXWriteables()); - registry.addAll(IndicesModule.getNamedXContents()); - registry.addAll(HealthNodeTaskExecutor.getNamedXContentParsers()); - - final var clusterService = mock(ClusterService.class); - when(clusterService.threadPool()).thenReturn(mock(ThreadPool.class)); - final var healthNodeTaskExecutor = HealthNodeTaskExecutor.create( - clusterService, - mock(PersistentTasksService.class), - Settings.EMPTY, - ClusterSettings.createBuiltInClusterSettings() - ); - new PersistentTasksExecutorRegistry(List.of(healthNodeTaskExecutor)); - - XContentParserConfiguration config = XContentParserConfiguration.EMPTY.withRegistry(new NamedXContentRegistry(registry)); - try (XContentParser parser = JsonXContent.jsonXContent.createParser(config, json)) { - return Metadata.fromXContent(parser); - } - } - - public void testGlobalStateEqualsCoordinationMetadata() { - CoordinationMetadata coordinationMetadata1 = new CoordinationMetadata( - randomNonNegativeLong(), - randomVotingConfig(), - randomVotingConfig(), - randomVotingConfigExclusions() - ); - Metadata metadata1 = Metadata.builder().coordinationMetadata(coordinationMetadata1).build(); - CoordinationMetadata coordinationMetadata2 = new CoordinationMetadata( - randomNonNegativeLong(), - randomVotingConfig(), - randomVotingConfig(), - randomVotingConfigExclusions() - ); - Metadata metadata2 = Metadata.builder().coordinationMetadata(coordinationMetadata2).build(); - - assertTrue(Metadata.isGlobalStateEquals(metadata1, metadata1)); - assertFalse(Metadata.isGlobalStateEquals(metadata1, metadata2)); - } - - public void testSerializationWithIndexGraveyard() throws IOException { - final var projectId = randomProjectIdOrDefault(); - final IndexGraveyard graveyard = IndexGraveyardTests.createRandom(); - final Metadata originalMeta = Metadata.builder().put(ProjectMetadata.builder(projectId).indexGraveyard(graveyard)).build(); - final BytesStreamOutput out = new BytesStreamOutput(); - originalMeta.writeTo(out); - NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); - final Metadata fromStreamMeta = Metadata.readFrom( - new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry) - ); - assertThat(fromStreamMeta.getProject(projectId).indexGraveyard(), equalTo(originalMeta.getProject(projectId).indexGraveyard())); - } - - public void testFindMappings() throws IOException { - Metadata metadata = Metadata.builder() - .put(IndexMetadata.builder("index1").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(FIND_MAPPINGS_TEST_ITEM)) - .put(IndexMetadata.builder("index2").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(FIND_MAPPINGS_TEST_ITEM)) - .build(); - - { - AtomicInteger onNextIndexCalls = new AtomicInteger(0); - Map mappings = metadata.getProject() - .findMappings(Strings.EMPTY_ARRAY, MapperPlugin.NOOP_FIELD_FILTER, onNextIndexCalls::incrementAndGet); - assertThat(mappings, anEmptyMap()); - assertThat(onNextIndexCalls.get(), equalTo(0)); - } - { - AtomicInteger onNextIndexCalls = new AtomicInteger(0); - Map mappings = metadata.getProject() - .findMappings(new String[] { "index1" }, MapperPlugin.NOOP_FIELD_FILTER, onNextIndexCalls::incrementAndGet); - assertThat(mappings, aMapWithSize(1)); - assertIndexMappingsNotFiltered(mappings, "index1"); - assertThat(onNextIndexCalls.get(), equalTo(1)); - } - { - AtomicInteger onNextIndexCalls = new AtomicInteger(0); - Map mappings = metadata.getProject() - .findMappings(new String[] { "index1", "index2" }, MapperPlugin.NOOP_FIELD_FILTER, onNextIndexCalls::incrementAndGet); - assertThat(mappings, aMapWithSize(2)); - assertIndexMappingsNotFiltered(mappings, "index1"); - assertIndexMappingsNotFiltered(mappings, "index2"); - assertThat(onNextIndexCalls.get(), equalTo(2)); - } - } - - public void testFindMappingsNoOpFilters() throws IOException { - MappingMetadata originalMappingMetadata = new MappingMetadata( - "_doc", - XContentHelper.convertToMap(JsonXContent.jsonXContent, FIND_MAPPINGS_TEST_ITEM, true) - ); - - Metadata metadata = Metadata.builder() - .put(IndexMetadata.builder("index1").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(originalMappingMetadata)) - .build(); - - { - Map mappings = metadata.getProject() - .findMappings(new String[] { "index1" }, MapperPlugin.NOOP_FIELD_FILTER, Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP); - MappingMetadata mappingMetadata = mappings.get("index1"); - assertSame(originalMappingMetadata, mappingMetadata); - } - { - Map mappings = metadata.getProject() - .findMappings(new String[] { "index1" }, index -> field -> randomBoolean(), Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP); - MappingMetadata mappingMetadata = mappings.get("index1"); - assertNotSame(originalMappingMetadata, mappingMetadata); - } - } - - @SuppressWarnings("unchecked") - public void testFindMappingsWithFilters() throws IOException { - String mapping = FIND_MAPPINGS_TEST_ITEM; - if (randomBoolean()) { - Map stringObjectMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, FIND_MAPPINGS_TEST_ITEM, false); - Map doc = (Map) stringObjectMap.get("_doc"); - try (XContentBuilder builder = JsonXContent.contentBuilder()) { - builder.map(doc); - mapping = Strings.toString(builder); - } - } - - Metadata metadata = Metadata.builder() - .put(IndexMetadata.builder("index1").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(mapping)) - .put(IndexMetadata.builder("index2").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(mapping)) - .put(IndexMetadata.builder("index3").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(mapping)) - .build(); - - { - Map mappings = metadata.getProject() - .findMappings(new String[] { "index1", "index2", "index3" }, index -> { - if (index.equals("index1")) { - return field -> field.startsWith("name.") == false - && field.startsWith("properties.key.") == false - && field.equals("age") == false - && field.equals("address.location") == false; - } - if (index.equals("index2")) { - return Predicates.never(); - } - return FieldPredicate.ACCEPT_ALL; - }, Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP); - - assertIndexMappingsNoFields(mappings, "index2"); - assertIndexMappingsNotFiltered(mappings, "index3"); - - MappingMetadata docMapping = mappings.get("index1"); - assertNotNull(docMapping); - - Map sourceAsMap = docMapping.getSourceAsMap(); - assertThat(sourceAsMap.keySet(), containsInAnyOrder("properties", "_routing", "_source")); - - Map typeProperties = (Map) sourceAsMap.get("properties"); - assertThat(typeProperties.keySet(), containsInAnyOrder("name", "address", "birth", "ip", "suggest", "properties")); - - Map name = (Map) typeProperties.get("name"); - assertThat(name.keySet(), containsInAnyOrder("properties")); - Map nameProperties = (Map) name.get("properties"); - assertThat(nameProperties, anEmptyMap()); - - Map address = (Map) typeProperties.get("address"); - assertThat(address.keySet(), containsInAnyOrder("type", "properties")); - Map addressProperties = (Map) address.get("properties"); - assertThat(addressProperties.keySet(), containsInAnyOrder("street", "area")); - assertLeafs(addressProperties, "street", "area"); - - Map properties = (Map) typeProperties.get("properties"); - assertThat(properties.keySet(), containsInAnyOrder("type", "properties")); - Map propertiesProperties = (Map) properties.get("properties"); - assertThat(propertiesProperties.keySet(), containsInAnyOrder("key", "value")); - assertLeafs(propertiesProperties, "key"); - assertMultiField(propertiesProperties, "value", "keyword"); - } - - { - Map mappings = metadata.getProject() - .findMappings( - new String[] { "index1", "index2", "index3" }, - index -> field -> (index.equals("index3") && field.endsWith("keyword")), - Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP - ); - - assertIndexMappingsNoFields(mappings, "index1"); - assertIndexMappingsNoFields(mappings, "index2"); - MappingMetadata mappingMetadata = mappings.get("index3"); - Map sourceAsMap = mappingMetadata.getSourceAsMap(); - assertThat(sourceAsMap.keySet(), containsInAnyOrder("_routing", "_source", "properties")); - Map typeProperties = (Map) sourceAsMap.get("properties"); - assertThat(typeProperties.keySet(), containsInAnyOrder("properties")); - Map properties = (Map) typeProperties.get("properties"); - assertThat(properties.keySet(), containsInAnyOrder("type", "properties")); - Map propertiesProperties = (Map) properties.get("properties"); - assertThat(propertiesProperties.keySet(), containsInAnyOrder("key", "value")); - Map key = (Map) propertiesProperties.get("key"); - assertThat(key.keySet(), containsInAnyOrder("properties")); - Map keyProperties = (Map) key.get("properties"); - assertThat(keyProperties.keySet(), containsInAnyOrder("keyword")); - assertLeafs(keyProperties, "keyword"); - Map value = (Map) propertiesProperties.get("value"); - assertThat(value.keySet(), containsInAnyOrder("properties")); - Map valueProperties = (Map) value.get("properties"); - assertThat(valueProperties.keySet(), containsInAnyOrder("keyword")); - assertLeafs(valueProperties, "keyword"); - } - - { - Map mappings = metadata.getProject() - .findMappings( - new String[] { "index1", "index2", "index3" }, - index -> field -> (index.equals("index2")), - Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP - ); - - assertIndexMappingsNoFields(mappings, "index1"); - assertIndexMappingsNoFields(mappings, "index3"); - assertIndexMappingsNotFiltered(mappings, "index2"); - } - } - - public void testOldestIndexComputation() { - Metadata metadata = buildIndicesWithVersions( - IndexVersions.MINIMUM_COMPATIBLE, - IndexVersion.current(), - IndexVersion.fromId(IndexVersion.current().id() + 1) - ).build(); - - assertEquals(IndexVersions.MINIMUM_COMPATIBLE, metadata.getProject().oldestIndexVersion()); - - Metadata.Builder b = Metadata.builder(); - assertEquals(IndexVersion.current(), b.build().getProject().oldestIndexVersion()); - - Throwable ex = expectThrows( - IllegalArgumentException.class, - () -> buildIndicesWithVersions( - IndexVersions.MINIMUM_COMPATIBLE, - IndexVersions.ZERO, - IndexVersion.fromId(IndexVersion.current().id() + 1) - ).build() - ); - - assertEquals("[index.version.created] is not present in the index settings for index with UUID [null]", ex.getMessage()); - } - - private Metadata.Builder buildIndicesWithVersions(IndexVersion... indexVersions) { - int lastIndexNum = randomIntBetween(9, 50); - Metadata.Builder b = Metadata.builder(); - for (IndexVersion indexVersion : indexVersions) { - IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName("index", lastIndexNum)) - .settings(settings(indexVersion)) - .numberOfShards(1) - .numberOfReplicas(1) - .build(); - b.put(im, false); - lastIndexNum = randomIntBetween(lastIndexNum + 1, lastIndexNum + 50); - } - - return b; - } - - private static IndexMetadata.Builder buildIndexMetadata(String name, String alias, Boolean writeIndex) { - return IndexMetadata.builder(name) - .settings(settings(IndexVersion.current())) - .creationDate(randomNonNegativeLong()) - .putAlias(AliasMetadata.builder(alias).writeIndex(writeIndex)) - .numberOfShards(1) - .numberOfReplicas(0); - } - - @SuppressWarnings("unchecked") - private static void assertIndexMappingsNoFields(Map mappings, String index) { - MappingMetadata docMapping = mappings.get(index); - assertNotNull(docMapping); - Map sourceAsMap = docMapping.getSourceAsMap(); - assertThat(sourceAsMap.keySet(), containsInAnyOrder("_routing", "_source", "properties")); - Map typeProperties = (Map) sourceAsMap.get("properties"); - assertThat(typeProperties, anEmptyMap()); - } - - @SuppressWarnings("unchecked") - private static void assertIndexMappingsNotFiltered(Map mappings, String index) { - MappingMetadata docMapping = mappings.get(index); - assertNotNull(docMapping); - - Map sourceAsMap = docMapping.getSourceAsMap(); - assertThat(sourceAsMap.keySet(), containsInAnyOrder("_routing", "_source", "properties")); - - Map typeProperties = (Map) sourceAsMap.get("properties"); - assertThat(typeProperties.keySet(), containsInAnyOrder("name", "address", "birth", "age", "ip", "suggest", "properties")); - - Map name = (Map) typeProperties.get("name"); - assertThat(name.keySet(), containsInAnyOrder("properties")); - Map nameProperties = (Map) name.get("properties"); - assertThat(nameProperties.keySet(), containsInAnyOrder("first", "last")); - assertLeafs(nameProperties, "first", "last"); - - Map address = (Map) typeProperties.get("address"); - assertThat(address.keySet(), containsInAnyOrder("type", "properties")); - Map addressProperties = (Map) address.get("properties"); - assertThat(addressProperties.keySet(), containsInAnyOrder("street", "location", "area")); - assertLeafs(addressProperties, "street", "location", "area"); - - Map properties = (Map) typeProperties.get("properties"); - assertThat(properties.keySet(), containsInAnyOrder("type", "properties")); - Map propertiesProperties = (Map) properties.get("properties"); - assertThat(propertiesProperties.keySet(), containsInAnyOrder("key", "value")); - assertMultiField(propertiesProperties, "key", "keyword"); - assertMultiField(propertiesProperties, "value", "keyword"); - } - - @SuppressWarnings("unchecked") - public static void assertLeafs(Map properties, String... fields) { - assertThat(properties.keySet(), hasItems(fields)); - for (String field : fields) { - Map fieldProp = (Map) properties.get(field); - assertThat(fieldProp, not(hasKey("properties"))); - assertThat(fieldProp, not(hasKey("fields"))); - } - } - - public static void assertMultiField(Map properties, String field, String... subFields) { - assertThat(properties, hasKey(field)); - @SuppressWarnings("unchecked") - Map fieldProp = (Map) properties.get(field); - assertThat(fieldProp, hasKey("fields")); - @SuppressWarnings("unchecked") - Map subFieldsDef = (Map) fieldProp.get("fields"); - assertLeafs(subFieldsDef, subFields); - } - - private static final String FIND_MAPPINGS_TEST_ITEM = """ - { - "_doc": { - "_routing": { - "required":true - }, "_source": { - "enabled":false - }, "properties": { - "name": { - "properties": { - "first": { - "type": "keyword" - }, - "last": { - "type": "keyword" - } - } - }, - "birth": { - "type": "date" - }, - "age": { - "type": "integer" - }, - "ip": { - "type": "ip" - }, - "suggest" : { - "type": "completion" - }, - "address": { - "type": "object", - "properties": { - "street": { - "type": "keyword" - }, - "location": { - "type": "geo_point" - }, - "area": { - "type": "geo_shape", \s - "tree": "quadtree", - "precision": "1m" - } - } - }, - "properties": { - "type": "nested", - "properties": { - "key" : { - "type": "text", - "fields": { - "keyword" : { - "type" : "keyword" - } - } - }, - "value" : { - "type": "text", - "fields": { - "keyword" : { - "type" : "keyword" - } - } - } - } - } - } - } - } - }"""; - - public void testTransientSettingsOverridePersistentSettings() { - final Setting setting = Setting.simpleString("key"); - final Metadata metadata = Metadata.builder() - .persistentSettings(Settings.builder().put(setting.getKey(), "persistent-value").build()) - .transientSettings(Settings.builder().put(setting.getKey(), "transient-value").build()) - .build(); - assertThat(setting.get(metadata.settings()), equalTo("transient-value")); - } - - public void testBuilderRejectsNullCustom() { - final Metadata.Builder builder = Metadata.builder(); - final String key = randomAlphaOfLength(10); - assertThat( - expectThrows(NullPointerException.class, () -> builder.putCustom(key, (Metadata.ClusterCustom) null)).getMessage(), - containsString(key) - ); - assertThat(expectThrows(NullPointerException.class, () -> builder.putProjectCustom(key, null)).getMessage(), containsString(key)); - } - - public void testBuilderRejectsNullInCustoms() { - final Metadata.Builder builder = Metadata.builder(); - final String key = randomAlphaOfLength(10); - { - final Map map = new HashMap<>(); - map.put(key, null); - assertThat(expectThrows(NullPointerException.class, () -> builder.customs(map)).getMessage(), containsString(key)); - } - { - final Map map = new HashMap<>(); - map.put(key, null); - assertThat(expectThrows(NullPointerException.class, () -> builder.projectCustoms(map)).getMessage(), containsString(key)); - } - } - - public void testCopyAndUpdate() throws IOException { - var metadata = Metadata.builder().clusterUUID(UUIDs.base64UUID()).build(); - var newClusterUuid = UUIDs.base64UUID(); - - var copy = metadata.copyAndUpdate(builder -> builder.clusterUUID(newClusterUuid)); - - assertThat(copy, not(sameInstance(metadata))); - assertThat(copy.clusterUUID(), equalTo(newClusterUuid)); - } - - public void testBuilderRemoveClusterCustomIf() { - var custom1 = new TestClusterCustomMetadata(); - var custom2 = new TestClusterCustomMetadata(); - var builder = Metadata.builder(); - builder.putCustom("custom1", custom1); - builder.putCustom("custom2", custom2); - - builder.removeCustomIf((key, value) -> Objects.equals(key, "custom1")); - - var metadata = builder.build(); - assertThat(metadata.custom("custom1"), nullValue()); - assertThat(metadata.custom("custom2"), sameInstance(custom2)); - } - - public void testBuilderRejectsDataStreamThatConflictsWithIndex() { - final String dataStreamName = "my-data-stream"; - IndexMetadata idx = createFirstBackingIndex(dataStreamName).build(); - Metadata.Builder b = Metadata.builder() - .put(idx, false) - .put( - IndexMetadata.builder(dataStreamName) - .settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - .build(), - false - ) - .put(newInstance(dataStreamName, List.of(idx.getIndex()))); - - IllegalStateException e = expectThrows(IllegalStateException.class, b::build); - assertThat( - e.getMessage(), - containsString( - "index, alias, and data stream names need to be unique, but the following duplicates were found [data " - + "stream [" - + dataStreamName - + "] conflicts with index]" - ) - ); - } - - public void testBuilderRejectsDataStreamThatConflictsWithAlias() { - final String dataStreamName = "my-data-stream"; - IndexMetadata idx = createFirstBackingIndex(dataStreamName).putAlias(AliasMetadata.builder(dataStreamName).build()).build(); - Metadata.Builder b = Metadata.builder().put(idx, false).put(newInstance(dataStreamName, List.of(idx.getIndex()))); - - IllegalStateException e = expectThrows(IllegalStateException.class, b::build); - assertThat( - e.getMessage(), - containsString( - "index, alias, and data stream names need to be unique, but the following duplicates were found [" - + dataStreamName - + " (alias of [" - + idx.getIndex().getName() - + "]) conflicts with data stream]" - ) - ); - } - - public void testBuilderRejectsAliasThatRefersToDataStreamBackingIndex() { - final String dataStreamName = "my-data-stream"; - final String conflictingName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); - IndexMetadata idx = createFirstBackingIndex(dataStreamName).putAlias(new AliasMetadata.Builder(conflictingName)).build(); - Metadata.Builder b = Metadata.builder().put(idx, false).put(newInstance(dataStreamName, List.of(idx.getIndex()))); - - AssertionError e = expectThrows(AssertionError.class, b::build); - assertThat(e.getMessage(), containsString("aliases [" + conflictingName + "] cannot refer to backing indices of data streams")); - } - - public void testBuilderForDataStreamWithRandomlyNumberedBackingIndices() { - final String dataStreamName = "my-data-stream"; - final List backingIndices = new ArrayList<>(); - final int numBackingIndices = randomIntBetween(2, 5); - int lastBackingIndexNum = 0; - Metadata.Builder b = Metadata.builder(); - for (int k = 1; k <= numBackingIndices; k++) { - lastBackingIndexNum = randomIntBetween(lastBackingIndexNum + 1, lastBackingIndexNum + 50); - IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, lastBackingIndexNum)) - .settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - .build(); - b.put(im, false); - backingIndices.add(im.getIndex()); - } - - b.put(newInstance(dataStreamName, backingIndices, lastBackingIndexNum, null)); - Metadata metadata = b.build(); - assertThat(metadata.getProject().dataStreams().keySet(), containsInAnyOrder(dataStreamName)); - assertThat(metadata.getProject().dataStreams().get(dataStreamName).getName(), equalTo(dataStreamName)); - } - - public void testBuildIndicesLookupForDataStreams() { - Metadata.Builder b = Metadata.builder(); - int numDataStreams = randomIntBetween(2, 8); - for (int i = 0; i < numDataStreams; i++) { - String name = "data-stream-" + i; - addDataStream(name, b); - } - - Metadata metadata = b.build(); - assertThat(metadata.getProject().dataStreams().size(), equalTo(numDataStreams)); - for (int i = 0; i < numDataStreams; i++) { - String name = "data-stream-" + i; - IndexAbstraction value = metadata.getProject().getIndicesLookup().get(name); - assertThat(value, notNullValue()); - DataStream ds = metadata.getProject().dataStreams().get(name); - assertThat(ds, notNullValue()); - - assertThat(value.isHidden(), is(false)); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); - assertThat(value.getIndices(), hasSize(ds.getIndices().size())); - assertThat(value.getWriteIndex().getName(), DataStreamTestHelper.backingIndexEqualTo(name, (int) ds.getGeneration())); - } - } - - public void testBuildIndicesLookupForDataStreamAliases() { - Metadata.Builder b = Metadata.builder(); - - addDataStream("d1", b); - addDataStream("d2", b); - addDataStream("d3", b); - addDataStream("d4", b); - - b.put("a1", "d1", null, null); - b.put("a1", "d2", null, null); - b.put("a2", "d3", null, null); - b.put("a3", "d1", null, null); - - Metadata metadata = b.build(); - assertThat(metadata.getProject().dataStreams(), aMapWithSize(4)); - IndexAbstraction value = metadata.getProject().getIndicesLookup().get("d1"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); - - value = metadata.getProject().getIndicesLookup().get("d2"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); - - value = metadata.getProject().getIndicesLookup().get("d3"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); - - value = metadata.getProject().getIndicesLookup().get("d4"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); - - value = metadata.getProject().getIndicesLookup().get("a1"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.ALIAS)); - - value = metadata.getProject().getIndicesLookup().get("a2"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.ALIAS)); - - value = metadata.getProject().getIndicesLookup().get("a3"); - assertThat(value, notNullValue()); - assertThat(value.getType(), equalTo(IndexAbstraction.Type.ALIAS)); - } - - public void testDataStreamAliasValidation() { - Metadata.Builder b = Metadata.builder(); - addDataStream("my-alias", b); - b.put("my-alias", "my-alias", null, null); - var e = expectThrows(IllegalStateException.class, b::build); - assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); - - b = Metadata.builder(); - addDataStream("d1", b); - addDataStream("my-alias", b); - b.put("my-alias", "d1", null, null); - e = expectThrows(IllegalStateException.class, b::build); - assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); - - b = Metadata.builder(); - b.put( - IndexMetadata.builder("index1") - .settings(indexSettings(IndexVersion.current(), 1, 0)) - .putAlias(new AliasMetadata.Builder("my-alias")) - ); - - addDataStream("d1", b); - b.put("my-alias", "d1", null, null); - e = expectThrows(IllegalStateException.class, b::build); - assertThat(e.getMessage(), containsString("data stream alias and indices alias have the same name (my-alias)")); - } - - public void testDataStreamAliasValidationRestoreScenario() { - Metadata.Builder b = Metadata.builder(); - b.dataStreams( - Map.of("my-alias", createDataStream("my-alias")), - Map.of("my-alias", new DataStreamAlias("my-alias", List.of("my-alias"), null, null)) - ); - var e = expectThrows(IllegalStateException.class, b::build); - assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); - - b = Metadata.builder(); - b.dataStreams( - Map.of("d1", createDataStream("d1"), "my-alias", createDataStream("my-alias")), - Map.of("my-alias", new DataStreamAlias("my-alias", List.of("d1"), null, null)) - ); - e = expectThrows(IllegalStateException.class, b::build); - assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); - - b = Metadata.builder(); - b.put( - IndexMetadata.builder("index1") - .settings(indexSettings(IndexVersion.current(), 1, 0)) - .putAlias(new AliasMetadata.Builder("my-alias")) - ); - b.dataStreams(Map.of("d1", createDataStream("d1")), Map.of("my-alias", new DataStreamAlias("my-alias", List.of("d1"), null, null))); - e = expectThrows(IllegalStateException.class, b::build); - assertThat(e.getMessage(), containsString("data stream alias and indices alias have the same name (my-alias)")); - } - - private void addDataStream(String name, Metadata.Builder b) { - int numBackingIndices = randomIntBetween(1, 4); - List indices = new ArrayList<>(numBackingIndices); - for (int j = 1; j <= numBackingIndices; j++) { - IndexMetadata idx = createBackingIndex(name, j).build(); - indices.add(idx.getIndex()); - b.put(idx, true); - } - b.put(newInstance(name, indices)); - } - - private DataStream createDataStream(String name) { - int numBackingIndices = randomIntBetween(1, 4); - List indices = new ArrayList<>(numBackingIndices); - for (int j = 1; j <= numBackingIndices; j++) { - IndexMetadata idx = createBackingIndex(name, j).build(); - indices.add(idx.getIndex()); - } - return newInstance(name, indices); - } - - public void testIndicesLookupRecordsDataStreamForBackingIndices() { - final int numIndices = randomIntBetween(2, 5); - final int numBackingIndices = randomIntBetween(2, 5); - final String dataStreamName = "my-data-stream"; - CreateIndexResult result = createIndices(numIndices, numBackingIndices, dataStreamName); - - SortedMap indicesLookup = result.metadata.getProject().getIndicesLookup(); - assertThat(indicesLookup, aMapWithSize(result.indices.size() + result.backingIndices.size() + 1)); - for (Index index : result.indices) { - assertThat(indicesLookup, hasKey(index.getName())); - assertNull(indicesLookup.get(index.getName()).getParentDataStream()); - } - for (Index index : result.backingIndices) { - assertThat(indicesLookup, hasKey(index.getName())); - assertNotNull(indicesLookup.get(index.getName()).getParentDataStream()); - assertThat(indicesLookup.get(index.getName()).getParentDataStream().getName(), equalTo(dataStreamName)); - } - } - - public void testSerialization() throws IOException { - final Metadata orig = randomMetadata(); - final BytesStreamOutput out = new BytesStreamOutput(); - orig.writeTo(out); - NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); - final Metadata fromStreamMeta = Metadata.readFrom( - new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry) - ); - assertTrue(Metadata.isGlobalStateEquals(orig, fromStreamMeta)); - } - - public void testMultiProjectSerialization() throws IOException { - // TODO: this whole suite needs to be updated for multiple projects - ProjectMetadata project1 = randomProject(ProjectId.fromId("1"), 1); - ProjectMetadata project2 = randomProject(ProjectId.fromId("2"), randomIntBetween(2, 10)); - Metadata metadata = randomMetadata(List.of(project1, project2)); - BytesStreamOutput out = new BytesStreamOutput(); - metadata.writeTo(out); - - // check it deserializes ok - NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); - Metadata fromStreamMeta = Metadata.readFrom(new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry)); - - // check it matches the original object - assertThat(fromStreamMeta.projects(), aMapWithSize(2)); - for (var original : List.of(project1, project2)) { - assertThat(fromStreamMeta.projects(), hasKey(original.id())); - final ProjectMetadata fromStreamProject = fromStreamMeta.getProject(original.id()); - assertThat("For project " + original.id(), fromStreamProject.indices().size(), equalTo(original.indices().size())); - assertThat("For project " + original.id(), fromStreamProject.dataStreams().size(), equalTo(original.dataStreams().size())); - assertThat("For project " + original.id(), fromStreamProject.templates().size(), equalTo(original.templates().size())); - assertThat("For project " + original.id(), fromStreamProject.templatesV2().size(), equalTo(original.templatesV2().size())); - original.indices().forEach((name, value) -> { - assertThat(fromStreamProject.indices(), hasKey(name)); - assertThat(fromStreamProject.index(name), equalTo(value)); - }); - original.dataStreams().forEach((name, value) -> { - assertThat(fromStreamProject.dataStreams(), hasKey(name)); - assertThat(fromStreamProject.dataStreams().get(name), equalTo(value)); - }); - } - } - - public void testGetNonExistingProjectThrows() { - final List projects = IntStream.range(0, between(1, 3)) - .mapToObj(i -> randomProject(ProjectId.fromId("p_" + i), between(0, 5))) - .toList(); - final Metadata metadata = randomMetadata(projects); - expectThrows(IllegalArgumentException.class, () -> metadata.getProject(randomProjectIdOrDefault())); - } - - public void testValidateDataStreamsNoConflicts() { - Metadata metadata = createIndices(5, 10, "foo-datastream").metadata; - // don't expect any exception when validating a system without indices that would conflict with future backing indices - assertDataStreams( - metadata.getProject().indices(), - (DataStreamMetadata) metadata.getProject().customs().get(DataStreamMetadata.TYPE) - ); - } - - public void testValidateDataStreamsIgnoresIndicesWithoutCounter() { - String dataStreamName = "foo-datastream"; - Metadata metadata = Metadata.builder(createIndices(10, 10, dataStreamName).metadata) - .put( - new IndexMetadata.Builder(dataStreamName + "-index-without-counter").settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - ) - .put( - new IndexMetadata.Builder(dataStreamName + randomAlphaOfLength(10)).settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - - ) - .put( - new IndexMetadata.Builder(randomAlphaOfLength(10)).settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - - ) - .build(); - // don't expect any exception when validating against non-backing indices that don't conform to the backing indices naming - // convention - assertDataStreams( - metadata.getProject().indices(), - (DataStreamMetadata) metadata.getProject().customs().get(DataStreamMetadata.TYPE) - ); - } - - public void testValidateDataStreamsAllowsNamesThatStartsWithPrefix() { - String dataStreamName = "foo-datastream"; - Metadata metadata = Metadata.builder(createIndices(10, 10, dataStreamName).metadata) - .put( - new IndexMetadata.Builder(DataStream.BACKING_INDEX_PREFIX + dataStreamName + "-something-100012").settings( - settings(IndexVersion.current()) - ).numberOfShards(1).numberOfReplicas(1) - ) - .build(); - // don't expect any exception when validating against (potentially backing) indices that can't create conflict because of - // additional text before number - assertDataStreams( - metadata.getProject().indices(), - (DataStreamMetadata) metadata.getProject().customs().get(DataStreamMetadata.TYPE) - ); - } - - public void testValidateDataStreamsForNullDataStreamMetadata() { - Metadata metadata = Metadata.builder() - .put(IndexMetadata.builder("foo-index").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1)) - .build(); - - try { - assertDataStreams(metadata.getProject().indices(), DataStreamMetadata.EMPTY); - } catch (Exception e) { - fail("did not expect exception when validating a system without any data streams but got " + e.getMessage()); - } - } - - public void testDataStreamAliases() { - Metadata.Builder mdBuilder = Metadata.builder(); - - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null), is(true)); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-us", null, null), is(true)); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-au", null, null), is(true)); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-au", null, null), is(false)); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") - ); - } - - public void testDataStreamReferToNonExistingDataStream() { - Metadata.Builder mdBuilder = Metadata.builder(); - - Exception e = expectThrows(IllegalArgumentException.class, () -> mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null)); - assertThat(e.getMessage(), equalTo("alias [logs-postgres] refers to a non existing data stream [logs-postgres-eu]")); - - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); - mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-eu")); - } - - public void testDeleteDataStreamShouldUpdateAlias() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); - mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); - mdBuilder.put("logs-postgres", "logs-postgres-us", null, null); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); - mdBuilder.put("logs-postgres", "logs-postgres-au", null, null); - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") - ); - - mdBuilder = Metadata.builder(metadata); - mdBuilder.removeDataStream("logs-postgres-us"); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-eu", "logs-postgres-au") - ); - - mdBuilder = Metadata.builder(metadata); - mdBuilder.removeDataStream("logs-postgres-au"); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-eu")); - - mdBuilder = Metadata.builder(metadata); - mdBuilder.removeDataStream("logs-postgres-eu"); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), nullValue()); - } - - public void testDeleteDataStreamAlias() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); - mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); - mdBuilder.put("logs-postgres", "logs-postgres-us", null, null); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); - mdBuilder.put("logs-postgres", "logs-postgres-au", null, null); - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") - ); - - mdBuilder = Metadata.builder(metadata); - assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-us", true), is(true)); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-eu", "logs-postgres-au") - ); - - mdBuilder = Metadata.builder(metadata); - assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-au", true), is(true)); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-eu")); - - mdBuilder = Metadata.builder(metadata); - assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-eu", true), is(true)); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), nullValue()); - } - - public void testDeleteDataStreamAliasMustExists() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); - mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); - mdBuilder.put("logs-postgres", "logs-postgres-us", null, null); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); - mdBuilder.put("logs-postgres", "logs-postgres-au", null, null); - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") - ); - - Metadata.Builder mdBuilder2 = Metadata.builder(metadata); - expectThrows(ResourceNotFoundException.class, () -> mdBuilder2.removeDataStreamAlias("logs-mysql", "logs-postgres-us", true)); - assertThat(mdBuilder2.removeDataStreamAlias("logs-mysql", "logs-postgres-us", false), is(false)); - } - - public void testDataStreamWriteAlias() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); - mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-replicated") - ); - - mdBuilder = Metadata.builder(metadata); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", true, null), is(true)); - - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), - equalTo("logs-postgres-replicated") - ); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-replicated") - ); - } - - public void testDataStreamMultipleWriteAlias() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-foobar")); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-barbaz")); - mdBuilder.put("logs", "logs-foobar", true, null); - mdBuilder.put("logs", "logs-barbaz", true, null); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs").getWriteDataStream(), equalTo("logs-barbaz")); - assertThat( - metadata.getProject().dataStreamAliases().get("logs").getDataStreams(), - containsInAnyOrder("logs-foobar", "logs-barbaz") - ); - } - - public void testDataStreamWriteAliasUnset() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); - mdBuilder.put("logs-postgres", "logs-postgres-replicated", true, null); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), - equalTo("logs-postgres-replicated") - ); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-replicated") - ); - - mdBuilder = Metadata.builder(metadata); - // Side check: null value isn't changing anything: - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(false)); - // Unset write flag - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", false, null), is(true)); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-replicated") - ); - } - - public void testDataStreamWriteAliasChange() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-primary")); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", true, null), is(true)); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(true)); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-primary")); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") - ); - - // change write flag: - mdBuilder = Metadata.builder(metadata); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", false, null), is(true)); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", true, null), is(true)); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), - equalTo("logs-postgres-replicated") - ); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") - ); - } - - public void testDataStreamWriteRemoveAlias() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-primary")); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", true, null), is(true)); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(true)); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-primary")); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") - ); - - mdBuilder = Metadata.builder(metadata); - assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-primary", randomBoolean()), is(true)); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-replicated") - ); - } - - public void testDataStreamWriteRemoveDataStream() { - Metadata.Builder mdBuilder = Metadata.builder(); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-primary")); - mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", true, null), is(true)); - assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(true)); - - Metadata metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-primary")); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") - ); - - mdBuilder = Metadata.builder(metadata); - mdBuilder.removeDataStream("logs-postgres-primary"); - metadata = mdBuilder.build(); - assertThat(metadata.getProject().dataStreams().keySet(), contains("logs-postgres-replicated")); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres"), notNullValue()); - assertThat(metadata.getProject().dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); - assertThat( - metadata.getProject().dataStreamAliases().get("logs-postgres").getDataStreams(), - containsInAnyOrder("logs-postgres-replicated") - ); - } - - public void testReuseIndicesLookup() { - String indexName = "my-index"; - String aliasName = "my-alias"; - String dataStreamName = "logs-mysql-prod"; - String dataStreamAliasName = "logs-mysql"; - Metadata previous = Metadata.builder().build(); - - // Things that should change indices lookup - { - Metadata.Builder builder = Metadata.builder(previous); - IndexMetadata idx = DataStreamTestHelper.createFirstBackingIndex(dataStreamName).build(); - builder.put(idx, true); - DataStream dataStream = newInstance(dataStreamName, List.of(idx.getIndex())); - builder.put(dataStream); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - builder.put(dataStreamAliasName, dataStreamName, false, null); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - builder.put(dataStreamAliasName, dataStreamName, true, null); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - builder.put( - IndexMetadata.builder(indexName) - .settings(settings(IndexVersion.current())) - .creationDate(randomNonNegativeLong()) - .numberOfShards(1) - .numberOfReplicas(0) - ); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); - imBuilder.putAlias(AliasMetadata.builder(aliasName).build()); - builder.put(imBuilder); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); - imBuilder.putAlias(AliasMetadata.builder(aliasName).writeIndex(true).build()); - builder.put(imBuilder); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); - Settings.Builder sBuilder = Settings.builder() - .put(builder.get(indexName).getSettings()) - .put(IndexMetadata.INDEX_HIDDEN_SETTING.getKey(), true); - imBuilder.settings(sBuilder.build()); - builder.put(imBuilder); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), not(sameInstance(metadata.getProject().getIndicesLookup()))); - previous = metadata; - } - - // Things that shouldn't change indices lookup - { - Metadata.Builder builder = Metadata.builder(previous); - IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); - imBuilder.numberOfReplicas(2); - builder.put(imBuilder); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), sameInstance(metadata.getProject().getIndicesLookup())); - previous = metadata; - } - { - Metadata.Builder builder = Metadata.builder(previous); - IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); - Settings.Builder sBuilder = Settings.builder() - .put(builder.get(indexName).getSettings()) - .put(IndexSettings.DEFAULT_FIELD_SETTING.getKey(), "val"); - imBuilder.settings(sBuilder.build()); - builder.put(imBuilder); - Metadata metadata = builder.build(); - assertThat(previous.getProject().getIndicesLookup(), sameInstance(metadata.getProject().getIndicesLookup())); - previous = metadata; - } - } - - public void testAliasedIndices() { - int numAliases = randomIntBetween(32, 64); - int numIndicesPerAlias = randomIntBetween(8, 16); - - Metadata.Builder builder = Metadata.builder(); - for (int i = 0; i < numAliases; i++) { - String aliasName = "alias-" + i; - for (int j = 0; j < numIndicesPerAlias; j++) { - AliasMetadata.Builder alias = new AliasMetadata.Builder(aliasName); - if (j == 0) { - alias.writeIndex(true); - } - - String indexName = aliasName + "-" + j; - builder.put( - IndexMetadata.builder(indexName) - .settings(settings(IndexVersion.current())) - .creationDate(randomNonNegativeLong()) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(alias) - ); - } - } - - Metadata metadata = builder.build(); - for (int i = 0; i < numAliases; i++) { - String aliasName = "alias-" + i; - Set result = metadata.getProject().aliasedIndices(aliasName); - Index[] expected = IntStream.range(0, numIndicesPerAlias) - .mapToObj(j -> aliasName + "-" + j) - .map(name -> new Index(name, ClusterState.UNKNOWN_UUID)) - .toArray(Index[]::new); - assertThat(result, containsInAnyOrder(expected)); - } - - // Add a new alias and index - builder = Metadata.builder(metadata); - String newAliasName = "alias-new"; - { - builder.put( - IndexMetadata.builder(newAliasName + "-1") - .settings(settings(IndexVersion.current())) - .creationDate(randomNonNegativeLong()) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(new AliasMetadata.Builder(newAliasName).writeIndex(true)) - ); - } - metadata = builder.build(); - assertThat(metadata.getProject().aliasedIndices(), hasSize(numAliases + 1)); - assertThat(metadata.getProject().aliasedIndices(newAliasName), contains(new Index(newAliasName + "-1", ClusterState.UNKNOWN_UUID))); - - // Remove the new alias/index - builder = Metadata.builder(metadata); - { - builder.remove(newAliasName + "-1"); - } - metadata = builder.build(); - assertThat(metadata.getProject().aliasedIndices(), hasSize(numAliases)); - assertThat(metadata.getProject().aliasedIndices(newAliasName), empty()); - - // Add a new alias that points to existing indices - builder = Metadata.builder(metadata); - { - IndexMetadata.Builder imBuilder = new IndexMetadata.Builder(metadata.getProject().index("alias-1-0")); - imBuilder.putAlias(new AliasMetadata.Builder(newAliasName)); - builder.put(imBuilder); - - imBuilder = new IndexMetadata.Builder(metadata.getProject().index("alias-2-1")); - imBuilder.putAlias(new AliasMetadata.Builder(newAliasName)); - builder.put(imBuilder); - - imBuilder = new IndexMetadata.Builder(metadata.getProject().index("alias-3-2")); - imBuilder.putAlias(new AliasMetadata.Builder(newAliasName)); - builder.put(imBuilder); - } - metadata = builder.build(); - assertThat(metadata.getProject().aliasedIndices(), hasSize(numAliases + 1)); - assertThat( - metadata.getProject().aliasedIndices(newAliasName), - containsInAnyOrder( - new Index("alias-1-0", ClusterState.UNKNOWN_UUID), - new Index("alias-2-1", ClusterState.UNKNOWN_UUID), - new Index("alias-3-2", ClusterState.UNKNOWN_UUID) - ) - ); - - // Remove the new alias that points to existing indices - builder = Metadata.builder(metadata); - { - IndexMetadata.Builder imBuilder = new IndexMetadata.Builder(metadata.getProject().index("alias-1-0")); - imBuilder.removeAlias(newAliasName); - builder.put(imBuilder); - - imBuilder = new IndexMetadata.Builder(metadata.getProject().index("alias-2-1")); - imBuilder.removeAlias(newAliasName); - builder.put(imBuilder); - - imBuilder = new IndexMetadata.Builder(metadata.getProject().index("alias-3-2")); - imBuilder.removeAlias(newAliasName); - builder.put(imBuilder); - } - metadata = builder.build(); - assertThat(metadata.getProject().aliasedIndices(), hasSize(numAliases)); - assertThat(metadata.getProject().aliasedIndices(newAliasName), empty()); - } - - public static final String SYSTEM_ALIAS_NAME = "system_alias"; - - public void testHiddenAliasValidation() { - final String hiddenAliasName = "hidden_alias"; - - IndexMetadata hidden1 = buildIndexWithAlias("hidden1", hiddenAliasName, true, IndexVersion.current(), false); - IndexMetadata hidden2 = buildIndexWithAlias("hidden2", hiddenAliasName, true, IndexVersion.current(), false); - IndexMetadata hidden3 = buildIndexWithAlias("hidden3", hiddenAliasName, true, IndexVersion.current(), false); - - IndexMetadata nonHidden = buildIndexWithAlias("nonhidden1", hiddenAliasName, false, IndexVersion.current(), false); - IndexMetadata unspecified = buildIndexWithAlias("nonhidden2", hiddenAliasName, null, IndexVersion.current(), false); - - { - // Should be ok: - metadataWithIndices(hidden1, hidden2, hidden3); - } - - { - // Should be ok: - if (randomBoolean()) { - metadataWithIndices(nonHidden, unspecified); - } else { - metadataWithIndices(unspecified, nonHidden); + "pending_generation": 42 + } + }, + "reserved_state":{ } + } } - } - - { - IllegalStateException exception = expectThrows( - IllegalStateException.class, - () -> metadataWithIndices(hidden1, hidden2, hidden3, nonHidden) - ); - assertThat(exception.getMessage(), containsString("alias [" + hiddenAliasName + "] has is_hidden set to true on indices [")); - assertThat( - exception.getMessage(), - allOf( - containsString(hidden1.getIndex().getName()), - containsString(hidden2.getIndex().getName()), - containsString(hidden3.getIndex().getName()) - ) - ); - assertThat( - exception.getMessage(), - containsString( - "but does not have is_hidden set to true on indices [" - + nonHidden.getIndex().getName() - + "]; alias must have the same is_hidden setting on all indices" - ) - ); - } - - { - IllegalStateException exception = expectThrows( - IllegalStateException.class, - () -> metadataWithIndices(hidden1, hidden2, hidden3, unspecified) - ); - assertThat(exception.getMessage(), containsString("alias [" + hiddenAliasName + "] has is_hidden set to true on indices [")); - assertThat( - exception.getMessage(), - allOf( - containsString(hidden1.getIndex().getName()), - containsString(hidden2.getIndex().getName()), - containsString(hidden3.getIndex().getName()) - ) - ); - assertThat( - exception.getMessage(), - containsString( - "but does not have is_hidden set to true on indices [" - + unspecified.getIndex().getName() - + "]; alias must have the same is_hidden setting on all indices" - ) - ); - } - - { - final IndexMetadata hiddenIndex = randomFrom(hidden1, hidden2, hidden3); - IllegalStateException exception = expectThrows(IllegalStateException.class, () -> { - if (randomBoolean()) { - metadataWithIndices(nonHidden, unspecified, hiddenIndex); - } else { - metadataWithIndices(unspecified, nonHidden, hiddenIndex); - } - }); - assertThat( - exception.getMessage(), - containsString( - "alias [" - + hiddenAliasName - + "] has is_hidden set to true on " - + "indices [" - + hiddenIndex.getIndex().getName() - + "] but does not have is_hidden set to true on indices [" - ) - ); - assertThat( - exception.getMessage(), - allOf(containsString(unspecified.getIndex().getName()), containsString(nonHidden.getIndex().getName())) - ); - assertThat(exception.getMessage(), containsString("but does not have is_hidden set to true on indices [")); - } - } + """, IndexVersion.current(), IndexVersion.current()); - public void testSystemAliasValidationMixedVersionSystemAndRegularFails() { - final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.V_7_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) - ); - final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); - final IndexMetadata regularIndex = buildIndexWithAlias("regular1", SYSTEM_ALIAS_NAME, false, IndexVersion.current(), false); + final Metadata metadata = fromJsonXContentStringWithPersistentTasks(json); + assertThat(metadata, notNullValue()); + assertThat(metadata.clusterUUID(), is("aba1aa1ababbbaabaabaab")); - IllegalStateException exception = expectThrows( - IllegalStateException.class, - () -> metadataWithIndices(currentVersionSystem, oldVersionSystem, regularIndex) - ); + assertThat(metadata.projects().keySet(), containsInAnyOrder(ProjectId.fromId("default"), ProjectId.fromId("another_project"))); + assertThat(metadata.customs(), not(hasKey("repositories"))); + final var repositoriesMetadata = RepositoriesMetadata.get(metadata.getProject(ProjectId.DEFAULT)); assertThat( - exception.getMessage(), - containsString( - "alias [" - + SYSTEM_ALIAS_NAME - + "] refers to both system indices [" - + currentVersionSystem.getIndex().getName() - + "] and non-system indices: [" - + regularIndex.getIndex().getName() - + "], but aliases must refer to either system or non-system indices, not both" + repositoriesMetadata.repositories(), + equalTo( + List.of( + new RepositoryMetadata("my-repo", "_my-repo-uuid_", "fs", Settings.builder().put("location", "backup").build(), 42, 42) + ) ) ); + assertThat(metadata.getProject(ProjectId.fromId("another_project")).customs(), not(hasKey("repositories"))); } - public void testSystemAliasValidationNewSystemAndRegularFails() { - final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata regularIndex = buildIndexWithAlias("regular1", SYSTEM_ALIAS_NAME, false, IndexVersion.current(), false); - - IllegalStateException exception = expectThrows( - IllegalStateException.class, - () -> metadataWithIndices(currentVersionSystem, regularIndex) - ); - assertThat( - exception.getMessage(), - containsString( - "alias [" - + SYSTEM_ALIAS_NAME - + "] refers to both system indices [" - + currentVersionSystem.getIndex().getName() - + "] and non-system indices: [" - + regularIndex.getIndex().getName() - + "], but aliases must refer to either system or non-system indices, not both" - ) - ); - } + private Metadata fromJsonXContentStringWithPersistentTasks(String json) throws IOException { + List registry = new ArrayList<>(); + registry.addAll(ClusterModule.getNamedXWriteables()); + registry.addAll(IndicesModule.getNamedXContents()); + registry.addAll(HealthNodeTaskExecutor.getNamedXContentParsers()); - public void testSystemAliasOldSystemAndNewRegular() { - final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.V_7_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + final var clusterService = mock(ClusterService.class); + when(clusterService.threadPool()).thenReturn(mock(ThreadPool.class)); + final var healthNodeTaskExecutor = HealthNodeTaskExecutor.create( + clusterService, + mock(PersistentTasksService.class), + Settings.EMPTY, + ClusterSettings.createBuiltInClusterSettings() ); - final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); - final IndexMetadata regularIndex = buildIndexWithAlias("regular1", SYSTEM_ALIAS_NAME, false, IndexVersion.current(), false); + new PersistentTasksExecutorRegistry(List.of(healthNodeTaskExecutor)); - // Should be ok: - metadataWithIndices(oldVersionSystem, regularIndex); + XContentParserConfiguration config = XContentParserConfiguration.EMPTY.withRegistry(new NamedXContentRegistry(registry)); + try (XContentParser parser = JsonXContent.jsonXContent.createParser(config, json)) { + return Metadata.fromXContent(parser); + } } - public void testSystemIndexValidationAllRegular() { - final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.V_7_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + public void testGlobalStateEqualsCoordinationMetadata() { + CoordinationMetadata coordinationMetadata1 = new CoordinationMetadata( + randomNonNegativeLong(), + randomVotingConfig(), + randomVotingConfig(), + randomVotingConfigExclusions() + ); + Metadata metadata1 = Metadata.builder().coordinationMetadata(coordinationMetadata1).build(); + CoordinationMetadata coordinationMetadata2 = new CoordinationMetadata( + randomNonNegativeLong(), + randomVotingConfig(), + randomVotingConfig(), + randomVotingConfigExclusions() ); - final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata currentVersionSystem2 = buildIndexWithAlias(".system2", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); + Metadata metadata2 = Metadata.builder().coordinationMetadata(coordinationMetadata2).build(); - // Should be ok - metadataWithIndices(currentVersionSystem, currentVersionSystem2, oldVersionSystem); + assertTrue(Metadata.isGlobalStateEquals(metadata1, metadata1)); + assertFalse(Metadata.isGlobalStateEquals(metadata1, metadata2)); } - public void testSystemAliasValidationAllSystemSomeOld() { - final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.V_7_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + public void testSerializationWithIndexGraveyard() throws IOException { + final var projectId = randomProjectIdOrDefault(); + final IndexGraveyard graveyard = IndexGraveyardTests.createRandom(); + final Metadata originalMeta = Metadata.builder().put(ProjectMetadata.builder(projectId).indexGraveyard(graveyard)).build(); + final BytesStreamOutput out = new BytesStreamOutput(); + originalMeta.writeTo(out); + NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); + final Metadata fromStreamMeta = Metadata.readFrom( + new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry) ); - final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata currentVersionSystem2 = buildIndexWithAlias(".system2", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); + assertThat(fromStreamMeta.getProject(projectId).indexGraveyard(), equalTo(originalMeta.getProject(projectId).indexGraveyard())); + } - // Should be ok: - metadataWithIndices(currentVersionSystem, currentVersionSystem2, oldVersionSystem); + private static IndexMetadata.Builder buildIndexMetadata(String name, String alias, Boolean writeIndex) { + return IndexMetadata.builder(name) + .settings(settings(IndexVersion.current())) + .creationDate(randomNonNegativeLong()) + .putAlias(AliasMetadata.builder(alias).writeIndex(writeIndex)) + .numberOfShards(1) + .numberOfReplicas(0); } - public void testSystemAliasValidationAll8x() { - final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); - final IndexMetadata currentVersionSystem2 = buildIndexWithAlias(".system2", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + public void testTransientSettingsOverridePersistentSettings() { + final Setting setting = Setting.simpleString("key"); + final Metadata metadata = Metadata.builder() + .persistentSettings(Settings.builder().put(setting.getKey(), "persistent-value").build()) + .transientSettings(Settings.builder().put(setting.getKey(), "transient-value").build()) + .build(); + assertThat(setting.get(metadata.settings()), equalTo("transient-value")); + } - // Should be ok - metadataWithIndices(currentVersionSystem, currentVersionSystem2); + public void testBuilderRejectsNullCustom() { + final Metadata.Builder builder = Metadata.builder(); + final String key = randomAlphaOfLength(10); + assertThat( + expectThrows(NullPointerException.class, () -> builder.putCustom(key, (Metadata.ClusterCustom) null)).getMessage(), + containsString(key) + ); } - private void metadataWithIndices(IndexMetadata... indices) { - Metadata.Builder builder = Metadata.builder(); - for (var cursor : indices) { - builder.put(cursor, false); + public void testBuilderRejectsNullInCustoms() { + final Metadata.Builder builder = Metadata.builder(); + final String key = randomAlphaOfLength(10); + { + final Map map = new HashMap<>(); + map.put(key, null); + assertThat(expectThrows(NullPointerException.class, () -> builder.customs(map)).getMessage(), containsString(key)); } - builder.build(); } - private IndexMetadata buildIndexWithAlias( - String indexName, - String aliasName, - @Nullable Boolean aliasIsHidden, - IndexVersion indexCreationVersion, - boolean isSystem - ) { - final AliasMetadata.Builder aliasMetadata = new AliasMetadata.Builder(aliasName); - if (aliasIsHidden != null || randomBoolean()) { - aliasMetadata.isHidden(aliasIsHidden); - } - return new IndexMetadata.Builder(indexName).settings(settings(indexCreationVersion)) - .system(isSystem) - .numberOfShards(1) - .numberOfReplicas(0) - .putAlias(aliasMetadata) - .build(); + public void testCopyAndUpdate() throws IOException { + var metadata = Metadata.builder().clusterUUID(UUIDs.base64UUID()).build(); + var newClusterUuid = UUIDs.base64UUID(); + + var copy = metadata.copyAndUpdate(builder -> builder.clusterUUID(newClusterUuid)); + + assertThat(copy, not(sameInstance(metadata))); + assertThat(copy.clusterUUID(), equalTo(newClusterUuid)); } - public void testMappingDuplication() { - final Set randomMappingDefinitions; - { - int numEntries = randomIntBetween(4, 8); - randomMappingDefinitions = Sets.newHashSetWithExpectedSize(numEntries); - for (int i = 0; i < numEntries; i++) { - Map mapping = RandomAliasActionsGenerator.randomMap(2); - String mappingAsString = Strings.toString((builder, params) -> builder.mapContents(mapping)); - randomMappingDefinitions.add(mappingAsString); - } - } + public void testBuilderRemoveClusterCustomIf() { + var custom1 = new TestClusterCustomMetadata(); + var custom2 = new TestClusterCustomMetadata(); + var builder = Metadata.builder(); + builder.putCustom("custom1", custom1); + builder.putCustom("custom2", custom2); - Metadata metadata; - int numIndices = randomIntBetween(16, 32); - { - String[] definitions = randomMappingDefinitions.toArray(String[]::new); - Metadata.Builder mb = new Metadata.Builder(); - for (int i = 0; i < numIndices; i++) { - IndexMetadata.Builder indexBuilder = IndexMetadata.builder("index-" + i) - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping(definitions[i % randomMappingDefinitions.size()]) - .numberOfShards(1) - .numberOfReplicas(0); - if (randomBoolean()) { - mb.put(indexBuilder); - } else { - mb.put(indexBuilder.build(), true); - } - } - metadata = mb.build(); - } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); - assertThat( - metadata.getProject().indices().values().stream().map(IndexMetadata::mapping).collect(Collectors.toSet()), - hasSize(metadata.getProject().getMappingsByHash().size()) - ); + builder.removeCustomIf((key, value) -> Objects.equals(key, "custom1")); - // Add a new index with a new index with known mapping: - MappingMetadata mapping = metadata.getProject().indices().get("index-" + randomInt(numIndices - 1)).mapping(); - MappingMetadata entry = metadata.getProject().getMappingsByHash().get(mapping.getSha256()); - { - Metadata.Builder mb = new Metadata.Builder(metadata); - mb.put( - IndexMetadata.builder("index-" + numIndices) - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping(mapping) - .numberOfShards(1) - .numberOfReplicas(0) - ); - metadata = mb.build(); - } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); - assertThat(metadata.getProject().getMappingsByHash().get(mapping.getSha256()), equalTo(entry)); + var metadata = builder.build(); + assertThat(metadata.custom("custom1"), nullValue()); + assertThat(metadata.custom("custom2"), sameInstance(custom2)); + } - // Remove index and ensure mapping cache stays the same - { - Metadata.Builder mb = new Metadata.Builder(metadata); - mb.remove("index-" + numIndices); - metadata = mb.build(); - } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); - assertThat(metadata.getProject().getMappingsByHash().get(mapping.getSha256()), equalTo(entry)); + public void testSerialization() throws IOException { + final Metadata orig = randomMetadata(); + final BytesStreamOutput out = new BytesStreamOutput(); + orig.writeTo(out); + NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); + final Metadata fromStreamMeta = Metadata.readFrom( + new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry) + ); + assertTrue(Metadata.isGlobalStateEquals(orig, fromStreamMeta)); + } - // Update a mapping of an index: - IndexMetadata luckyIndex = metadata.getProject().index("index-" + randomInt(numIndices - 1)); - entry = metadata.getProject().getMappingsByHash().get(luckyIndex.mapping().getSha256()); - MappingMetadata updatedMapping = new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, Map.of("mapping", "updated")); - { - Metadata.Builder mb = new Metadata.Builder(metadata); - mb.put(IndexMetadata.builder(luckyIndex).putMapping(updatedMapping)); - metadata = mb.build(); - } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size() + 1)); - assertThat(metadata.getProject().getMappingsByHash().get(luckyIndex.mapping().getSha256()), equalTo(entry)); - assertThat(metadata.getProject().getMappingsByHash().get(updatedMapping.getSha256()), equalTo(updatedMapping)); + public void testMultiProjectSerialization() throws IOException { + ProjectMetadata project1 = randomProject(randomProjectIdOrDefault(), 1); + ProjectMetadata project2 = randomProject(randomUniqueProjectId(), randomIntBetween(2, 10)); + Metadata metadata = randomMetadata(List.of(project1, project2)); + BytesStreamOutput out = new BytesStreamOutput(); + metadata.writeTo(out); - // Remove the index with updated mapping - { - Metadata.Builder mb = new Metadata.Builder(metadata); - mb.remove(luckyIndex.getIndex().getName()); - metadata = mb.build(); - } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); - assertThat(metadata.getProject().getMappingsByHash().get(updatedMapping.getSha256()), nullValue()); + // check it deserializes ok + NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); + Metadata fromStreamMeta = Metadata.readFrom(new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry)); - // Add an index with new mapping and then later remove it: - MappingMetadata newMapping = new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, Map.of("new", "mapping")); - { - Metadata.Builder mb = new Metadata.Builder(metadata); - mb.put( - IndexMetadata.builder("index-" + numIndices) - .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping(newMapping) - .numberOfShards(1) - .numberOfReplicas(0) - ); - metadata = mb.build(); + // check it matches the original object + assertThat(fromStreamMeta.projects(), aMapWithSize(2)); + for (var original : List.of(project1, project2)) { + assertThat(fromStreamMeta.projects(), hasKey(original.id())); + final ProjectMetadata fromStreamProject = fromStreamMeta.getProject(original.id()); + assertThat("For project " + original.id(), fromStreamProject.indices().size(), equalTo(original.indices().size())); + assertThat("For project " + original.id(), fromStreamProject.dataStreams().size(), equalTo(original.dataStreams().size())); + assertThat("For project " + original.id(), fromStreamProject.templates().size(), equalTo(original.templates().size())); + assertThat("For project " + original.id(), fromStreamProject.templatesV2().size(), equalTo(original.templatesV2().size())); + original.indices().forEach((name, value) -> { + assertThat(fromStreamProject.indices(), hasKey(name)); + assertThat(fromStreamProject.index(name), equalTo(value)); + }); + original.dataStreams().forEach((name, value) -> { + assertThat(fromStreamProject.dataStreams(), hasKey(name)); + assertThat(fromStreamProject.dataStreams().get(name), equalTo(value)); + }); } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size() + 1)); - assertThat(metadata.getProject().getMappingsByHash().get(newMapping.getSha256()), equalTo(newMapping)); + } - { - Metadata.Builder mb = new Metadata.Builder(metadata); - mb.remove("index-" + numIndices); - metadata = mb.build(); + public void testUnableToSerializeNonDefaultProjectBeforeMultiProject() { + final var projectId = randomUniqueProjectId(); + Metadata metadata = Metadata.builder().put(ProjectMetadata.builder(projectId)).build(); + + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setTransportVersion(TransportVersionUtils.getPreviousVersion(TransportVersions.MULTI_PROJECT)); + var e = assertThrows(UnsupportedOperationException.class, () -> metadata.writeTo(output)); + assertEquals("There is 1 project, but it has id [" + projectId + "] rather than default", e.getMessage()); } - assertThat(metadata.getProject().getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); - assertThat(metadata.getProject().getMappingsByHash().get(newMapping.getSha256()), nullValue()); } - public void testWithLifecycleState() { - String indexName = "my-index"; - String indexUUID = randomAlphaOfLength(10); - Metadata metadata1 = Metadata.builder(randomMetadata()) - .put( - IndexMetadata.builder(indexName) - .settings(settings(IndexVersion.current()).put(IndexMetadata.SETTING_INDEX_UUID, indexUUID)) - .creationDate(randomNonNegativeLong()) - .numberOfShards(1) - .numberOfReplicas(0) - ) - .build(); - IndexMetadata index1 = metadata1.getProject().index(indexName); - assertThat(metadata1.getProject().getIndicesLookup(), notNullValue()); - assertThat(index1.getLifecycleExecutionState(), sameInstance(LifecycleExecutionState.EMPTY_STATE)); - - LifecycleExecutionState state = LifecycleExecutionState.builder().setPhase("phase").setAction("action").setStep("step").build(); - Metadata metadata2 = metadata1.withLifecycleState(index1.getIndex(), state); - IndexMetadata index2 = metadata2.getProject().index(indexName); - - // the indices lookups are the same object - assertThat(metadata2.getProject().getIndicesLookup(), sameInstance(metadata1.getProject().getIndicesLookup())); - - // the lifecycle state and version were changed - assertThat(index2.getLifecycleExecutionState().asMap(), is(state.asMap())); - assertThat(index2.getVersion(), is(index1.getVersion() + 1)); - - // but those are the only differences between the two - IndexMetadata.Builder builder = IndexMetadata.builder(index2); - builder.version(builder.version() - 1); - builder.removeCustom(LifecycleExecutionState.ILM_CUSTOM_METADATA_KEY); - assertThat(index1, equalTo(builder.build())); - - // withLifecycleState returns the same reference if nothing changed - Metadata metadata3 = metadata2.withLifecycleState(index2.getIndex(), state); - assertThat(metadata3, sameInstance(metadata2)); - - // withLifecycleState rejects a nonsense Index - String randomUUID = randomValueOtherThan(indexUUID, () -> randomAlphaOfLength(10)); - expectThrows(IndexNotFoundException.class, () -> metadata1.withLifecycleState(new Index(indexName, randomUUID), state)); + public void testGetNonExistingProjectThrows() { + final List projects = IntStream.range(0, between(1, 3)) + .mapToObj(i -> randomProject(ProjectId.fromId("p_" + i), between(0, 5))) + .toList(); + final Metadata metadata = randomMetadata(projects); + expectThrows(IllegalArgumentException.class, () -> metadata.getProject(randomProjectIdOrDefault())); } public void testEmptyDiffReturnsSameInstance() throws IOException { @@ -3049,90 +1053,6 @@ public void testEnsureMetadataFieldCheckedForGlobalStateChanges() { assertThat(unclassifiedFields, empty()); } - public void testRetrieveIndexModeFromTemplateTsdb() throws IOException { - // tsdb: - var tsdbTemplate = new Template(Settings.builder().put("index.mode", "time_series").build(), new CompressedXContent("{}"), null); - // Settings in component template: - { - var componentTemplate = new ComponentTemplate(tsdbTemplate, null, null); - var indexTemplate = ComposableIndexTemplate.builder() - .indexPatterns(List.of("test-*")) - .componentTemplates(List.of("component_template_1")) - .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) - .build(); - Metadata m = Metadata.builder().put("component_template_1", componentTemplate).put("index_template_1", indexTemplate).build(); - assertThat(m.getProject().retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.TIME_SERIES)); - } - // Settings in composable index template: - { - var componentTemplate = new ComponentTemplate(new Template(null, null, null), null, null); - var indexTemplate = ComposableIndexTemplate.builder() - .indexPatterns(List.of("test-*")) - .template(tsdbTemplate) - .componentTemplates(List.of("component_template_1")) - .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) - .build(); - Metadata m = Metadata.builder().put("component_template_1", componentTemplate).put("index_template_1", indexTemplate).build(); - assertThat(m.getProject().retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.TIME_SERIES)); - } - } - - public void testRetrieveIndexModeFromTemplateLogsdb() throws IOException { - // logsdb: - var logsdbTemplate = new Template(Settings.builder().put("index.mode", "logsdb").build(), new CompressedXContent("{}"), null); - // Settings in component template: - { - var componentTemplate = new ComponentTemplate(logsdbTemplate, null, null); - var indexTemplate = ComposableIndexTemplate.builder() - .indexPatterns(List.of("test-*")) - .componentTemplates(List.of("component_template_1")) - .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) - .build(); - Metadata m = Metadata.builder().put("component_template_1", componentTemplate).put("index_template_1", indexTemplate).build(); - assertThat(m.getProject().retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.LOGSDB)); - } - // Settings in composable index template: - { - var componentTemplate = new ComponentTemplate(new Template(null, null, null), null, null); - var indexTemplate = ComposableIndexTemplate.builder() - .indexPatterns(List.of("test-*")) - .template(logsdbTemplate) - .componentTemplates(List.of("component_template_1")) - .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) - .build(); - Metadata m = Metadata.builder().put("component_template_1", componentTemplate).put("index_template_1", indexTemplate).build(); - assertThat(m.getProject().retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.LOGSDB)); - } - } - - public void testRetrieveIndexModeFromTemplateEmpty() throws IOException { - // no index mode: - var emptyTemplate = new Template(Settings.EMPTY, new CompressedXContent("{}"), null); - // Settings in component template: - { - var componentTemplate = new ComponentTemplate(emptyTemplate, null, null); - var indexTemplate = ComposableIndexTemplate.builder() - .indexPatterns(List.of("test-*")) - .componentTemplates(List.of("component_template_1")) - .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) - .build(); - Metadata m = Metadata.builder().put("component_template_1", componentTemplate).put("index_template_1", indexTemplate).build(); - assertThat(m.getProject().retrieveIndexModeFromTemplate(indexTemplate), nullValue()); - } - // Settings in composable index template: - { - var componentTemplate = new ComponentTemplate(new Template(null, null, null), null, null); - var indexTemplate = ComposableIndexTemplate.builder() - .indexPatterns(List.of("test-*")) - .template(emptyTemplate) - .componentTemplates(List.of("component_template_1")) - .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) - .build(); - Metadata m = Metadata.builder().put("component_template_1", componentTemplate).put("index_template_1", indexTemplate).build(); - assertThat(m.getProject().retrieveIndexModeFromTemplate(indexTemplate), nullValue()); - } - } - public void testGetSingleProjectWithCustom() { var type = IngestMetadata.TYPE; { @@ -3312,56 +1232,11 @@ public static ProjectMetadata randomProject(ProjectId id, int numDataStreams) { return project.build(); } - private static CreateIndexResult createIndices(int numIndices, int numBackingIndices, String dataStreamName) { - // create some indices that do not back a data stream - final List indices = new ArrayList<>(); - int lastIndexNum = randomIntBetween(9, 50); - Metadata.Builder b = Metadata.builder(); - for (int k = 1; k <= numIndices; k++) { - IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName("index", lastIndexNum)) - .settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - .build(); - b.put(im, false); - indices.add(im.getIndex()); - lastIndexNum = randomIntBetween(lastIndexNum + 1, lastIndexNum + 50); - } - - // create some backing indices for a data stream - final List backingIndices = new ArrayList<>(); - int lastBackingIndexNum = 0; - for (int k = 1; k <= numBackingIndices; k++) { - lastBackingIndexNum = randomIntBetween(lastBackingIndexNum + 1, lastBackingIndexNum + 50); - IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, lastBackingIndexNum)) - .settings(settings(IndexVersion.current())) - .numberOfShards(1) - .numberOfReplicas(1) - .build(); - b.put(im, false); - backingIndices.add(im.getIndex()); - } - b.put(newInstance(dataStreamName, backingIndices, lastBackingIndexNum, null)); - return new CreateIndexResult(indices, backingIndices, b.build()); - } - private static ToXContent.Params formatParams() { return new ToXContent.MapParams(Map.of("binary", "true", Metadata.CONTEXT_MODE_PARAM, Metadata.CONTEXT_MODE_GATEWAY)); } - private static class CreateIndexResult { - final List indices; - final List backingIndices; - final Metadata metadata; - - CreateIndexResult(List indices, List backingIndices, Metadata metadata) { - this.indices = indices; - this.backingIndices = backingIndices; - this.metadata = metadata; - } - } - - private abstract static class AbstractCustomMetadata> implements Metadata.MetadataCustom { + private static class TestClusterCustomMetadata implements Metadata.ClusterCustom { @Override public Iterator toXContentChunked(ToXContent.Params params) { @@ -3369,7 +1244,7 @@ public Iterator toXContentChunked(ToXContent.Params params } @Override - public Diff diff(C previousState) { + public Diff diff(Metadata.ClusterCustom previousState) { return null; } @@ -3393,12 +1268,4 @@ public void writeTo(StreamOutput out) throws IOException { } } - - private static class TestClusterCustomMetadata extends AbstractCustomMetadata - implements - Metadata.ClusterCustom {} - - private static class TestProjectCustomMetadata extends AbstractCustomMetadata - implements - Metadata.ProjectCustom {} } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/ProjectMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/ProjectMetadataTests.java index c39b5caeebce1..bd63b6e70371d 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/ProjectMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/ProjectMetadataTests.java @@ -9,39 +9,2226 @@ package org.elasticsearch.cluster.metadata; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.Diff; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Predicates; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.alias.RandomAliasActionsGenerator; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.plugins.FieldPredicate; +import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.test.AbstractChunkedSerializingTestCase; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.createBackingIndex; +import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.createFirstBackingIndex; +import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.newInstance; import static org.elasticsearch.cluster.metadata.MetadataTests.checkChunkSize; import static org.elasticsearch.cluster.metadata.MetadataTests.count; +import static org.elasticsearch.cluster.metadata.ProjectMetadata.Builder.assertDataStreams; +import static org.elasticsearch.test.LambdaMatchers.transformedItemsMatch; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; +import static org.hamcrest.Matchers.startsWith; public class ProjectMetadataTests extends ESTestCase { - private static final Setting PROJECT_SETTING = Setting.intSetting( - "project.setting.value", - 0, - Setting.Property.Dynamic, - Setting.Property.NodeScope, - Setting.Property.ProjectScope - ); + public void testFindAliases() { + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put( + IndexMetadata.builder("index") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("alias1").build()) + .putAlias(AliasMetadata.builder("alias2").build()) + ) + .put( + IndexMetadata.builder("index2") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("alias2").build()) + .putAlias(AliasMetadata.builder("alias3").build()) + ) + .build(); + + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT); + Map> aliases = project.findAliases(request.aliases(), Strings.EMPTY_ARRAY); + assertThat(aliases, anEmptyMap()); + } + { + final GetAliasesRequest request; + if (randomBoolean()) { + request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT); + } else { + request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT, randomFrom("alias1", "alias2")); + // replacing with empty aliases behaves as if aliases were unspecified at request building + request.replaceAliases(Strings.EMPTY_ARRAY); + } + Map> aliases = project.findAliases(request.aliases(), new String[] { "index" }); + assertThat(aliases, aMapWithSize(1)); + List aliasMetadataList = aliases.get("index"); + assertThat(aliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1", "alias2"))); + } + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT, "alias*"); + Map> aliases = project.findAliases(request.aliases(), new String[] { "index", "index2" }); + assertThat(aliases, aMapWithSize(2)); + List indexAliasMetadataList = aliases.get("index"); + assertThat(indexAliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1", "alias2"))); + List index2AliasMetadataList = aliases.get("index2"); + assertThat(index2AliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias2", "alias3"))); + } + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT, "alias1"); + Map> aliases = project.findAliases(request.aliases(), new String[] { "index" }); + assertThat(aliases, aMapWithSize(1)); + List aliasMetadataList = aliases.get("index"); + assertThat(aliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1"))); + } + { + Map> aliases = project.findAllAliases(new String[] { "index" }); + assertThat(aliases, aMapWithSize(1)); + List aliasMetadataList = aliases.get("index"); + assertThat(aliasMetadataList, transformedItemsMatch(AliasMetadata::alias, contains("alias1", "alias2"))); + } + { + Map> aliases = project.findAllAliases(Strings.EMPTY_ARRAY); + assertThat(aliases, anEmptyMap()); + } + } + + public void testFindDataStreamAliases() { + ProjectMetadata.Builder builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + + addDataStream("d1", builder); + addDataStream("d2", builder); + addDataStream("d3", builder); + addDataStream("d4", builder); + + builder.put("alias1", "d1", null, null); + builder.put("alias2", "d2", null, null); + builder.put("alias2-part2", "d2", null, null); + + ProjectMetadata project = builder.build(); + + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT); + Map> aliases = project.findDataStreamAliases(request.aliases(), Strings.EMPTY_ARRAY); + assertThat(aliases, anEmptyMap()); + } + + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("alias1"); + Map> aliases = project.findDataStreamAliases(request.aliases(), new String[] { "index" }); + assertThat(aliases, anEmptyMap()); + } + + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("alias1"); + Map> aliases = project.findDataStreamAliases( + request.aliases(), + new String[] { "index", "d1", "d2" } + ); + assertEquals(1, aliases.size()); + List found = aliases.get("d1"); + assertThat(found, transformedItemsMatch(DataStreamAlias::getAlias, contains("alias1"))); + } + + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("ali*"); + Map> aliases = project.findDataStreamAliases(request.aliases(), new String[] { "index", "d2" }); + assertEquals(1, aliases.size()); + List found = aliases.get("d2"); + assertThat(found, transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); + } + + // test exclusion + { + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("*"); + Map> aliases = project.findDataStreamAliases( + request.aliases(), + new String[] { "index", "d1", "d2", "d3", "d4" } + ); + assertThat(aliases.get("d2"), transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); + assertThat(aliases.get("d1"), transformedItemsMatch(DataStreamAlias::getAlias, contains("alias1"))); + + request.aliases("*", "-alias1"); + aliases = project.findDataStreamAliases(request.aliases(), new String[] { "index", "d1", "d2", "d3", "d4" }); + assertThat(aliases.get("d2"), transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); + assertNull(aliases.get("d1")); + } + } + + public void testDataStreamAliasesByDataStream() { + ProjectMetadata.Builder builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + + addDataStream("d1", builder); + addDataStream("d2", builder); + addDataStream("d3", builder); + addDataStream("d4", builder); + + builder.put("alias1", "d1", null, null); + builder.put("alias2", "d2", null, null); + builder.put("alias2-part2", "d2", null, null); + + ProjectMetadata project = builder.build(); + + var aliases = project.dataStreamAliasesByDataStream(); + + assertTrue(aliases.containsKey("d1")); + assertTrue(aliases.containsKey("d2")); + assertFalse(aliases.containsKey("d3")); + assertFalse(aliases.containsKey("d4")); + + assertEquals(1, aliases.get("d1").size()); + assertEquals(2, aliases.get("d2").size()); + + assertThat(aliases.get("d2"), transformedItemsMatch(DataStreamAlias::getAlias, containsInAnyOrder("alias2", "alias2-part2"))); + } + + public void testFindAliasWithExclusion() { + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put( + IndexMetadata.builder("index") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("alias1").build()) + .putAlias(AliasMetadata.builder("alias2").build()) + ) + .put( + IndexMetadata.builder("index2") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("alias1").build()) + .putAlias(AliasMetadata.builder("alias3").build()) + ) + .build(); + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("*", "-alias1"); + Map> aliases = project.findAliases(request.aliases(), new String[] { "index", "index2" }); + assertThat(aliases.get("index"), transformedItemsMatch(AliasMetadata::alias, contains("alias2"))); + assertThat(aliases.get("index2"), transformedItemsMatch(AliasMetadata::alias, contains("alias3"))); + } + + public void testFindDataStreams() { + final int numIndices = randomIntBetween(2, 5); + final int numBackingIndices = randomIntBetween(2, 5); + final String dataStreamName = "my-data-stream"; + CreateIndexResult result = createIndices(numIndices, numBackingIndices, dataStreamName); + + List allIndices = new ArrayList<>(result.indices); + allIndices.addAll(result.backingIndices); + String[] concreteIndices = allIndices.stream().map(Index::getName).toArray(String[]::new); + Map dataStreams = result.project.findDataStreams(concreteIndices); + assertThat(dataStreams, aMapWithSize(numBackingIndices)); + for (Index backingIndex : result.backingIndices) { + assertThat(dataStreams, hasKey(backingIndex.getName())); + assertThat(dataStreams.get(backingIndex.getName()).getName(), equalTo(dataStreamName)); + } + } + + public void testFindAliasWithExclusionAndOverride() { + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put( + IndexMetadata.builder("index") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("aa").build()) + .putAlias(AliasMetadata.builder("ab").build()) + .putAlias(AliasMetadata.builder("bb").build()) + ) + .build(); + GetAliasesRequest request = new GetAliasesRequest(TEST_REQUEST_TIMEOUT).aliases("a*", "-*b", "b*"); + List aliases = project.findAliases(request.aliases(), new String[] { "index" }).get("index"); + assertThat(aliases, transformedItemsMatch(AliasMetadata::alias, contains("aa", "bb"))); + } + + public void testAliasCollidingWithAnExistingIndex() { + int indexCount = randomIntBetween(10, 100); + Set indices = Sets.newHashSetWithExpectedSize(indexCount); + for (int i = 0; i < indexCount; i++) { + indices.add(randomAlphaOfLength(10)); + } + Map> aliasToIndices = new HashMap<>(); + for (String alias : randomSubsetOf(randomIntBetween(1, 10), indices)) { + Set indicesInAlias; + do { + indicesInAlias = new HashSet<>(randomSubsetOf(randomIntBetween(1, 3), indices)); + indicesInAlias.remove(alias); + } while (indicesInAlias.isEmpty()); + aliasToIndices.put(alias, indicesInAlias); + } + int properAliases = randomIntBetween(0, 3); + for (int i = 0; i < properAliases; i++) { + aliasToIndices.put(randomAlphaOfLength(5), new HashSet<>(randomSubsetOf(randomIntBetween(1, 3), indices))); + } + ProjectMetadata.Builder projectBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (String index : indices) { + IndexMetadata.Builder indexBuilder = IndexMetadata.builder(index) + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0); + aliasToIndices.forEach((key, value) -> { + if (value.contains(index)) { + indexBuilder.putAlias(AliasMetadata.builder(key).build()); + } + }); + projectBuilder.put(indexBuilder); + } + + Exception e = expectThrows(IllegalStateException.class, projectBuilder::build); + assertThat(e.getMessage(), startsWith("index, alias, and data stream names need to be unique")); + } + + public void testValidateAliasWriteOnly() { + String alias = randomAlphaOfLength(5); + String indexA = randomAlphaOfLength(6); + String indexB = randomAlphaOfLength(7); + Boolean aWriteIndex = randomBoolean() ? null : randomBoolean(); + Boolean bWriteIndex; + if (Boolean.TRUE.equals(aWriteIndex)) { + bWriteIndex = randomFrom(Boolean.FALSE, null); + } else { + bWriteIndex = randomFrom(Boolean.TRUE, Boolean.FALSE, null); + } + // when only one index/alias pair exist + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(buildIndexMetadata(indexA, alias, aWriteIndex)) + .build(); + + // when alias points to two indices, but valid + // one of the following combinations: [(null, null), (null, true), (null, false), (false, false)] + ProjectMetadata.builder(project).put(buildIndexMetadata(indexB, alias, bWriteIndex)).build(); + + // when too many write indices + Exception exception = expectThrows(IllegalStateException.class, () -> { + IndexMetadata.Builder metaA = buildIndexMetadata(indexA, alias, true); + IndexMetadata.Builder metaB = buildIndexMetadata(indexB, alias, true); + ProjectMetadata.builder(randomProjectIdOrDefault()).put(metaA).put(metaB).build(); + }); + assertThat(exception.getMessage(), startsWith("alias [" + alias + "] has more than one write index [")); + } + + public void testValidateHiddenAliasConsistency() { + String alias = randomAlphaOfLength(5); + String indexA = randomAlphaOfLength(6); + String indexB = randomAlphaOfLength(7); + + { + Exception ex = expectThrows( + IllegalStateException.class, + () -> buildMetadataWithHiddenIndexMix(alias, indexA, true, indexB, randomFrom(false, null)).build() + ); + assertThat(ex.getMessage(), containsString("has is_hidden set to true on indices")); + } + + { + Exception ex = expectThrows( + IllegalStateException.class, + () -> buildMetadataWithHiddenIndexMix(alias, indexA, randomFrom(false, null), indexB, true).build() + ); + assertThat(ex.getMessage(), containsString("has is_hidden set to true on indices")); + } + } + + private ProjectMetadata.Builder buildMetadataWithHiddenIndexMix( + String aliasName, + String indexAName, + Boolean indexAHidden, + String indexBName, + Boolean indexBHidden + ) { + IndexMetadata.Builder indexAMeta = IndexMetadata.builder(indexAName) + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder(aliasName).isHidden(indexAHidden).build()); + IndexMetadata.Builder indexBMeta = IndexMetadata.builder(indexBName) + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder(aliasName).isHidden(indexBHidden).build()); + return ProjectMetadata.builder(randomProjectIdOrDefault()).put(indexAMeta).put(indexBMeta); + } + + public void testResolveIndexRouting() { + IndexMetadata.Builder builder = IndexMetadata.builder("index") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("alias0").build()) + .putAlias(AliasMetadata.builder("alias1").routing("1").build()) + .putAlias(AliasMetadata.builder("alias2").routing("1,2").build()); + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()).put(builder).build(); + + // no alias, no index + assertNull(project.resolveIndexRouting(null, null)); + assertEquals(project.resolveIndexRouting("0", null), "0"); + + // index, no alias + assertNull(project.resolveIndexRouting(null, "index")); + assertEquals(project.resolveIndexRouting("0", "index"), "0"); + + // alias with no index routing + assertNull(project.resolveIndexRouting(null, "alias0")); + assertEquals(project.resolveIndexRouting("0", "alias0"), "0"); + + // alias with index routing. + assertEquals(project.resolveIndexRouting(null, "alias1"), "1"); + Exception ex = expectThrows(IllegalArgumentException.class, () -> project.resolveIndexRouting("0", "alias1")); + assertThat( + ex.getMessage(), + is("Alias [alias1] has index routing associated with it [1], and was provided with routing value [0], rejecting operation") + ); + + // alias with invalid index routing. + ex = expectThrows(IllegalArgumentException.class, () -> project.resolveIndexRouting(null, "alias2")); + assertThat( + ex.getMessage(), + is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") + ); + + ex = expectThrows(IllegalArgumentException.class, () -> project.resolveIndexRouting("1", "alias2")); + assertThat( + ex.getMessage(), + is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") + ); + + IndexMetadata.Builder builder2 = IndexMetadata.builder("index2") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(AliasMetadata.builder("alias0").build()); + ProjectMetadata projectTwoIndices = ProjectMetadata.builder(project).put(builder2).build(); + + // alias with multiple indices + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> projectTwoIndices.resolveIndexRouting("1", "alias0") + ); + assertThat(exception.getMessage(), startsWith("Alias [alias0] has more than one index associated with it")); + } + + public void testResolveWriteIndexRouting() { + AliasMetadata.Builder aliasZeroBuilder = AliasMetadata.builder("alias0"); + if (randomBoolean()) { + aliasZeroBuilder.writeIndex(true); + } + IndexMetadata.Builder builder = IndexMetadata.builder("index") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(aliasZeroBuilder.build()) + .putAlias(AliasMetadata.builder("alias1").routing("1").build()) + .putAlias(AliasMetadata.builder("alias2").routing("1,2").build()) + .putAlias(AliasMetadata.builder("alias3").writeIndex(false).build()) + .putAlias(AliasMetadata.builder("alias4").routing("1,2").writeIndex(true).build()); + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()).put(builder).build(); + + // no alias, no index + assertNull(project.resolveWriteIndexRouting(null, null)); + assertEquals(project.resolveWriteIndexRouting("0", null), "0"); + + // index, no alias + assertNull(project.resolveWriteIndexRouting(null, "index")); + assertEquals(project.resolveWriteIndexRouting("0", "index"), "0"); + + // alias with no index routing + assertNull(project.resolveWriteIndexRouting(null, "alias0")); + assertEquals(project.resolveWriteIndexRouting("0", "alias0"), "0"); + + // alias with index routing. + assertEquals(project.resolveWriteIndexRouting(null, "alias1"), "1"); + Exception exception = expectThrows(IllegalArgumentException.class, () -> project.resolveWriteIndexRouting("0", "alias1")); + assertThat( + exception.getMessage(), + is("Alias [alias1] has index routing associated with it [1], and was provided with routing value [0], rejecting operation") + ); + + // alias with invalid index routing. + exception = expectThrows(IllegalArgumentException.class, () -> project.resolveWriteIndexRouting(null, "alias2")); + assertThat( + exception.getMessage(), + is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") + ); + exception = expectThrows(IllegalArgumentException.class, () -> project.resolveWriteIndexRouting("1", "alias2")); + assertThat( + exception.getMessage(), + is("index/alias [alias2] provided with routing value [1,2] that resolved to several routing values, rejecting operation") + ); + exception = expectThrows(IllegalArgumentException.class, () -> project.resolveWriteIndexRouting(randomFrom("1", null), "alias4")); + assertThat( + exception.getMessage(), + is("index/alias [alias4] provided with routing value [1,2] that resolved to several routing values, rejecting operation") + ); + + // alias with no write index + exception = expectThrows(IllegalArgumentException.class, () -> project.resolveWriteIndexRouting("1", "alias3")); + assertThat(exception.getMessage(), is("alias [alias3] does not have a write index")); + + // aliases with multiple indices + AliasMetadata.Builder aliasZeroBuilderTwo = AliasMetadata.builder("alias0"); + if (randomBoolean()) { + aliasZeroBuilder.writeIndex(false); + } + IndexMetadata.Builder builder2 = IndexMetadata.builder("index2") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(aliasZeroBuilderTwo.build()) + .putAlias(AliasMetadata.builder("alias1").routing("0").writeIndex(true).build()) + .putAlias(AliasMetadata.builder("alias2").writeIndex(true).build()); + ProjectMetadata projectTwoIndices = ProjectMetadata.builder(project).put(builder2).build(); + + // verify that new write index is used + assertThat("0", equalTo(projectTwoIndices.resolveWriteIndexRouting("0", "alias1"))); + } + + public void testFindMappings() throws IOException { + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(IndexMetadata.builder("index1").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(FIND_MAPPINGS_TEST_ITEM)) + .put(IndexMetadata.builder("index2").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(FIND_MAPPINGS_TEST_ITEM)) + .build(); + + { + AtomicInteger onNextIndexCalls = new AtomicInteger(0); + Map mappings = project.findMappings( + Strings.EMPTY_ARRAY, + MapperPlugin.NOOP_FIELD_FILTER, + onNextIndexCalls::incrementAndGet + ); + assertThat(mappings, anEmptyMap()); + assertThat(onNextIndexCalls.get(), equalTo(0)); + } + { + AtomicInteger onNextIndexCalls = new AtomicInteger(0); + Map mappings = project.findMappings( + new String[] { "index1" }, + MapperPlugin.NOOP_FIELD_FILTER, + onNextIndexCalls::incrementAndGet + ); + assertThat(mappings, aMapWithSize(1)); + assertIndexMappingsNotFiltered(mappings, "index1"); + assertThat(onNextIndexCalls.get(), equalTo(1)); + } + { + AtomicInteger onNextIndexCalls = new AtomicInteger(0); + Map mappings = project.findMappings( + new String[] { "index1", "index2" }, + MapperPlugin.NOOP_FIELD_FILTER, + onNextIndexCalls::incrementAndGet + ); + assertThat(mappings, aMapWithSize(2)); + assertIndexMappingsNotFiltered(mappings, "index1"); + assertIndexMappingsNotFiltered(mappings, "index2"); + assertThat(onNextIndexCalls.get(), equalTo(2)); + } + } + + public void testFindMappingsNoOpFilters() throws IOException { + MappingMetadata originalMappingMetadata = new MappingMetadata( + "_doc", + XContentHelper.convertToMap(JsonXContent.jsonXContent, FIND_MAPPINGS_TEST_ITEM, true) + ); + + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(IndexMetadata.builder("index1").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(originalMappingMetadata)) + .build(); + + { + Map mappings = project.findMappings( + new String[] { "index1" }, + MapperPlugin.NOOP_FIELD_FILTER, + Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP + ); + MappingMetadata mappingMetadata = mappings.get("index1"); + assertSame(originalMappingMetadata, mappingMetadata); + } + { + Map mappings = project.findMappings( + new String[] { "index1" }, + index -> field -> randomBoolean(), + Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP + ); + MappingMetadata mappingMetadata = mappings.get("index1"); + assertNotSame(originalMappingMetadata, mappingMetadata); + } + } + + @SuppressWarnings("unchecked") + public void testFindMappingsWithFilters() throws IOException { + String mapping = FIND_MAPPINGS_TEST_ITEM; + if (randomBoolean()) { + Map stringObjectMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, FIND_MAPPINGS_TEST_ITEM, false); + Map doc = (Map) stringObjectMap.get("_doc"); + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.map(doc); + mapping = Strings.toString(builder); + } + } + + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(IndexMetadata.builder("index1").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(mapping)) + .put(IndexMetadata.builder("index2").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(mapping)) + .put(IndexMetadata.builder("index3").settings(indexSettings(IndexVersion.current(), 1, 0)).putMapping(mapping)) + .build(); + + { + Map mappings = project.findMappings(new String[] { "index1", "index2", "index3" }, index -> { + if (index.equals("index1")) { + return field -> field.startsWith("name.") == false + && field.startsWith("properties.key.") == false + && field.equals("age") == false + && field.equals("address.location") == false; + } + if (index.equals("index2")) { + return Predicates.never(); + } + return FieldPredicate.ACCEPT_ALL; + }, Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP); + + assertIndexMappingsNoFields(mappings, "index2"); + assertIndexMappingsNotFiltered(mappings, "index3"); + + MappingMetadata docMapping = mappings.get("index1"); + assertNotNull(docMapping); + + Map sourceAsMap = docMapping.getSourceAsMap(); + assertThat(sourceAsMap.keySet(), containsInAnyOrder("properties", "_routing", "_source")); + + Map typeProperties = (Map) sourceAsMap.get("properties"); + assertThat(typeProperties.keySet(), containsInAnyOrder("name", "address", "birth", "ip", "suggest", "properties")); + + Map name = (Map) typeProperties.get("name"); + assertThat(name.keySet(), containsInAnyOrder("properties")); + Map nameProperties = (Map) name.get("properties"); + assertThat(nameProperties, anEmptyMap()); + + Map address = (Map) typeProperties.get("address"); + assertThat(address.keySet(), containsInAnyOrder("type", "properties")); + Map addressProperties = (Map) address.get("properties"); + assertThat(addressProperties.keySet(), containsInAnyOrder("street", "area")); + assertLeafs(addressProperties, "street", "area"); + + Map properties = (Map) typeProperties.get("properties"); + assertThat(properties.keySet(), containsInAnyOrder("type", "properties")); + Map propertiesProperties = (Map) properties.get("properties"); + assertThat(propertiesProperties.keySet(), containsInAnyOrder("key", "value")); + assertLeafs(propertiesProperties, "key"); + assertMultiField(propertiesProperties, "value", "keyword"); + } + + { + Map mappings = project.findMappings( + new String[] { "index1", "index2", "index3" }, + index -> field -> (index.equals("index3") && field.endsWith("keyword")), + Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP + ); + + assertIndexMappingsNoFields(mappings, "index1"); + assertIndexMappingsNoFields(mappings, "index2"); + MappingMetadata mappingMetadata = mappings.get("index3"); + Map sourceAsMap = mappingMetadata.getSourceAsMap(); + assertThat(sourceAsMap.keySet(), containsInAnyOrder("_routing", "_source", "properties")); + Map typeProperties = (Map) sourceAsMap.get("properties"); + assertThat(typeProperties.keySet(), containsInAnyOrder("properties")); + Map properties = (Map) typeProperties.get("properties"); + assertThat(properties.keySet(), containsInAnyOrder("type", "properties")); + Map propertiesProperties = (Map) properties.get("properties"); + assertThat(propertiesProperties.keySet(), containsInAnyOrder("key", "value")); + Map key = (Map) propertiesProperties.get("key"); + assertThat(key.keySet(), containsInAnyOrder("properties")); + Map keyProperties = (Map) key.get("properties"); + assertThat(keyProperties.keySet(), containsInAnyOrder("keyword")); + assertLeafs(keyProperties, "keyword"); + Map value = (Map) propertiesProperties.get("value"); + assertThat(value.keySet(), containsInAnyOrder("properties")); + Map valueProperties = (Map) value.get("properties"); + assertThat(valueProperties.keySet(), containsInAnyOrder("keyword")); + assertLeafs(valueProperties, "keyword"); + } + + { + Map mappings = project.findMappings( + new String[] { "index1", "index2", "index3" }, + index -> field -> (index.equals("index2")), + Metadata.ON_NEXT_INDEX_FIND_MAPPINGS_NOOP + ); + + assertIndexMappingsNoFields(mappings, "index1"); + assertIndexMappingsNoFields(mappings, "index3"); + assertIndexMappingsNotFiltered(mappings, "index2"); + } + } + + public void testOldestIndexComputation() { + ProjectMetadata project = buildIndicesWithVersions( + IndexVersions.MINIMUM_COMPATIBLE, + IndexVersion.current(), + IndexVersion.fromId(IndexVersion.current().id() + 1) + ).build(); + + assertEquals(IndexVersions.MINIMUM_COMPATIBLE, project.oldestIndexVersion()); + + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + assertEquals(IndexVersion.current(), b.build().oldestIndexVersion()); + + Throwable ex = expectThrows( + IllegalArgumentException.class, + () -> buildIndicesWithVersions( + IndexVersions.MINIMUM_COMPATIBLE, + IndexVersions.ZERO, + IndexVersion.fromId(IndexVersion.current().id() + 1) + ).build() + ); + + assertEquals("[index.version.created] is not present in the index settings for index with UUID [null]", ex.getMessage()); + } + + private ProjectMetadata.Builder buildIndicesWithVersions(IndexVersion... indexVersions) { + int lastIndexNum = randomIntBetween(9, 50); + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (IndexVersion indexVersion : indexVersions) { + IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName("index", lastIndexNum)) + .settings(settings(indexVersion)) + .numberOfShards(1) + .numberOfReplicas(1) + .build(); + b.put(im, false); + lastIndexNum = randomIntBetween(lastIndexNum + 1, lastIndexNum + 50); + } + + return b; + } + + private static IndexMetadata.Builder buildIndexMetadata(String name, String alias, Boolean writeIndex) { + return IndexMetadata.builder(name) + .settings(settings(IndexVersion.current())) + .creationDate(randomNonNegativeLong()) + .putAlias(AliasMetadata.builder(alias).writeIndex(writeIndex)) + .numberOfShards(1) + .numberOfReplicas(0); + } + + @SuppressWarnings("unchecked") + private static void assertIndexMappingsNoFields(Map mappings, String index) { + MappingMetadata docMapping = mappings.get(index); + assertNotNull(docMapping); + Map sourceAsMap = docMapping.getSourceAsMap(); + assertThat(sourceAsMap.keySet(), containsInAnyOrder("_routing", "_source", "properties")); + Map typeProperties = (Map) sourceAsMap.get("properties"); + assertThat(typeProperties, anEmptyMap()); + } + + @SuppressWarnings("unchecked") + private static void assertIndexMappingsNotFiltered(Map mappings, String index) { + MappingMetadata docMapping = mappings.get(index); + assertNotNull(docMapping); + + Map sourceAsMap = docMapping.getSourceAsMap(); + assertThat(sourceAsMap.keySet(), containsInAnyOrder("_routing", "_source", "properties")); + + Map typeProperties = (Map) sourceAsMap.get("properties"); + assertThat(typeProperties.keySet(), containsInAnyOrder("name", "address", "birth", "age", "ip", "suggest", "properties")); + + Map name = (Map) typeProperties.get("name"); + assertThat(name.keySet(), containsInAnyOrder("properties")); + Map nameProperties = (Map) name.get("properties"); + assertThat(nameProperties.keySet(), containsInAnyOrder("first", "last")); + assertLeafs(nameProperties, "first", "last"); + + Map address = (Map) typeProperties.get("address"); + assertThat(address.keySet(), containsInAnyOrder("type", "properties")); + Map addressProperties = (Map) address.get("properties"); + assertThat(addressProperties.keySet(), containsInAnyOrder("street", "location", "area")); + assertLeafs(addressProperties, "street", "location", "area"); + + Map properties = (Map) typeProperties.get("properties"); + assertThat(properties.keySet(), containsInAnyOrder("type", "properties")); + Map propertiesProperties = (Map) properties.get("properties"); + assertThat(propertiesProperties.keySet(), containsInAnyOrder("key", "value")); + assertMultiField(propertiesProperties, "key", "keyword"); + assertMultiField(propertiesProperties, "value", "keyword"); + } + + @SuppressWarnings("unchecked") + public static void assertLeafs(Map properties, String... fields) { + assertThat(properties.keySet(), hasItems(fields)); + for (String field : fields) { + Map fieldProp = (Map) properties.get(field); + assertThat(fieldProp, not(hasKey("properties"))); + assertThat(fieldProp, not(hasKey("fields"))); + } + } + + public static void assertMultiField(Map properties, String field, String... subFields) { + assertThat(properties, hasKey(field)); + @SuppressWarnings("unchecked") + Map fieldProp = (Map) properties.get(field); + assertThat(fieldProp, hasKey("fields")); + @SuppressWarnings("unchecked") + Map subFieldsDef = (Map) fieldProp.get("fields"); + assertLeafs(subFieldsDef, subFields); + } + + private static final String FIND_MAPPINGS_TEST_ITEM = """ + { + "_doc": { + "_routing": { + "required":true + }, "_source": { + "enabled":false + }, "properties": { + "name": { + "properties": { + "first": { + "type": "keyword" + }, + "last": { + "type": "keyword" + } + } + }, + "birth": { + "type": "date" + }, + "age": { + "type": "integer" + }, + "ip": { + "type": "ip" + }, + "suggest" : { + "type": "completion" + }, + "address": { + "type": "object", + "properties": { + "street": { + "type": "keyword" + }, + "location": { + "type": "geo_point" + }, + "area": { + "type": "geo_shape", \s + "tree": "quadtree", + "precision": "1m" + } + } + }, + "properties": { + "type": "nested", + "properties": { + "key" : { + "type": "text", + "fields": { + "keyword" : { + "type" : "keyword" + } + } + }, + "value" : { + "type": "text", + "fields": { + "keyword" : { + "type" : "keyword" + } + } + } + } + } + } + } + } + }"""; + + public void testBuilderRejectsNullCustom() { + final ProjectMetadata.Builder builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + final String key = randomAlphaOfLength(10); + assertThat(expectThrows(NullPointerException.class, () -> builder.putCustom(key, null)).getMessage(), containsString(key)); + } + + public void testBuilderRejectsNullInCustoms() { + final ProjectMetadata.Builder builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + final String key = randomAlphaOfLength(10); + { + final Map map = new HashMap<>(); + map.put(key, null); + assertThat(expectThrows(NullPointerException.class, () -> builder.customs(map)).getMessage(), containsString(key)); + } + } + + public void testCopyAndUpdate() { + var initialIndexUUID = randomUUID(); + final String indexName = randomAlphaOfLengthBetween(4, 12); + final ProjectMetadata before = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(IndexMetadata.builder(indexName).settings(indexSettings(IndexVersion.current(), initialIndexUUID, 1, 1))) + .build(); + + var alteredIndexUUID = randomUUID(); + assertThat(alteredIndexUUID, not(equalTo(initialIndexUUID))); + final ProjectMetadata after = before.copyAndUpdate( + builder -> builder.put(IndexMetadata.builder(indexName).settings(indexSettings(IndexVersion.current(), alteredIndexUUID, 1, 1))) + ); + + assertThat(after, not(sameInstance(before))); + assertThat(after.index(indexName).getIndexUUID(), equalTo(alteredIndexUUID)); + } + + public void testBuilderRemoveCustomIf() { + var custom1 = new TestProjectCustomMetadata(); + var custom2 = new TestProjectCustomMetadata(); + var builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + builder.putCustom("custom1", custom1); + builder.putCustom("custom2", custom2); + + builder.removeCustomIf((key, value) -> Objects.equals(key, "custom1")); + + var project = builder.build(); + assertThat(project.custom("custom1"), nullValue()); + assertThat(project.custom("custom2"), sameInstance(custom2)); + } + + public void testBuilderRejectsDataStreamThatConflictsWithIndex() { + final String dataStreamName = "my-data-stream"; + IndexMetadata idx = createFirstBackingIndex(dataStreamName).build(); + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(idx, false) + .put( + IndexMetadata.builder(dataStreamName) + .settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + .build(), + false + ) + .put(newInstance(dataStreamName, List.of(idx.getIndex()))); + + IllegalStateException e = expectThrows(IllegalStateException.class, b::build); + assertThat( + e.getMessage(), + containsString( + "index, alias, and data stream names need to be unique, but the following duplicates were found [data " + + "stream [" + + dataStreamName + + "] conflicts with index]" + ) + ); + } + + public void testBuilderRejectsDataStreamThatConflictsWithAlias() { + final String dataStreamName = "my-data-stream"; + IndexMetadata idx = createFirstBackingIndex(dataStreamName).putAlias(AliasMetadata.builder(dataStreamName).build()).build(); + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(idx, false) + .put(newInstance(dataStreamName, List.of(idx.getIndex()))); + + IllegalStateException e = expectThrows(IllegalStateException.class, b::build); + assertThat( + e.getMessage(), + containsString( + "index, alias, and data stream names need to be unique, but the following duplicates were found [" + + dataStreamName + + " (alias of [" + + idx.getIndex().getName() + + "]) conflicts with data stream]" + ) + ); + } + + public void testBuilderRejectsAliasThatRefersToDataStreamBackingIndex() { + final String dataStreamName = "my-data-stream"; + final String conflictingName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); + IndexMetadata idx = createFirstBackingIndex(dataStreamName).putAlias(new AliasMetadata.Builder(conflictingName)).build(); + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(idx, false) + .put(newInstance(dataStreamName, List.of(idx.getIndex()))); + + AssertionError e = expectThrows(AssertionError.class, b::build); + assertThat(e.getMessage(), containsString("aliases [" + conflictingName + "] cannot refer to backing indices of data streams")); + } + + public void testBuilderForDataStreamWithRandomlyNumberedBackingIndices() { + final String dataStreamName = "my-data-stream"; + final List backingIndices = new ArrayList<>(); + final int numBackingIndices = randomIntBetween(2, 5); + int lastBackingIndexNum = 0; + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (int k = 1; k <= numBackingIndices; k++) { + lastBackingIndexNum = randomIntBetween(lastBackingIndexNum + 1, lastBackingIndexNum + 50); + IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, lastBackingIndexNum)) + .settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + .build(); + b.put(im, false); + backingIndices.add(im.getIndex()); + } + + b.put(newInstance(dataStreamName, backingIndices, lastBackingIndexNum, null)); + ProjectMetadata project = b.build(); + assertThat(project.dataStreams().keySet(), containsInAnyOrder(dataStreamName)); + assertThat(project.dataStreams().get(dataStreamName).getName(), equalTo(dataStreamName)); + } + + public void testBuildIndicesLookupForDataStreams() { + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + int numDataStreams = randomIntBetween(2, 8); + for (int i = 0; i < numDataStreams; i++) { + String name = "data-stream-" + i; + addDataStream(name, b); + } + + ProjectMetadata project = b.build(); + assertThat(project.dataStreams().size(), equalTo(numDataStreams)); + for (int i = 0; i < numDataStreams; i++) { + String name = "data-stream-" + i; + IndexAbstraction value = project.getIndicesLookup().get(name); + assertThat(value, notNullValue()); + DataStream ds = project.dataStreams().get(name); + assertThat(ds, notNullValue()); + + assertThat(value.isHidden(), is(false)); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); + assertThat(value.getIndices(), hasSize(ds.getIndices().size())); + assertThat(value.getWriteIndex().getName(), DataStreamTestHelper.backingIndexEqualTo(name, (int) ds.getGeneration())); + } + } + + public void testBuildIndicesLookupForDataStreamAliases() { + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + + addDataStream("d1", b); + addDataStream("d2", b); + addDataStream("d3", b); + addDataStream("d4", b); + + b.put("a1", "d1", null, null); + b.put("a1", "d2", null, null); + b.put("a2", "d3", null, null); + b.put("a3", "d1", null, null); + + ProjectMetadata project = b.build(); + assertThat(project.dataStreams(), aMapWithSize(4)); + IndexAbstraction value = project.getIndicesLookup().get("d1"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); + + value = project.getIndicesLookup().get("d2"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); + + value = project.getIndicesLookup().get("d3"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); + + value = project.getIndicesLookup().get("d4"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.DATA_STREAM)); + + value = project.getIndicesLookup().get("a1"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.ALIAS)); + + value = project.getIndicesLookup().get("a2"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.ALIAS)); + + value = project.getIndicesLookup().get("a3"); + assertThat(value, notNullValue()); + assertThat(value.getType(), equalTo(IndexAbstraction.Type.ALIAS)); + } + + public void testDataStreamAliasValidation() { + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + addDataStream("my-alias", b); + b.put("my-alias", "my-alias", null, null); + var e = expectThrows(IllegalStateException.class, b::build); + assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); + + b = ProjectMetadata.builder(randomProjectIdOrDefault()); + addDataStream("d1", b); + addDataStream("my-alias", b); + b.put("my-alias", "d1", null, null); + e = expectThrows(IllegalStateException.class, b::build); + assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); + + b = ProjectMetadata.builder(randomProjectIdOrDefault()); + b.put( + IndexMetadata.builder("index1") + .settings(indexSettings(IndexVersion.current(), 1, 0)) + .putAlias(new AliasMetadata.Builder("my-alias")) + ); + + addDataStream("d1", b); + b.put("my-alias", "d1", null, null); + e = expectThrows(IllegalStateException.class, b::build); + assertThat(e.getMessage(), containsString("data stream alias and indices alias have the same name (my-alias)")); + } + + public void testDataStreamAliasValidationRestoreScenario() { + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + b.dataStreams( + Map.of("my-alias", createDataStream("my-alias")), + Map.of("my-alias", new DataStreamAlias("my-alias", List.of("my-alias"), null, null)) + ); + var e = expectThrows(IllegalStateException.class, b::build); + assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); + + b = ProjectMetadata.builder(randomProjectIdOrDefault()); + b.dataStreams( + Map.of("d1", createDataStream("d1"), "my-alias", createDataStream("my-alias")), + Map.of("my-alias", new DataStreamAlias("my-alias", List.of("d1"), null, null)) + ); + e = expectThrows(IllegalStateException.class, b::build); + assertThat(e.getMessage(), containsString("data stream alias and data stream have the same name (my-alias)")); + + b = ProjectMetadata.builder(randomProjectIdOrDefault()); + b.put( + IndexMetadata.builder("index1") + .settings(indexSettings(IndexVersion.current(), 1, 0)) + .putAlias(new AliasMetadata.Builder("my-alias")) + ); + b.dataStreams(Map.of("d1", createDataStream("d1")), Map.of("my-alias", new DataStreamAlias("my-alias", List.of("d1"), null, null))); + e = expectThrows(IllegalStateException.class, b::build); + assertThat(e.getMessage(), containsString("data stream alias and indices alias have the same name (my-alias)")); + } + + private void addDataStream(String name, ProjectMetadata.Builder b) { + int numBackingIndices = randomIntBetween(1, 4); + List indices = new ArrayList<>(numBackingIndices); + for (int j = 1; j <= numBackingIndices; j++) { + IndexMetadata idx = createBackingIndex(name, j).build(); + indices.add(idx.getIndex()); + b.put(idx, true); + } + b.put(newInstance(name, indices)); + } + + private DataStream createDataStream(String name) { + int numBackingIndices = randomIntBetween(1, 4); + List indices = new ArrayList<>(numBackingIndices); + for (int j = 1; j <= numBackingIndices; j++) { + IndexMetadata idx = createBackingIndex(name, j).build(); + indices.add(idx.getIndex()); + } + return newInstance(name, indices); + } + + public void testIndicesLookupRecordsDataStreamForBackingIndices() { + final int numIndices = randomIntBetween(2, 5); + final int numBackingIndices = randomIntBetween(2, 5); + final String dataStreamName = "my-data-stream"; + CreateIndexResult result = createIndices(numIndices, numBackingIndices, dataStreamName); + + SortedMap indicesLookup = result.project.getIndicesLookup(); + assertThat(indicesLookup, aMapWithSize(result.indices.size() + result.backingIndices.size() + 1)); + for (Index index : result.indices) { + assertThat(indicesLookup, hasKey(index.getName())); + assertNull(indicesLookup.get(index.getName()).getParentDataStream()); + } + for (Index index : result.backingIndices) { + assertThat(indicesLookup, hasKey(index.getName())); + assertNotNull(indicesLookup.get(index.getName()).getParentDataStream()); + assertThat(indicesLookup.get(index.getName()).getParentDataStream().getName(), equalTo(dataStreamName)); + } + } + + public void testValidateDataStreamsNoConflicts() { + ProjectMetadata project = createIndices(5, 10, "foo-datastream").project; + // don't expect any exception when validating a system without indices that would conflict with future backing indices + assertDataStreams(project.indices(), (DataStreamMetadata) project.customs().get(DataStreamMetadata.TYPE)); + } + + public void testValidateDataStreamsIgnoresIndicesWithoutCounter() { + String dataStreamName = "foo-datastream"; + ProjectMetadata project = ProjectMetadata.builder(createIndices(10, 10, dataStreamName).project) + .put( + new IndexMetadata.Builder(dataStreamName + "-index-without-counter").settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + ) + .put( + new IndexMetadata.Builder(dataStreamName + randomAlphaOfLength(10)).settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + + ) + .put( + new IndexMetadata.Builder(randomAlphaOfLength(10)).settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + + ) + .build(); + // don't expect any exception when validating against non-backing indices that don't conform to the backing indices naming + // convention + assertDataStreams(project.indices(), (DataStreamMetadata) project.customs().get(DataStreamMetadata.TYPE)); + } + + public void testValidateDataStreamsAllowsNamesThatStartsWithPrefix() { + String dataStreamName = "foo-datastream"; + ProjectMetadata project = ProjectMetadata.builder(createIndices(10, 10, dataStreamName).project) + .put( + new IndexMetadata.Builder(DataStream.BACKING_INDEX_PREFIX + dataStreamName + "-something-100012").settings( + settings(IndexVersion.current()) + ).numberOfShards(1).numberOfReplicas(1) + ) + .build(); + // don't expect any exception when validating against (potentially backing) indices that can't create conflict because of + // additional text before number + assertDataStreams(project.indices(), (DataStreamMetadata) project.customs().get(DataStreamMetadata.TYPE)); + } + + public void testValidateDataStreamsForNullDataStreamMetadata() { + ProjectMetadata project = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put(IndexMetadata.builder("foo-index").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1)) + .build(); + + try { + assertDataStreams(project.indices(), DataStreamMetadata.EMPTY); + } catch (Exception e) { + fail("did not expect exception when validating a system without any data streams but got " + e.getMessage()); + } + } + + public void testDataStreamAliases() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null), is(true)); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-us", null, null), is(true)); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-au", null, null), is(true)); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-au", null, null), is(false)); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") + ); + } + + public void testDataStreamReferToNonExistingDataStream() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + + Exception e = expectThrows(IllegalArgumentException.class, () -> mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null)); + assertThat(e.getMessage(), equalTo("alias [logs-postgres] refers to a non existing data stream [logs-postgres-eu]")); + + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); + mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-eu")); + } + + public void testDeleteDataStreamShouldUpdateAlias() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); + mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); + mdBuilder.put("logs-postgres", "logs-postgres-us", null, null); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); + mdBuilder.put("logs-postgres", "logs-postgres-au", null, null); + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") + ); + + mdBuilder = ProjectMetadata.builder(project); + mdBuilder.removeDataStream("logs-postgres-us"); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-eu", "logs-postgres-au") + ); + + mdBuilder = ProjectMetadata.builder(project); + mdBuilder.removeDataStream("logs-postgres-au"); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-eu")); + + mdBuilder = ProjectMetadata.builder(project); + mdBuilder.removeDataStream("logs-postgres-eu"); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), nullValue()); + } + + public void testDeleteDataStreamAlias() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); + mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); + mdBuilder.put("logs-postgres", "logs-postgres-us", null, null); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); + mdBuilder.put("logs-postgres", "logs-postgres-au", null, null); + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") + ); + + mdBuilder = ProjectMetadata.builder(project); + assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-us", true), is(true)); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-eu", "logs-postgres-au") + ); + + mdBuilder = ProjectMetadata.builder(project); + assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-au", true), is(true)); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-eu")); + + mdBuilder = ProjectMetadata.builder(project); + assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-eu", true), is(true)); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), nullValue()); + } + + public void testDeleteDataStreamAliasMustExists() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-eu")); + mdBuilder.put("logs-postgres", "logs-postgres-eu", null, null); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-us")); + mdBuilder.put("logs-postgres", "logs-postgres-us", null, null); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-au")); + mdBuilder.put("logs-postgres", "logs-postgres-au", null, null); + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-eu", "logs-postgres-us", "logs-postgres-au") + ); + + ProjectMetadata.Builder mdBuilder2 = ProjectMetadata.builder(project); + expectThrows(ResourceNotFoundException.class, () -> mdBuilder2.removeDataStreamAlias("logs-mysql", "logs-postgres-us", true)); + assertThat(mdBuilder2.removeDataStreamAlias("logs-mysql", "logs-postgres-us", false), is(false)); + } + + public void testDataStreamWriteAlias() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); + mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-replicated")); + + mdBuilder = ProjectMetadata.builder(project); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", true, null), is(true)); + + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-replicated")); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-replicated")); + } + + public void testDataStreamMultipleWriteAlias() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-foobar")); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-barbaz")); + mdBuilder.put("logs", "logs-foobar", true, null); + mdBuilder.put("logs", "logs-barbaz", true, null); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs").getWriteDataStream(), equalTo("logs-barbaz")); + assertThat(project.dataStreamAliases().get("logs").getDataStreams(), containsInAnyOrder("logs-foobar", "logs-barbaz")); + } + + public void testDataStreamWriteAliasUnset() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); + mdBuilder.put("logs-postgres", "logs-postgres-replicated", true, null); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-replicated")); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-replicated")); + + mdBuilder = ProjectMetadata.builder(project); + // Side check: null value isn't changing anything: + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(false)); + // Unset write flag + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", false, null), is(true)); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-replicated")); + } + + public void testDataStreamWriteAliasChange() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-primary")); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", true, null), is(true)); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(true)); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-primary")); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") + ); + + // change write flag: + mdBuilder = ProjectMetadata.builder(project); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", false, null), is(true)); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", true, null), is(true)); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-replicated")); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") + ); + } + + public void testDataStreamWriteRemoveAlias() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-primary")); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", true, null), is(true)); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(true)); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-primary")); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") + ); + + mdBuilder = ProjectMetadata.builder(project); + assertThat(mdBuilder.removeDataStreamAlias("logs-postgres", "logs-postgres-primary", randomBoolean()), is(true)); + project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-replicated")); + } + + public void testDataStreamWriteRemoveDataStream() { + ProjectMetadata.Builder mdBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-primary")); + mdBuilder.put(DataStreamTestHelper.randomInstance("logs-postgres-replicated")); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-primary", true, null), is(true)); + assertThat(mdBuilder.put("logs-postgres", "logs-postgres-replicated", null, null), is(true)); + + ProjectMetadata project = mdBuilder.build(); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), equalTo("logs-postgres-primary")); + assertThat( + project.dataStreamAliases().get("logs-postgres").getDataStreams(), + containsInAnyOrder("logs-postgres-primary", "logs-postgres-replicated") + ); + + mdBuilder = ProjectMetadata.builder(project); + mdBuilder.removeDataStream("logs-postgres-primary"); + project = mdBuilder.build(); + assertThat(project.dataStreams().keySet(), contains("logs-postgres-replicated")); + assertThat(project.dataStreamAliases().get("logs-postgres"), notNullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getWriteDataStream(), nullValue()); + assertThat(project.dataStreamAliases().get("logs-postgres").getDataStreams(), containsInAnyOrder("logs-postgres-replicated")); + } + + public void testReuseIndicesLookup() { + String indexName = "my-index"; + String aliasName = "my-alias"; + String dataStreamName = "logs-mysql-prod"; + String dataStreamAliasName = "logs-mysql"; + ProjectMetadata previous = ProjectMetadata.builder(randomProjectIdOrDefault()).build(); + + // Things that should change indices lookup + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + IndexMetadata idx = DataStreamTestHelper.createFirstBackingIndex(dataStreamName).build(); + builder.put(idx, true); + DataStream dataStream = newInstance(dataStreamName, List.of(idx.getIndex())); + builder.put(dataStream); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + builder.put(dataStreamAliasName, dataStreamName, false, null); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + builder.put(dataStreamAliasName, dataStreamName, true, null); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + builder.put( + IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current())) + .creationDate(randomNonNegativeLong()) + .numberOfShards(1) + .numberOfReplicas(0) + ); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); + imBuilder.putAlias(AliasMetadata.builder(aliasName).build()); + builder.put(imBuilder); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); + imBuilder.putAlias(AliasMetadata.builder(aliasName).writeIndex(true).build()); + builder.put(imBuilder); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); + Settings.Builder sBuilder = Settings.builder() + .put(builder.get(indexName).getSettings()) + .put(IndexMetadata.INDEX_HIDDEN_SETTING.getKey(), true); + imBuilder.settings(sBuilder.build()); + builder.put(imBuilder); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), not(sameInstance(project.getIndicesLookup()))); + previous = project; + } + + // Things that shouldn't change indices lookup + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); + imBuilder.numberOfReplicas(2); + builder.put(imBuilder); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), sameInstance(project.getIndicesLookup())); + previous = project; + } + { + ProjectMetadata.Builder builder = ProjectMetadata.builder(previous); + IndexMetadata.Builder imBuilder = IndexMetadata.builder(builder.get(indexName)); + Settings.Builder sBuilder = Settings.builder() + .put(builder.get(indexName).getSettings()) + .put(IndexSettings.DEFAULT_FIELD_SETTING.getKey(), "val"); + imBuilder.settings(sBuilder.build()); + builder.put(imBuilder); + ProjectMetadata project = builder.build(); + assertThat(previous.getIndicesLookup(), sameInstance(project.getIndicesLookup())); + previous = project; + } + } + + public void testAliasedIndices() { + int numAliases = randomIntBetween(32, 64); + int numIndicesPerAlias = randomIntBetween(8, 16); + + ProjectMetadata.Builder builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (int i = 0; i < numAliases; i++) { + String aliasName = "alias-" + i; + for (int j = 0; j < numIndicesPerAlias; j++) { + AliasMetadata.Builder alias = new AliasMetadata.Builder(aliasName); + if (j == 0) { + alias.writeIndex(true); + } + + String indexName = aliasName + "-" + j; + builder.put( + IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current())) + .creationDate(randomNonNegativeLong()) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(alias) + ); + } + } + + ProjectMetadata project = builder.build(); + for (int i = 0; i < numAliases; i++) { + String aliasName = "alias-" + i; + Set result = project.aliasedIndices(aliasName); + Index[] expected = IntStream.range(0, numIndicesPerAlias) + .mapToObj(j -> aliasName + "-" + j) + .map(name -> new Index(name, ClusterState.UNKNOWN_UUID)) + .toArray(Index[]::new); + assertThat(result, containsInAnyOrder(expected)); + } + + // Add a new alias and index + builder = ProjectMetadata.builder(project); + String newAliasName = "alias-new"; + { + builder.put( + IndexMetadata.builder(newAliasName + "-1") + .settings(settings(IndexVersion.current())) + .creationDate(randomNonNegativeLong()) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(new AliasMetadata.Builder(newAliasName).writeIndex(true)) + ); + } + project = builder.build(); + assertThat(project.aliasedIndices(), hasSize(numAliases + 1)); + assertThat(project.aliasedIndices(newAliasName), contains(new Index(newAliasName + "-1", ClusterState.UNKNOWN_UUID))); + + // Remove the new alias/index + builder = ProjectMetadata.builder(project); + { + builder.remove(newAliasName + "-1"); + } + project = builder.build(); + assertThat(project.aliasedIndices(), hasSize(numAliases)); + assertThat(project.aliasedIndices(newAliasName), empty()); + + // Add a new alias that points to existing indices + builder = ProjectMetadata.builder(project); + { + IndexMetadata.Builder imBuilder = new IndexMetadata.Builder(project.index("alias-1-0")); + imBuilder.putAlias(new AliasMetadata.Builder(newAliasName)); + builder.put(imBuilder); + + imBuilder = new IndexMetadata.Builder(project.index("alias-2-1")); + imBuilder.putAlias(new AliasMetadata.Builder(newAliasName)); + builder.put(imBuilder); + + imBuilder = new IndexMetadata.Builder(project.index("alias-3-2")); + imBuilder.putAlias(new AliasMetadata.Builder(newAliasName)); + builder.put(imBuilder); + } + project = builder.build(); + assertThat(project.aliasedIndices(), hasSize(numAliases + 1)); + assertThat( + project.aliasedIndices(newAliasName), + containsInAnyOrder( + new Index("alias-1-0", ClusterState.UNKNOWN_UUID), + new Index("alias-2-1", ClusterState.UNKNOWN_UUID), + new Index("alias-3-2", ClusterState.UNKNOWN_UUID) + ) + ); + + // Remove the new alias that points to existing indices + builder = ProjectMetadata.builder(project); + { + IndexMetadata.Builder imBuilder = new IndexMetadata.Builder(project.index("alias-1-0")); + imBuilder.removeAlias(newAliasName); + builder.put(imBuilder); + + imBuilder = new IndexMetadata.Builder(project.index("alias-2-1")); + imBuilder.removeAlias(newAliasName); + builder.put(imBuilder); + + imBuilder = new IndexMetadata.Builder(project.index("alias-3-2")); + imBuilder.removeAlias(newAliasName); + builder.put(imBuilder); + } + project = builder.build(); + assertThat(project.aliasedIndices(), hasSize(numAliases)); + assertThat(project.aliasedIndices(newAliasName), empty()); + } + + public void testHiddenAliasValidation() { + final String hiddenAliasName = "hidden_alias"; + + IndexMetadata hidden1 = buildIndexWithAlias("hidden1", hiddenAliasName, true, IndexVersion.current(), false); + IndexMetadata hidden2 = buildIndexWithAlias("hidden2", hiddenAliasName, true, IndexVersion.current(), false); + IndexMetadata hidden3 = buildIndexWithAlias("hidden3", hiddenAliasName, true, IndexVersion.current(), false); + + IndexMetadata nonHidden = buildIndexWithAlias("nonhidden1", hiddenAliasName, false, IndexVersion.current(), false); + IndexMetadata unspecified = buildIndexWithAlias("nonhidden2", hiddenAliasName, null, IndexVersion.current(), false); + + { + // Should be ok: + projectWithIndices(hidden1, hidden2, hidden3); + } + + { + // Should be ok: + if (randomBoolean()) { + projectWithIndices(nonHidden, unspecified); + } else { + projectWithIndices(unspecified, nonHidden); + } + } + + { + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> projectWithIndices(hidden1, hidden2, hidden3, nonHidden) + ); + assertThat(exception.getMessage(), containsString("alias [" + hiddenAliasName + "] has is_hidden set to true on indices [")); + assertThat( + exception.getMessage(), + allOf( + containsString(hidden1.getIndex().getName()), + containsString(hidden2.getIndex().getName()), + containsString(hidden3.getIndex().getName()) + ) + ); + assertThat( + exception.getMessage(), + containsString( + "but does not have is_hidden set to true on indices [" + + nonHidden.getIndex().getName() + + "]; alias must have the same is_hidden setting on all indices" + ) + ); + } + + { + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> projectWithIndices(hidden1, hidden2, hidden3, unspecified) + ); + assertThat(exception.getMessage(), containsString("alias [" + hiddenAliasName + "] has is_hidden set to true on indices [")); + assertThat( + exception.getMessage(), + allOf( + containsString(hidden1.getIndex().getName()), + containsString(hidden2.getIndex().getName()), + containsString(hidden3.getIndex().getName()) + ) + ); + assertThat( + exception.getMessage(), + containsString( + "but does not have is_hidden set to true on indices [" + + unspecified.getIndex().getName() + + "]; alias must have the same is_hidden setting on all indices" + ) + ); + } + + { + final IndexMetadata hiddenIndex = randomFrom(hidden1, hidden2, hidden3); + IllegalStateException exception = expectThrows(IllegalStateException.class, () -> { + if (randomBoolean()) { + projectWithIndices(nonHidden, unspecified, hiddenIndex); + } else { + projectWithIndices(unspecified, nonHidden, hiddenIndex); + } + }); + assertThat( + exception.getMessage(), + containsString( + "alias [" + + hiddenAliasName + + "] has is_hidden set to true on " + + "indices [" + + hiddenIndex.getIndex().getName() + + "] but does not have is_hidden set to true on indices [" + ) + ); + assertThat( + exception.getMessage(), + allOf(containsString(unspecified.getIndex().getName()), containsString(nonHidden.getIndex().getName())) + ); + assertThat(exception.getMessage(), containsString("but does not have is_hidden set to true on indices [")); + } + } + + public static final String SYSTEM_ALIAS_NAME = "system_alias"; + + public void testSystemAliasValidationMixedVersionSystemAndRegularFails() { + final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.V_7_0_0, + IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + ); + final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); + final IndexMetadata regularIndex = buildIndexWithAlias("regular1", SYSTEM_ALIAS_NAME, false, IndexVersion.current(), false); + + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> projectWithIndices(currentVersionSystem, oldVersionSystem, regularIndex) + ); + assertThat( + exception.getMessage(), + containsString( + "alias [" + + SYSTEM_ALIAS_NAME + + "] refers to both system indices [" + + currentVersionSystem.getIndex().getName() + + "] and non-system indices: [" + + regularIndex.getIndex().getName() + + "], but aliases must refer to either system or non-system indices, not both" + ) + ); + } + + public void testSystemAliasValidationNewSystemAndRegularFails() { + final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata regularIndex = buildIndexWithAlias("regular1", SYSTEM_ALIAS_NAME, false, IndexVersion.current(), false); + + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> projectWithIndices(currentVersionSystem, regularIndex) + ); + assertThat( + exception.getMessage(), + containsString( + "alias [" + + SYSTEM_ALIAS_NAME + + "] refers to both system indices [" + + currentVersionSystem.getIndex().getName() + + "] and non-system indices: [" + + regularIndex.getIndex().getName() + + "], but aliases must refer to either system or non-system indices, not both" + ) + ); + } + + public void testSystemAliasOldSystemAndNewRegular() { + final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.V_7_0_0, + IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + ); + final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); + final IndexMetadata regularIndex = buildIndexWithAlias("regular1", SYSTEM_ALIAS_NAME, false, IndexVersion.current(), false); + + // Should be ok: + projectWithIndices(oldVersionSystem, regularIndex); + } + + public void testSystemIndexValidationAllRegular() { + final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.V_7_0_0, + IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + ); + final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata currentVersionSystem2 = buildIndexWithAlias(".system2", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); + + // Should be ok + projectWithIndices(currentVersionSystem, currentVersionSystem2, oldVersionSystem); + } + + public void testSystemAliasValidationAllSystemSomeOld() { + final IndexVersion random7xVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.V_7_0_0, + IndexVersionUtils.getPreviousVersion(IndexVersions.V_8_0_0) + ); + final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata currentVersionSystem2 = buildIndexWithAlias(".system2", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata oldVersionSystem = buildIndexWithAlias(".oldVersionSystem", SYSTEM_ALIAS_NAME, null, random7xVersion, true); + + // Should be ok: + projectWithIndices(currentVersionSystem, currentVersionSystem2, oldVersionSystem); + } + + public void testSystemAliasValidationAll8x() { + final IndexMetadata currentVersionSystem = buildIndexWithAlias(".system1", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + final IndexMetadata currentVersionSystem2 = buildIndexWithAlias(".system2", SYSTEM_ALIAS_NAME, null, IndexVersion.current(), true); + + // Should be ok + projectWithIndices(currentVersionSystem, currentVersionSystem2); + } + + private void projectWithIndices(IndexMetadata... indices) { + ProjectMetadata.Builder builder = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (var cursor : indices) { + builder.put(cursor, false); + } + builder.build(); + } + + private IndexMetadata buildIndexWithAlias( + String indexName, + String aliasName, + @Nullable Boolean aliasIsHidden, + IndexVersion indexCreationVersion, + boolean isSystem + ) { + final AliasMetadata.Builder aliasMetadata = new AliasMetadata.Builder(aliasName); + if (aliasIsHidden != null || randomBoolean()) { + aliasMetadata.isHidden(aliasIsHidden); + } + return new IndexMetadata.Builder(indexName).settings(settings(indexCreationVersion)) + .system(isSystem) + .numberOfShards(1) + .numberOfReplicas(0) + .putAlias(aliasMetadata) + .build(); + } + + public void testMappingDuplication() { + final Set randomMappingDefinitions; + { + int numEntries = randomIntBetween(4, 8); + randomMappingDefinitions = Sets.newHashSetWithExpectedSize(numEntries); + for (int i = 0; i < numEntries; i++) { + Map mapping = RandomAliasActionsGenerator.randomMap(2); + String mappingAsString = Strings.toString((builder, params) -> builder.mapContents(mapping)); + randomMappingDefinitions.add(mappingAsString); + } + } + + ProjectMetadata project; + int numIndices = randomIntBetween(16, 32); + { + String[] definitions = randomMappingDefinitions.toArray(String[]::new); + ProjectMetadata.Builder mb = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (int i = 0; i < numIndices; i++) { + IndexMetadata.Builder indexBuilder = IndexMetadata.builder("index-" + i) + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .putMapping(definitions[i % randomMappingDefinitions.size()]) + .numberOfShards(1) + .numberOfReplicas(0); + if (randomBoolean()) { + mb.put(indexBuilder); + } else { + mb.put(indexBuilder.build(), true); + } + } + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); + assertThat( + project.indices().values().stream().map(IndexMetadata::mapping).collect(Collectors.toSet()), + hasSize(project.getMappingsByHash().size()) + ); + + // Add a new index with a new index with known mapping: + MappingMetadata mapping = project.indices().get("index-" + randomInt(numIndices - 1)).mapping(); + MappingMetadata entry = project.getMappingsByHash().get(mapping.getSha256()); + { + ProjectMetadata.Builder mb = new ProjectMetadata.Builder(project); + mb.put( + IndexMetadata.builder("index-" + numIndices) + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .putMapping(mapping) + .numberOfShards(1) + .numberOfReplicas(0) + ); + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); + assertThat(project.getMappingsByHash().get(mapping.getSha256()), equalTo(entry)); + + // Remove index and ensure mapping cache stays the same + { + ProjectMetadata.Builder mb = new ProjectMetadata.Builder(project); + mb.remove("index-" + numIndices); + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); + assertThat(project.getMappingsByHash().get(mapping.getSha256()), equalTo(entry)); + + // Update a mapping of an index: + IndexMetadata luckyIndex = project.index("index-" + randomInt(numIndices - 1)); + entry = project.getMappingsByHash().get(luckyIndex.mapping().getSha256()); + MappingMetadata updatedMapping = new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, Map.of("mapping", "updated")); + { + ProjectMetadata.Builder mb = new ProjectMetadata.Builder(project); + mb.put(IndexMetadata.builder(luckyIndex).putMapping(updatedMapping)); + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size() + 1)); + assertThat(project.getMappingsByHash().get(luckyIndex.mapping().getSha256()), equalTo(entry)); + assertThat(project.getMappingsByHash().get(updatedMapping.getSha256()), equalTo(updatedMapping)); + + // Remove the index with updated mapping + { + ProjectMetadata.Builder mb = new ProjectMetadata.Builder(project); + mb.remove(luckyIndex.getIndex().getName()); + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); + assertThat(project.getMappingsByHash().get(updatedMapping.getSha256()), nullValue()); + + // Add an index with new mapping and then later remove it: + MappingMetadata newMapping = new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, Map.of("new", "mapping")); + { + ProjectMetadata.Builder mb = new ProjectMetadata.Builder(project); + mb.put( + IndexMetadata.builder("index-" + numIndices) + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .putMapping(newMapping) + .numberOfShards(1) + .numberOfReplicas(0) + ); + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size() + 1)); + assertThat(project.getMappingsByHash().get(newMapping.getSha256()), equalTo(newMapping)); + + { + ProjectMetadata.Builder mb = new ProjectMetadata.Builder(project); + mb.remove("index-" + numIndices); + project = mb.build(); + } + assertThat(project.getMappingsByHash(), aMapWithSize(randomMappingDefinitions.size())); + assertThat(project.getMappingsByHash().get(newMapping.getSha256()), nullValue()); + } + + public void testWithLifecycleState() { + String indexName = "my-index"; + String indexUUID = randomAlphaOfLength(10); + ProjectMetadata project1 = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put( + IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(IndexMetadata.SETTING_INDEX_UUID, indexUUID)) + .creationDate(randomNonNegativeLong()) + .numberOfShards(1) + .numberOfReplicas(0) + ) + .build(); + IndexMetadata index1 = project1.index(indexName); + assertThat(project1.getIndicesLookup(), notNullValue()); + assertThat(index1.getLifecycleExecutionState(), sameInstance(LifecycleExecutionState.EMPTY_STATE)); + + LifecycleExecutionState state = LifecycleExecutionState.builder().setPhase("phase").setAction("action").setStep("step").build(); + ProjectMetadata project2 = project1.withLifecycleState(index1.getIndex(), state); + IndexMetadata index2 = project2.index(indexName); + + // the indices lookups are the same object + assertThat(project2.getIndicesLookup(), sameInstance(project1.getIndicesLookup())); + + // the lifecycle state and version were changed + assertThat(index2.getLifecycleExecutionState().asMap(), is(state.asMap())); + assertThat(index2.getVersion(), is(index1.getVersion() + 1)); + + // but those are the only differences between the two + IndexMetadata.Builder builder = IndexMetadata.builder(index2); + builder.version(builder.version() - 1); + builder.removeCustom(LifecycleExecutionState.ILM_CUSTOM_METADATA_KEY); + assertThat(index1, equalTo(builder.build())); + + // withLifecycleState returns the same reference if nothing changed + ProjectMetadata project3 = project2.withLifecycleState(index2.getIndex(), state); + assertThat(project3, sameInstance(project2)); + + // withLifecycleState rejects a nonsense Index + String randomUUID = randomValueOtherThan(indexUUID, () -> randomAlphaOfLength(10)); + expectThrows(IndexNotFoundException.class, () -> project1.withLifecycleState(new Index(indexName, randomUUID), state)); + } + + public void testRetrieveIndexModeFromTemplateTsdb() throws IOException { + // tsdb: + var tsdbTemplate = new Template(Settings.builder().put("index.mode", "time_series").build(), new CompressedXContent("{}"), null); + // Settings in component template: + { + var componentTemplate = new ComponentTemplate(tsdbTemplate, null, null); + var indexTemplate = ComposableIndexTemplate.builder() + .indexPatterns(List.of("test-*")) + .componentTemplates(List.of("component_template_1")) + .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) + .build(); + ProjectMetadata p = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put("component_template_1", componentTemplate) + .put("index_template_1", indexTemplate) + .build(); + assertThat(p.retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.TIME_SERIES)); + } + // Settings in composable index template: + { + var componentTemplate = new ComponentTemplate(new Template(null, null, null), null, null); + var indexTemplate = ComposableIndexTemplate.builder() + .indexPatterns(List.of("test-*")) + .template(tsdbTemplate) + .componentTemplates(List.of("component_template_1")) + .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) + .build(); + ProjectMetadata p = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put("component_template_1", componentTemplate) + .put("index_template_1", indexTemplate) + .build(); + assertThat(p.retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.TIME_SERIES)); + } + } + + public void testRetrieveIndexModeFromTemplateLogsdb() throws IOException { + // logsdb: + var logsdbTemplate = new Template(Settings.builder().put("index.mode", "logsdb").build(), new CompressedXContent("{}"), null); + // Settings in component template: + { + var componentTemplate = new ComponentTemplate(logsdbTemplate, null, null); + var indexTemplate = ComposableIndexTemplate.builder() + .indexPatterns(List.of("test-*")) + .componentTemplates(List.of("component_template_1")) + .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) + .build(); + ProjectMetadata p = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put("component_template_1", componentTemplate) + .put("index_template_1", indexTemplate) + .build(); + assertThat(p.retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.LOGSDB)); + } + // Settings in composable index template: + { + var componentTemplate = new ComponentTemplate(new Template(null, null, null), null, null); + var indexTemplate = ComposableIndexTemplate.builder() + .indexPatterns(List.of("test-*")) + .template(logsdbTemplate) + .componentTemplates(List.of("component_template_1")) + .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) + .build(); + ProjectMetadata p = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put("component_template_1", componentTemplate) + .put("index_template_1", indexTemplate) + .build(); + assertThat(p.retrieveIndexModeFromTemplate(indexTemplate), is(IndexMode.LOGSDB)); + } + } + + public void testRetrieveIndexModeFromTemplateEmpty() throws IOException { + // no index mode: + var emptyTemplate = new Template(Settings.EMPTY, new CompressedXContent("{}"), null); + // Settings in component template: + { + var componentTemplate = new ComponentTemplate(emptyTemplate, null, null); + var indexTemplate = ComposableIndexTemplate.builder() + .indexPatterns(List.of("test-*")) + .componentTemplates(List.of("component_template_1")) + .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) + .build(); + ProjectMetadata p = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put("component_template_1", componentTemplate) + .put("index_template_1", indexTemplate) + .build(); + assertThat(p.retrieveIndexModeFromTemplate(indexTemplate), nullValue()); + } + // Settings in composable index template: + { + var componentTemplate = new ComponentTemplate(new Template(null, null, null), null, null); + var indexTemplate = ComposableIndexTemplate.builder() + .indexPatterns(List.of("test-*")) + .template(emptyTemplate) + .componentTemplates(List.of("component_template_1")) + .dataStreamTemplate(new ComposableIndexTemplate.DataStreamTemplate()) + .build(); + ProjectMetadata p = ProjectMetadata.builder(randomProjectIdOrDefault()) + .put("component_template_1", componentTemplate) + .put("index_template_1", indexTemplate) + .build(); + assertThat(p.retrieveIndexModeFromTemplate(indexTemplate), nullValue()); + } + } + + private static CreateIndexResult createIndices(int numIndices, int numBackingIndices, String dataStreamName) { + // create some indices that do not back a data stream + final List indices = new ArrayList<>(); + int lastIndexNum = randomIntBetween(9, 50); + ProjectMetadata.Builder b = ProjectMetadata.builder(randomProjectIdOrDefault()); + for (int k = 1; k <= numIndices; k++) { + IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName("index", lastIndexNum)) + .settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + .build(); + b.put(im, false); + indices.add(im.getIndex()); + lastIndexNum = randomIntBetween(lastIndexNum + 1, lastIndexNum + 50); + } + + // create some backing indices for a data stream + final List backingIndices = new ArrayList<>(); + int lastBackingIndexNum = 0; + for (int k = 1; k <= numBackingIndices; k++) { + lastBackingIndexNum = randomIntBetween(lastBackingIndexNum + 1, lastBackingIndexNum + 50); + IndexMetadata im = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, lastBackingIndexNum)) + .settings(settings(IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(1) + .build(); + b.put(im, false); + backingIndices.add(im.getIndex()); + } + b.put(newInstance(dataStreamName, backingIndices, lastBackingIndexNum, null)); + return new CreateIndexResult(indices, backingIndices, b.build()); + } + + private record CreateIndexResult(List indices, List backingIndices, ProjectMetadata project) {}; public void testToXContent() throws IOException { final ProjectMetadata projectMetadata = prepareProjectMetadata(); @@ -670,21 +2857,37 @@ static int expectedChunkCount(ToXContent.Params params, ProjectMetadata project) return Math.toIntExact(chunkCount); } - public void testCopyAndUpdate() { - var initialIndexUUID = randomUUID(); - final String indexName = randomAlphaOfLengthBetween(4, 12); - final ProjectMetadata before = ProjectMetadata.builder(randomProjectIdOrDefault()) - .put(IndexMetadata.builder(indexName).settings(indexSettings(IndexVersion.current(), initialIndexUUID, 1, 1))) - .build(); + private static class TestProjectCustomMetadata implements Metadata.ProjectCustom { - var alteredIndexUUID = randomUUID(); - assertThat(alteredIndexUUID, not(equalTo(initialIndexUUID))); - final ProjectMetadata after = before.copyAndUpdate( - builder -> builder.put(IndexMetadata.builder(indexName).settings(indexSettings(IndexVersion.current(), alteredIndexUUID, 1, 1))) - ); + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Collections.emptyIterator(); + } - assertThat(after, not(sameInstance(before))); - assertThat(after.index(indexName).getIndexUUID(), equalTo(alteredIndexUUID)); + @Override + public Diff diff(Metadata.ProjectCustom previousState) { + return null; + } + + @Override + public EnumSet context() { + return null; + } + + @Override + public String getWriteableName() { + return null; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java index f3c0a2a4a5141..e10012133ffcd 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java @@ -63,7 +63,7 @@ public void testSimpleJsonFromAndTo() throws IOException { ReservedStateMetadata reservedStateMetadata1 = ReservedStateMetadata.builder("namespace_two").putHandler(hmTwo).build(); - Metadata metadata = Metadata.builder() + ProjectMetadata project = ProjectMetadata.builder(ProjectId.DEFAULT) .put( IndexTemplateMetadata.builder("foo") .patterns(Collections.singletonList("bar")) @@ -133,7 +133,7 @@ public void testSimpleJsonFromAndTo() throws IOException { XContentBuilder builder = JsonXContent.contentBuilder(); builder.startObject(); - ChunkedToXContent.wrapAsToXContent(metadata) + ChunkedToXContent.wrapAsToXContent(Metadata.builder().put(project).build()) .toXContent( builder, new ToXContent.MapParams(Map.of("binary", "true", Metadata.CONTEXT_MODE_PARAM, Metadata.CONTEXT_MODE_GATEWAY)) @@ -146,30 +146,31 @@ public void testSimpleJsonFromAndTo() throws IOException { } // templates - assertThat(parsedMetadata.getProject().templates().get("foo").name(), is("foo")); - assertThat(parsedMetadata.getProject().templates().get("foo").patterns(), is(Collections.singletonList("bar"))); - assertThat(parsedMetadata.getProject().templates().get("foo").settings().get("index.setting1"), is("value1")); - assertThat(parsedMetadata.getProject().templates().get("foo").settings().getByPrefix("index.").get("setting2"), is("value2")); - assertThat(parsedMetadata.getProject().templates().get("foo").aliases().size(), equalTo(3)); - assertThat(parsedMetadata.getProject().templates().get("foo").aliases().get("alias-bar1").alias(), equalTo("alias-bar1")); - assertThat(parsedMetadata.getProject().templates().get("foo").aliases().get("alias-bar2").alias(), equalTo("alias-bar2")); + final var parsedProject = parsedMetadata.getProject(ProjectId.DEFAULT); + assertThat(parsedProject.templates().get("foo").name(), is("foo")); + assertThat(parsedProject.templates().get("foo").patterns(), is(Collections.singletonList("bar"))); + assertThat(parsedProject.templates().get("foo").settings().get("index.setting1"), is("value1")); + assertThat(parsedProject.templates().get("foo").settings().getByPrefix("index.").get("setting2"), is("value2")); + assertThat(parsedProject.templates().get("foo").aliases().size(), equalTo(3)); + assertThat(parsedProject.templates().get("foo").aliases().get("alias-bar1").alias(), equalTo("alias-bar1")); + assertThat(parsedProject.templates().get("foo").aliases().get("alias-bar2").alias(), equalTo("alias-bar2")); assertThat( - parsedMetadata.getProject().templates().get("foo").aliases().get("alias-bar2").filter().string(), + parsedProject.templates().get("foo").aliases().get("alias-bar2").filter().string(), equalTo("{\"term\":{\"user\":\"kimchy\"}}") ); - assertThat(parsedMetadata.getProject().templates().get("foo").aliases().get("alias-bar3").alias(), equalTo("alias-bar3")); - assertThat(parsedMetadata.getProject().templates().get("foo").aliases().get("alias-bar3").indexRouting(), equalTo("routing-bar")); - assertThat(parsedMetadata.getProject().templates().get("foo").aliases().get("alias-bar3").searchRouting(), equalTo("routing-bar")); + assertThat(parsedProject.templates().get("foo").aliases().get("alias-bar3").alias(), equalTo("alias-bar3")); + assertThat(parsedProject.templates().get("foo").aliases().get("alias-bar3").indexRouting(), equalTo("routing-bar")); + assertThat(parsedProject.templates().get("foo").aliases().get("alias-bar3").searchRouting(), equalTo("routing-bar")); // component template - assertNotNull(parsedMetadata.getProject().componentTemplates().get("component_template")); - assertThat(parsedMetadata.getProject().componentTemplates().get("component_template").version(), is(5L)); + assertNotNull(parsedProject.componentTemplates().get("component_template")); + assertThat(parsedProject.componentTemplates().get("component_template").version(), is(5L)); assertThat( - parsedMetadata.getProject().componentTemplates().get("component_template").metadata(), + parsedProject.componentTemplates().get("component_template").metadata(), equalTo(Collections.singletonMap("my_meta", Collections.singletonMap("foo", "bar"))) ); assertThat( - parsedMetadata.getProject().componentTemplates().get("component_template").template(), + parsedProject.componentTemplates().get("component_template").template(), equalTo( new Template( Settings.builder().put("setting", "value").build(), @@ -180,20 +181,17 @@ public void testSimpleJsonFromAndTo() throws IOException { ); // index template v2 - assertNotNull(parsedMetadata.getProject().templatesV2().get("index_templatev2")); - assertThat(parsedMetadata.getProject().templatesV2().get("index_templatev2").priority(), is(5L)); - assertThat(parsedMetadata.getProject().templatesV2().get("index_templatev2").version(), is(4L)); - assertThat(parsedMetadata.getProject().templatesV2().get("index_templatev2").indexPatterns(), is(Arrays.asList("foo", "bar*"))); - assertThat( - parsedMetadata.getProject().templatesV2().get("index_templatev2").composedOf(), - is(Collections.singletonList("component_template")) - ); + assertNotNull(parsedProject.templatesV2().get("index_templatev2")); + assertThat(parsedProject.templatesV2().get("index_templatev2").priority(), is(5L)); + assertThat(parsedProject.templatesV2().get("index_templatev2").version(), is(4L)); + assertThat(parsedProject.templatesV2().get("index_templatev2").indexPatterns(), is(Arrays.asList("foo", "bar*"))); + assertThat(parsedProject.templatesV2().get("index_templatev2").composedOf(), is(Collections.singletonList("component_template"))); assertThat( - parsedMetadata.getProject().templatesV2().get("index_templatev2").metadata(), + parsedProject.templatesV2().get("index_templatev2").metadata(), equalTo(Collections.singletonMap("my_meta", Collections.singletonMap("potato", "chicken"))) ); assertThat( - parsedMetadata.getProject().templatesV2().get("index_templatev2").template(), + parsedProject.templatesV2().get("index_templatev2").template(), equalTo( new Template( Settings.builder().put("setting", "value").build(), @@ -204,12 +202,12 @@ public void testSimpleJsonFromAndTo() throws IOException { ); // data streams - assertNotNull(parsedMetadata.getProject().dataStreams().get("data-stream1")); - assertThat(parsedMetadata.getProject().dataStreams().get("data-stream1").getName(), is("data-stream1")); - assertThat(parsedMetadata.getProject().dataStreams().get("data-stream1").getIndices(), contains(idx1.getIndex())); - assertNotNull(parsedMetadata.getProject().dataStreams().get("data-stream2")); - assertThat(parsedMetadata.getProject().dataStreams().get("data-stream2").getName(), is("data-stream2")); - assertThat(parsedMetadata.getProject().dataStreams().get("data-stream2").getIndices(), contains(idx2.getIndex())); + assertNotNull(parsedProject.dataStreams().get("data-stream1")); + assertThat(parsedProject.dataStreams().get("data-stream1").getName(), is("data-stream1")); + assertThat(parsedProject.dataStreams().get("data-stream1").getIndices(), contains(idx1.getIndex())); + assertNotNull(parsedProject.dataStreams().get("data-stream2")); + assertThat(parsedProject.dataStreams().get("data-stream2").getName(), is("data-stream2")); + assertThat(parsedProject.dataStreams().get("data-stream2").getIndices(), contains(idx2.getIndex())); // reserved 'operator' metadata assertEquals(reservedStateMetadata, parsedMetadata.reservedStateMetadata().get(reservedStateMetadata.namespace())); @@ -364,20 +362,23 @@ public void testToXContentAPI_SameTypeName() throws IOException { .clusterUUID("clusterUUID") .coordinationMetadata(CoordinationMetadata.builder().build()) .put( - IndexMetadata.builder("index") - .state(IndexMetadata.State.OPEN) - .settings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping( - new MappingMetadata( - "type", - // the type name is the root value, - // the original logic in ClusterState.toXContent will reduce - Map.of("type", Map.of("key", "value")) - ) + ProjectMetadata.builder(ProjectId.DEFAULT) + .put( + IndexMetadata.builder("index") + .state(IndexMetadata.State.OPEN) + .settings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current())) + .putMapping( + new MappingMetadata( + "type", + // the type name is the root value, + // the original logic in ClusterState.toXContent will reduce + Map.of("type", Map.of("key", "value")) + ) + ) + .numberOfShards(1) + .primaryTerm(0, 1L) + .numberOfReplicas(2) ) - .numberOfShards(1) - .primaryTerm(0, 1L) - .numberOfReplicas(2) ) .build(); XContentBuilder builder = JsonXContent.contentBuilder().prettyPrint(); @@ -932,24 +933,27 @@ private Metadata buildMetadata() throws IOException { .persistentSettings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current()).build()) .transientSettings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current()).build()) .put( - IndexMetadata.builder("index") - .state(IndexMetadata.State.OPEN) - .settings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping(new MappingMetadata("type", Map.of("type1", Map.of("key", "value")))) - .putAlias(AliasMetadata.builder("alias").indexRouting("indexRouting").build()) - .numberOfShards(1) - .primaryTerm(0, 1L) - .putInSyncAllocationIds(0, Set.of("allocationId")) - .numberOfReplicas(2) - .putRolloverInfo(new RolloverInfo("rolloveAlias", List.of(), 1L)) - ) - .put( - IndexTemplateMetadata.builder("template") - .patterns(List.of("pattern1", "pattern2")) - .order(0) - .settings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping("type", "{ \"key1\": {} }") - .build() + ProjectMetadata.builder(ProjectId.DEFAULT) + .put( + IndexMetadata.builder("index") + .state(IndexMetadata.State.OPEN) + .settings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current())) + .putMapping(new MappingMetadata("type", Map.of("type1", Map.of("key", "value")))) + .putAlias(AliasMetadata.builder("alias").indexRouting("indexRouting").build()) + .numberOfShards(1) + .primaryTerm(0, 1L) + .putInSyncAllocationIds(0, Set.of("allocationId")) + .numberOfReplicas(2) + .putRolloverInfo(new RolloverInfo("rolloveAlias", List.of(), 1L)) + ) + .put( + IndexTemplateMetadata.builder("template") + .patterns(List.of("pattern1", "pattern2")) + .order(0) + .settings(Settings.builder().put(SETTING_VERSION_CREATED, IndexVersion.current())) + .putMapping("type", "{ \"key1\": {} }") + .build() + ) ) .build(); } diff --git a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java index 97b8b86ea6a85..7a474f528897c 100644 --- a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java @@ -9,6 +9,8 @@ package org.elasticsearch.cluster.project; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.Diff; @@ -17,10 +19,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.SimpleDiffableWireSerializationTestCase; +import org.elasticsearch.test.TransportVersionUtils; import java.io.IOException; import java.util.stream.IntStream; +import static org.hamcrest.Matchers.equalTo; + public class ProjectStateRegistrySerializationTests extends SimpleDiffableWireSerializationTestCase { @Override @@ -56,7 +61,7 @@ protected ClusterState.Custom mutateInstance(ClusterState.Custom instance) throw private ProjectStateRegistry mutate(ProjectStateRegistry instance) { if (randomBoolean() && instance.size() > 0) { // Remove or mutate a project's settings or deletion flag - var projectId = randomFrom(instance.getProjectsSettings().keySet()); + var projectId = randomFrom(instance.knownProjects()); var builder = ProjectStateRegistry.builder(instance); builder.putProjectSettings(projectId, randomSettings()); if (randomBoolean()) { @@ -86,4 +91,11 @@ public static Settings randomSettings() { IntStream.range(0, randomIntBetween(1, 5)).forEach(i -> builder.put(randomIdentifier(), randomIdentifier())); return builder.build(); } + + public void testProjectStateRegistryBwcSerialization() throws IOException { + ProjectStateRegistry projectStateRegistry = randomProjectStateRegistry(); + TransportVersion oldVersion = TransportVersionUtils.getPreviousVersion(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY); + ClusterState.Custom serialized = copyInstance(projectStateRegistry, oldVersion); + assertThat(serialized, equalTo(projectStateRegistry)); + } } diff --git a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java index a4d4dd6f2b154..5fdc73e9a6cf8 100644 --- a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java @@ -9,10 +9,13 @@ package org.elasticsearch.cluster.project; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; -import org.hamcrest.Matchers; import static org.elasticsearch.cluster.project.ProjectStateRegistrySerializationTests.randomSettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.sameInstance; public class ProjectStateRegistryTests extends ESTestCase { @@ -26,22 +29,63 @@ public void testBuilder() { ); var projectStateRegistry = builder.build(); var gen1 = projectStateRegistry.getProjectsMarkedForDeletionGeneration(); - assertThat(gen1, Matchers.equalTo(projectsUnderDeletion.isEmpty() ? 0L : 1L)); + assertThat(gen1, equalTo(projectsUnderDeletion.isEmpty() ? 0L : 1L)); projectStateRegistry = ProjectStateRegistry.builder(projectStateRegistry).markProjectForDeletion(randomFrom(projects)).build(); var gen2 = projectStateRegistry.getProjectsMarkedForDeletionGeneration(); - assertThat(gen2, Matchers.equalTo(gen1 + 1)); + assertThat(gen2, equalTo(gen1 + 1)); if (projectsUnderDeletion.isEmpty() == false) { // re-adding the same projectId should not change the generation projectStateRegistry = ProjectStateRegistry.builder(projectStateRegistry) .markProjectForDeletion(randomFrom(projectsUnderDeletion)) .build(); - assertThat(projectStateRegistry.getProjectsMarkedForDeletionGeneration(), Matchers.equalTo(gen2)); + assertThat(projectStateRegistry.getProjectsMarkedForDeletionGeneration(), equalTo(gen2)); } var unknownProjectId = randomUniqueProjectId(); var throwingBuilder = ProjectStateRegistry.builder(projectStateRegistry).markProjectForDeletion(unknownProjectId); assertThrows(IllegalArgumentException.class, throwingBuilder::build); } + + public void testDiff() { + ProjectStateRegistry originalRegistry = ProjectStateRegistry.builder() + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .build(); + + ProjectId newProjectId = randomUniqueProjectId(); + Settings newSettings = randomSettings(); + ProjectId projectToMarkForDeletion = randomFrom(originalRegistry.knownProjects()); + ProjectId projectToModifyId = randomFrom(originalRegistry.knownProjects()); + Settings modifiedSettings = randomSettings(); + + ProjectStateRegistry modifiedRegistry = ProjectStateRegistry.builder(originalRegistry) + .putProjectSettings(newProjectId, newSettings) + .markProjectForDeletion(projectToMarkForDeletion) + .putProjectSettings(projectToModifyId, modifiedSettings) + .build(); + + var diff = modifiedRegistry.diff(originalRegistry); + var appliedRegistry = (ProjectStateRegistry) diff.apply(originalRegistry); + + assertThat(appliedRegistry, equalTo(modifiedRegistry)); + assertThat(appliedRegistry.size(), equalTo(originalRegistry.size() + 1)); + assertTrue(appliedRegistry.knownProjects().contains(newProjectId)); + assertTrue(appliedRegistry.isProjectMarkedForDeletion(projectToMarkForDeletion)); + assertThat(appliedRegistry.getProjectSettings(newProjectId), equalTo(newSettings)); + assertThat(appliedRegistry.getProjectSettings(projectToModifyId), equalTo(modifiedSettings)); + } + + public void testDiffNoChanges() { + ProjectStateRegistry originalRegistry = ProjectStateRegistry.builder() + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .build(); + + var diff = originalRegistry.diff(originalRegistry); + var appliedRegistry = (ProjectStateRegistry) diff.apply(originalRegistry); + + assertThat(appliedRegistry, sameInstance(originalRegistry)); + } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimatorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimatorTests.java index 754b4d2b22d0d..fcc372a53f517 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimatorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimatorTests.java @@ -199,15 +199,7 @@ private static Metadata metadata(IndexMetadata.Builder... indices) { } private static ClusterInfo createClusterInfo(ShardRouting shard, Long size) { - return new ClusterInfo( - Map.of(), - Map.of(), - Map.of(ClusterInfo.shardIdentifierFromRouting(shard), size), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + return ClusterInfo.builder().shardSizes(Map.of(ClusterInfo.shardIdentifierFromRouting(shard), size)).build(); } private ClusterState buildRoutingTable(ClusterState state) { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java index 14e0aaa253749..4ce195721b228 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java @@ -72,15 +72,9 @@ public void testShardStats() { ) .build(); - var clusterInfo = new ClusterInfo( - Map.of(), - Map.of(), - Map.of(ClusterInfo.shardIdentifierFromRouting(shardId, true), currentShardSize), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + var clusterInfo = ClusterInfo.builder() + .shardSizes(Map.of(ClusterInfo.shardIdentifierFromRouting(shardId, true), currentShardSize)) + .build(); var queue = new DeterministicTaskQueue(); try (var clusterService = ClusterServiceUtils.createClusterService(state, queue.getThreadPool())) { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java index df0fa875a7249..c896d8a8f20fd 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java @@ -1580,7 +1580,7 @@ private static ClusterInfo clusterInfo( Map diskUsages, Map reservedSpace ) { - return new ClusterInfo(diskUsages, Map.of(), Map.of(), Map.of(), Map.of(), reservedSpace, Map.of()); + return ClusterInfo.builder().leastAvailableSpaceUsage(diskUsages).reservedSpace(reservedSpace).build(); } private static DiscoveryNode newFrozenOnlyNode(String nodeId) { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ExpectedShardSizeAllocationTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ExpectedShardSizeAllocationTests.java index f1a2b4b1358fe..4a79e76b944ab 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ExpectedShardSizeAllocationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/ExpectedShardSizeAllocationTests.java @@ -249,21 +249,17 @@ public void testExpectedSizeOnMove() { } private static ClusterInfo createClusterInfoWith(ShardId shardId, long size) { - return new ClusterInfo( - Map.of(), - Map.of(), - Map.ofEntries( - Map.entry(ClusterInfo.shardIdentifierFromRouting(shardId, true), size), - Map.entry(ClusterInfo.shardIdentifierFromRouting(shardId, false), size) - ), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + return ClusterInfo.builder() + .shardSizes( + Map.ofEntries( + Map.entry(ClusterInfo.shardIdentifierFromRouting(shardId, true), size), + Map.entry(ClusterInfo.shardIdentifierFromRouting(shardId, false), size) + ) + ) + .build(); } private static ClusterInfo createClusterInfo(Map diskUsage, Map shardSizes) { - return new ClusterInfo(diskUsage, diskUsage, shardSizes, Map.of(), Map.of(), Map.of(), Map.of()); + return ClusterInfo.builder().leastAvailableSpaceUsage(diskUsage).mostAvailableSpaceUsage(diskUsage).shardSizes(shardSizes).build(); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocatorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocatorTests.java index 8ab031aa53fe1..3667de9c65e4e 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocatorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocatorTests.java @@ -597,20 +597,16 @@ public void testShardSizeDiscrepancyWithinIndex() { var allocationService = createAllocationService( Settings.EMPTY, - () -> new ClusterInfo( - Map.of(), - Map.of(), - Map.of( - ClusterInfo.shardIdentifierFromRouting(new ShardId(index, 0), true), - 0L, - ClusterInfo.shardIdentifierFromRouting(new ShardId(index, 1), true), - ByteSizeUnit.GB.toBytes(500) - ), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ) + () -> ClusterInfo.builder() + .shardSizes( + Map.of( + ClusterInfo.shardIdentifierFromRouting(new ShardId(index, 0), true), + 0L, + ClusterInfo.shardIdentifierFromRouting(new ShardId(index, 1), true), + ByteSizeUnit.GB.toBytes(500) + ) + ) + .build() ); assertSame(clusterState, reroute(allocationService, clusterState)); @@ -705,7 +701,7 @@ private RoutingAllocation createRoutingAllocation(ClusterState clusterState) { } private static ClusterInfo createClusterInfo(Map indexSizes) { - return new ClusterInfo(Map.of(), Map.of(), indexSizes, Map.of(), Map.of(), Map.of(), Map.of()); + return ClusterInfo.builder().shardSizes(indexSizes).build(); } private static IndexMetadata.Builder anIndex(String name) { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterAllocationSimulationTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterAllocationSimulationTests.java index 277521c5832a1..6ef622948f5c5 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterAllocationSimulationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterAllocationSimulationTests.java @@ -561,7 +561,12 @@ public ClusterInfo getClusterInfo() { dataPath.put(new ClusterInfo.NodeAndShard(shardRouting.currentNodeId(), shardRouting.shardId()), "/data"); } - return new ClusterInfo(diskSpaceUsage, diskSpaceUsage, shardSizes, Map.of(), dataPath, Map.of(), Map.of()); + return ClusterInfo.builder() + .leastAvailableSpaceUsage(diskSpaceUsage) + .mostAvailableSpaceUsage(diskSpaceUsage) + .shardSizes(shardSizes) + .dataPath(dataPath) + .build(); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterBalanceStatsTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterBalanceStatsTests.java index 80fe603488fd3..00cf85609bf38 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterBalanceStatsTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterBalanceStatsTests.java @@ -329,25 +329,21 @@ private static Tuple startedIndex( } private ClusterInfo createClusterInfo(List> shardSizes) { - return new ClusterInfo( - Map.of(), - Map.of(), - shardSizes.stream() - .flatMap( - entry -> IntStream.range(0, entry.v2().length) - .mapToObj( - index -> Map.entry( - ClusterInfo.shardIdentifierFromRouting(new ShardId(entry.v1(), "_na_", index), true), - entry.v2()[index] + return ClusterInfo.builder() + .shardSizes( + shardSizes.stream() + .flatMap( + entry -> IntStream.range(0, entry.v2().length) + .mapToObj( + index -> Map.entry( + ClusterInfo.shardIdentifierFromRouting(new ShardId(entry.v1(), "_na_", index), true), + entry.v2()[index] + ) ) - ) - ) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + .build(); } private static Tuple indexSizes(String name, long... sizes) { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterInfoSimulatorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterInfoSimulatorTests.java index b67e248999ced..ea6a1522ec141 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterInfoSimulatorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterInfoSimulatorTests.java @@ -690,15 +690,12 @@ public ClusterInfoTestBuilder withReservedSpace(String nodeId, String path, long } public ClusterInfo build() { - return new ClusterInfo( - leastAvailableSpaceUsage, - mostAvailableSpaceUsage, - shardSizes, - Map.of(), - Map.of(), - reservedSpace, - Map.of() - ); + return ClusterInfo.builder() + .leastAvailableSpaceUsage(leastAvailableSpaceUsage) + .mostAvailableSpaceUsage(mostAvailableSpaceUsage) + .shardSizes(shardSizes) + .reservedSpace(reservedSpace) + .build(); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java index a0d28ce124584..d204c1c925d40 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java @@ -690,7 +690,12 @@ public void testDesiredBalanceShouldConvergeInABigCluster() { .stream() .collect(toMap(Map.Entry::getKey, it -> new DiskUsage(it.getKey(), it.getKey(), "/data", diskSize, diskSize - it.getValue()))); - var clusterInfo = new ClusterInfo(diskUsage, diskUsage, shardSizes, Map.of(), dataPath, Map.of(), Map.of()); + var clusterInfo = ClusterInfo.builder() + .leastAvailableSpaceUsage(diskUsage) + .mostAvailableSpaceUsage(diskUsage) + .shardSizes(shardSizes) + .dataPath(dataPath) + .build(); var settings = Settings.EMPTY; @@ -1196,7 +1201,12 @@ public ClusterInfoTestBuilder withReservedSpace(String nodeId, long size, ShardI } public ClusterInfo build() { - return new ClusterInfo(diskUsage, diskUsage, shardSizes, Map.of(), Map.of(), reservedSpace, Map.of()); + return ClusterInfo.builder() + .leastAvailableSpaceUsage(diskUsage) + .mostAvailableSpaceUsage(diskUsage) + .shardSizes(shardSizes) + .reservedSpace(reservedSpace) + .build(); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java index 844912cba4c17..1f8d59a958bfe 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java @@ -620,6 +620,8 @@ public void testUnassignedAllocationPredictsDiskUsage() { ImmutableOpenMap.of(), ImmutableOpenMap.of(), ImmutableOpenMap.of(), + ImmutableOpenMap.of(), + ImmutableOpenMap.of(), ImmutableOpenMap.of() ); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java index 5467d313834b8..c4ca84e6e977f 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java @@ -1406,7 +1406,17 @@ static class DevNullClusterInfo extends ClusterInfo { Map shardSizes, Map reservedSpace ) { - super(leastAvailableSpaceUsage, mostAvailableSpaceUsage, shardSizes, Map.of(), Map.of(), reservedSpace, Map.of()); + super( + leastAvailableSpaceUsage, + mostAvailableSpaceUsage, + shardSizes, + Map.of(), + Map.of(), + reservedSpace, + Map.of(), + Map.of(), + Map.of() + ); } @Override diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java index 7da75f61da801..debb4343931d7 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java @@ -109,6 +109,8 @@ public void testCanAllocateUsesMaxAvailableSpace() { Map.of(), Map.of(), Map.of(), + Map.of(), + Map.of(), Map.of() ); RoutingAllocation allocation = new RoutingAllocation( @@ -181,6 +183,8 @@ private void doTestCannotAllocateDueToLackOfDiskResources(boolean testMaxHeadroo Map.of(), Map.of(), Map.of(), + Map.of(), + Map.of(), Map.of() ); RoutingAllocation allocation = new RoutingAllocation( @@ -327,6 +331,8 @@ private void doTestCanRemainUsesLeastAvailableSpace(boolean testMaxHeadroom) { Map.of(), shardRoutingMap, Map.of(), + Map.of(), + Map.of(), Map.of() ); RoutingAllocation allocation = new RoutingAllocation( @@ -840,15 +846,11 @@ public void testDecidesYesIfWatermarksIgnored() { allFullUsages.put("node_0", new DiskUsage("node_0", "node_0", "_na_", 100, 0)); // all full allFullUsages.put("node_1", new DiskUsage("node_1", "node_1", "_na_", 100, 0)); // all full - final ClusterInfo clusterInfo = new ClusterInfo( - allFullUsages, - allFullUsages, - Map.of("[test][0][p]", 10L), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + final ClusterInfo clusterInfo = ClusterInfo.builder() + .leastAvailableSpaceUsage(allFullUsages) + .mostAvailableSpaceUsage(allFullUsages) + .shardSizes(Map.of("[test][0][p]", 10L)) + .build(); RoutingAllocation allocation = new RoutingAllocation( new AllocationDeciders(Collections.singleton(decider)), clusterState, @@ -908,15 +910,11 @@ public void testCannotForceAllocateOver100PercentUsage() { // bigger than available space final long shardSize = randomIntBetween(1, 10); shardSizes.put("[test][0][p]", shardSize); - ClusterInfo clusterInfo = new ClusterInfo( - leastAvailableUsages, - mostAvailableUsage, - shardSizes, - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + ClusterInfo clusterInfo = ClusterInfo.builder() + .leastAvailableSpaceUsage(leastAvailableUsages) + .mostAvailableSpaceUsage(mostAvailableUsage) + .shardSizes(shardSizes) + .build(); RoutingAllocation allocation = new RoutingAllocation( new AllocationDeciders(Collections.singleton(decider)), clusterState, diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/FrequencyCappedActionTests.java b/server/src/test/java/org/elasticsearch/common/FrequencyCappedActionTests.java similarity index 97% rename from server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/FrequencyCappedActionTests.java rename to server/src/test/java/org/elasticsearch/common/FrequencyCappedActionTests.java index 10b0f998ed046..1ecfd66a7d5d6 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/FrequencyCappedActionTests.java +++ b/server/src/test/java/org/elasticsearch/common/FrequencyCappedActionTests.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.cluster.routing.allocation.allocator; +package org.elasticsearch.common; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractAsyncTaskTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractAsyncTaskTests.java index ec8e0f69d89d8..a193bcd8a89ed 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractAsyncTaskTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractAsyncTaskTests.java @@ -59,12 +59,14 @@ protected void runInternal() { try { barrier1.await(); } catch (Exception e) { + logger.error("barrier1 interrupted", e); fail("interrupted"); } count.incrementAndGet(); try { barrier2.await(); } catch (Exception e) { + logger.error("barrier2 interrupted", e); fail("interrupted"); } if (shouldRunThrowException) { @@ -112,6 +114,7 @@ protected void runInternal() { try { barrier.await(); } catch (Exception e) { + logger.error("barrier interrupted", e); fail("interrupted"); } if (shouldRunThrowException) { diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java index 38b14b09cd82e..62d4d6d9cbc15 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java @@ -675,6 +675,7 @@ public void testScalingWithTaskTimeTracking() { final int max = between(min + 1, 6); { + var executionTimeEwma = randomDoubleBetween(0.01, 0.1, true); ThreadPoolExecutor pool = EsExecutors.newScaling( getClass().getName() + "/" + getTestName(), min, @@ -684,7 +685,9 @@ public void testScalingWithTaskTimeTracking() { randomBoolean(), EsExecutors.daemonThreadFactory("test"), threadContext, - new EsExecutors.TaskTrackingConfig(randomBoolean(), randomDoubleBetween(0.01, 0.1, true)) + randomBoolean() + ? EsExecutors.TaskTrackingConfig.builder().trackOngoingTasks().trackExecutionTime(executionTimeEwma).build() + : EsExecutors.TaskTrackingConfig.builder().trackExecutionTime(executionTimeEwma).build() ); assertThat(pool, instanceOf(TaskExecutionTimeTrackingEsThreadPoolExecutor.class)); } diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java index 7f720721aebf2..505c26409a702 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.metrics.ExponentialBucketHistogram; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig; import org.elasticsearch.telemetry.InstrumentType; import org.elasticsearch.telemetry.Measurement; import org.elasticsearch.telemetry.RecordingMeterRegistry; @@ -25,7 +24,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.Function; -import static org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig.DEFAULT_EWMA_ALPHA; +import static org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig.DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -51,7 +50,12 @@ public void testExecutionEWMACalculation() throws Exception { EsExecutors.daemonThreadFactory("queuetest"), new EsAbortPolicy(), context, - new TaskTrackingConfig(randomBoolean(), DEFAULT_EWMA_ALPHA) + randomBoolean() + ? EsExecutors.TaskTrackingConfig.builder() + .trackOngoingTasks() + .trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST) + .build() + : EsExecutors.TaskTrackingConfig.builder().trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST).build() ); executor.prestartAllCoreThreads(); logger.info("--> executor: {}", executor); @@ -89,6 +93,64 @@ public void testExecutionEWMACalculation() throws Exception { executor.awaitTermination(10, TimeUnit.SECONDS); } + public void testMaxQueueLatency() throws Exception { + ThreadContext context = new ThreadContext(Settings.EMPTY); + RecordingMeterRegistry meterRegistry = new RecordingMeterRegistry(); + final var threadPoolName = randomIdentifier(); + final var barrier = new CyclicBarrier(2); + var adjustableTimedRunnable = new AdjustableQueueTimeWithExecutionBarrierTimedRunnable( + barrier, + TimeUnit.NANOSECONDS.toNanos(1000000) + ); + TaskExecutionTimeTrackingEsThreadPoolExecutor executor = new TaskExecutionTimeTrackingEsThreadPoolExecutor( + "test-threadpool", + 1, + 1, + 1000, + TimeUnit.MILLISECONDS, + ConcurrentCollections.newBlockingQueue(), + (runnable) -> adjustableTimedRunnable, + EsExecutors.daemonThreadFactory("queue-latency-test"), + new EsAbortPolicy(), + context, + randomBoolean() + ? EsExecutors.TaskTrackingConfig.builder() + .trackOngoingTasks() + .trackMaxQueueLatency() + .trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST) + .build() + : EsExecutors.TaskTrackingConfig.builder() + .trackMaxQueueLatency() + .trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST) + .build() + ); + try { + executor.prestartAllCoreThreads(); + logger.info("--> executor: {}", executor); + + // Check that the max is zero initially and after a reset. + assertEquals("The queue latency should be initialized zero", 0, executor.getMaxQueueLatencyMillisSinceLastPollAndReset()); + executor.execute(() -> {}); + safeAwait(barrier); // Wait for the task to start, which means implies has finished the queuing stage. + assertEquals("Ran one task of 1ms, should be the max", 1, executor.getMaxQueueLatencyMillisSinceLastPollAndReset()); + assertEquals("The max was just reset, should be zero", 0, executor.getMaxQueueLatencyMillisSinceLastPollAndReset()); + + // Check that the max is kept across multiple calls, where the last is not the max. + adjustableTimedRunnable.setQueuedTimeTakenNanos(5000000); + executeTask(executor, 1); + safeAwait(barrier); // Wait for the task to start, which means implies has finished the queuing stage. + adjustableTimedRunnable.setQueuedTimeTakenNanos(1000000); + executeTask(executor, 1); + safeAwait(barrier); + assertEquals("Max should not be the last task", 5, executor.getMaxQueueLatencyMillisSinceLastPollAndReset()); + assertEquals("The max was just reset, should be zero", 0, executor.getMaxQueueLatencyMillisSinceLastPollAndReset()); + } finally { + // Clean up. + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + } + } + /** Use a runnable wrapper that simulates a task with unknown failures. */ public void testExceptionThrowingTask() throws Exception { ThreadContext context = new ThreadContext(Settings.EMPTY); @@ -103,7 +165,12 @@ public void testExceptionThrowingTask() throws Exception { EsExecutors.daemonThreadFactory("queuetest"), new EsAbortPolicy(), context, - new TaskTrackingConfig(randomBoolean(), DEFAULT_EWMA_ALPHA) + randomBoolean() + ? EsExecutors.TaskTrackingConfig.builder() + .trackOngoingTasks() + .trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST) + .build() + : EsExecutors.TaskTrackingConfig.builder().trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST).build() ); executor.prestartAllCoreThreads(); logger.info("--> executor: {}", executor); @@ -135,7 +202,10 @@ public void testGetOngoingTasks() throws Exception { EsExecutors.daemonThreadFactory("queuetest"), new EsAbortPolicy(), context, - new TaskTrackingConfig(true, DEFAULT_EWMA_ALPHA) + EsExecutors.TaskTrackingConfig.builder() + .trackOngoingTasks() + .trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST) + .build() ); var taskRunningLatch = new CountDownLatch(1); var exitTaskLatch = new CountDownLatch(1); @@ -156,7 +226,7 @@ public void testGetOngoingTasks() throws Exception { executor.awaitTermination(10, TimeUnit.SECONDS); } - public void testQueueLatencyMetrics() { + public void testQueueLatencyHistogramMetrics() { RecordingMeterRegistry meterRegistry = new RecordingMeterRegistry(); final var threadPoolName = randomIdentifier(); var executor = new TaskExecutionTimeTrackingEsThreadPoolExecutor( @@ -170,7 +240,10 @@ public void testQueueLatencyMetrics() { EsExecutors.daemonThreadFactory("queuetest"), new EsAbortPolicy(), new ThreadContext(Settings.EMPTY), - new TaskTrackingConfig(true, DEFAULT_EWMA_ALPHA) + EsExecutors.TaskTrackingConfig.builder() + .trackOngoingTasks() + .trackExecutionTime(DEFAULT_EXECUTION_TIME_EWMA_ALPHA_FOR_TEST) + .build() ); executor.setupMetrics(meterRegistry, threadPoolName); @@ -261,18 +334,18 @@ private void executeTask(TaskExecutionTimeTrackingEsThreadPoolExecutor executor, } public class SettableTimedRunnable extends TimedRunnable { - private final long timeTaken; + private final long executionTimeTakenNanos; private final boolean testFailedOrRejected; - public SettableTimedRunnable(long timeTaken, boolean failedOrRejected) { + public SettableTimedRunnable(long executionTimeTakenNanos, boolean failedOrRejected) { super(() -> {}); - this.timeTaken = timeTaken; + this.executionTimeTakenNanos = executionTimeTakenNanos; this.testFailedOrRejected = failedOrRejected; } @Override public long getTotalExecutionNanos() { - return timeTaken; + return executionTimeTakenNanos; } @Override @@ -280,4 +353,38 @@ public boolean getFailedOrRejected() { return testFailedOrRejected; } } + + /** + * This TimedRunnable override provides the following: + *

    + *
  • Overrides {@link TimedRunnable#getQueueTimeNanos()} so that arbitrary queue latencies can be set for the thread pool.
  • + *
  • Replaces any submitted Runnable task to the thread pool with a Runnable that only waits on a {@link CyclicBarrier}.
  • + *
+ * This allows dynamically manipulating the queue time with {@link #setQueuedTimeTakenNanos}, and provides a means of waiting for a task + * to start by calling {@code safeAwait(barrier)} after submitting a task. + *

+ * Look at {@link TaskExecutionTimeTrackingEsThreadPoolExecutor#wrapRunnable} for how the ThreadPool uses this as a wrapper around all + * submitted tasks. + */ + public class AdjustableQueueTimeWithExecutionBarrierTimedRunnable extends TimedRunnable { + private long queuedTimeTakenNanos; + + /** + * @param barrier A barrier that the caller can wait upon to ensure a task starts. + * @param queuedTimeTakenNanos The default queue time reported for all tasks. + */ + public AdjustableQueueTimeWithExecutionBarrierTimedRunnable(CyclicBarrier barrier, long queuedTimeTakenNanos) { + super(() -> { safeAwait(barrier); }); + this.queuedTimeTakenNanos = queuedTimeTakenNanos; + } + + public void setQueuedTimeTakenNanos(long timeTakenNanos) { + this.queuedTimeTakenNanos = timeTakenNanos; + } + + @Override + long getQueueTimeNanos() { + return queuedTimeTakenNanos; + } + } } diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java index 289e1a730db90..59fce66926fe2 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java @@ -34,12 +34,14 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.randomAsciiLettersOfLengthBetween; import static org.elasticsearch.tasks.Task.HEADERS_TO_COPY; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; @@ -1102,7 +1104,8 @@ public void testSanitizeHeaders() { final String authorizationHeader = randomCase("authorization"); final String authorizationHeader2 = randomCase("es-secondary-authorization"); final String authorizationHeader3 = randomCase("ES-Client-Authentication"); - Set possibleHeaders = Set.of(authorizationHeader, authorizationHeader2, authorizationHeader3); + final String authorizationHeader4 = randomCase("X-Client-Authentication"); + Set possibleHeaders = Set.of(authorizationHeader, authorizationHeader2, authorizationHeader3, authorizationHeader4); Set headers = randomizeHeaders ? randomSet(0, possibleHeaders.size(), () -> randomFrom(possibleHeaders)) : possibleHeaders; @@ -1161,6 +1164,75 @@ public void testNewEmptySystemContext() { assertNotNull(threadContext.getHeader(header)); } + public void testNewTraceContext() { + final var threadContext = new ThreadContext(Settings.EMPTY); + + var rootTraceContext = Map.of(Task.TRACE_PARENT_HTTP_HEADER, randomIdentifier(), Task.TRACE_STATE, randomIdentifier()); + var apmTraceContext = new Object(); + var responseKey = randomIdentifier(); + var responseValue = randomAlphaOfLength(10); + + threadContext.putHeader(rootTraceContext); + threadContext.putTransient(Task.APM_TRACE_CONTEXT, apmTraceContext); + + assertThat(threadContext.hasTraceContext(), equalTo(true)); + assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + + try (var ignored = threadContext.newTraceContext()) { + assertThat(threadContext.hasTraceContext(), equalTo(false)); // no trace started yet + assertThat(threadContext.hasParentTraceContext(), equalTo(true)); + + assertThat(threadContext.getHeaders(), is(anEmptyMap())); + assertThat( + threadContext.getTransientHeaders(), + equalTo( + Map.of( + Task.PARENT_TRACE_PARENT_HEADER, + rootTraceContext.get(Task.TRACE_PARENT_HTTP_HEADER), + Task.PARENT_TRACE_STATE, + rootTraceContext.get(Task.TRACE_STATE), + Task.PARENT_APM_TRACE_CONTEXT, + apmTraceContext + ) + ) + ); + // response headers shall be propagated + threadContext.addResponseHeader(responseKey, responseValue); + } + + assertThat(threadContext.hasTraceContext(), equalTo(true)); + assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + + assertThat(threadContext.getHeaders(), equalTo(rootTraceContext)); + assertThat(threadContext.getTransientHeaders(), equalTo(Map.of(Task.APM_TRACE_CONTEXT, apmTraceContext))); + assertThat(threadContext.getResponseHeaders(), equalTo(Map.of(responseKey, List.of(responseValue)))); + } + + public void testNewTraceContextWithoutParentTrace() { + final var threadContext = new ThreadContext(Settings.EMPTY); + + var responseKey = randomIdentifier(); + var responseValue = randomAlphaOfLength(10); + + assertThat(threadContext.hasTraceContext(), equalTo(false)); + assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + + try (var ignored = threadContext.newTraceContext()) { + assertTrue(threadContext.isDefaultContext()); + assertThat(threadContext.hasTraceContext(), equalTo(false)); + assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + + // discared, just making sure the context is isolated + threadContext.putTransient(randomIdentifier(), randomAlphaOfLength(10)); + // response headers shall be propagated + threadContext.addResponseHeader(responseKey, responseValue); + } + + assertThat(threadContext.getHeaders(), is(anEmptyMap())); + assertThat(threadContext.getTransientHeaders(), is(anEmptyMap())); + assertThat(threadContext.getResponseHeaders(), equalTo(Map.of(responseKey, List.of(responseValue)))); + } + public void testRestoreExistingContext() { final var threadContext = new ThreadContext(Settings.EMPTY); final var header = randomIdentifier(); diff --git a/server/src/test/java/org/elasticsearch/index/engine/ThreadPoolMergeExecutorServiceDiskSpaceTests.java b/server/src/test/java/org/elasticsearch/index/engine/ThreadPoolMergeExecutorServiceDiskSpaceTests.java index 33a86ef5709ad..9f34c56cf041e 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/ThreadPoolMergeExecutorServiceDiskSpaceTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/ThreadPoolMergeExecutorServiceDiskSpaceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.PathUtils; import org.elasticsearch.core.PathUtilsForTesting; @@ -40,6 +41,7 @@ import java.util.IdentityHashMap; import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -898,8 +900,8 @@ public void testEnqueuedMergeTasksAreUnblockedWhenEstimatedMergeSizeChanges() th assertBusy( () -> assertThat(threadPoolMergeExecutorService.getDiskSpaceAvailableForNewMergeTasks(), is(expectedAvailableBudget.get())) ); - List tasksRunList = new ArrayList<>(); - List tasksAbortList = new ArrayList<>(); + Set tasksRunList = ConcurrentCollections.newConcurrentSet(); + Set tasksAbortList = ConcurrentCollections.newConcurrentSet(); int submittedMergesCount = randomIntBetween(1, 5); long[] mergeSizeEstimates = new long[submittedMergesCount]; for (int i = 0; i < submittedMergesCount; i++) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java index 73d76ad48c955..6a81a93923abc 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java @@ -125,7 +125,7 @@ private static void testBoundsBlockLoaderAux( for (int j : array) { expected.add(visitor.apply(geometries.get(j + currentIndex)).get()); } - try (var block = (TestBlock) loader.reader(leaf).read(TestBlock.factory(leafReader.numDocs()), TestBlock.docs(array))) { + try (var block = (TestBlock) loader.reader(leaf).read(TestBlock.factory(), TestBlock.docs(array), 0)) { for (int i = 0; i < block.size(); i++) { intArrayResults.add(block.get(i)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java b/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java index 357ada3ad656d..1fa9c85a5c738 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java @@ -59,7 +59,7 @@ private void loadBlock(LeafReaderContext ctx, Consumer test) throws I StoredFieldLoader.fromSpec(loader.rowStrideStoredFieldSpec()).getLoader(ctx, null), loader.rowStrideStoredFieldSpec().requiresSource() ? SourceLoader.FROM_STORED_SOURCE.leaf(ctx.reader(), null) : null ); - BlockLoader.Builder builder = loader.builder(TestBlock.factory(ctx.reader().numDocs()), 1); + BlockLoader.Builder builder = loader.builder(TestBlock.factory(), 1); storedFields.advanceTo(0); reader.read(0, storedFields, builder); TestBlock block = (TestBlock) builder.build(); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java index ce9a9bc0688f3..54656ab1af3ee 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java @@ -446,7 +446,8 @@ public void testBlockLoader() throws IOException { try (DirectoryReader reader = iw.getReader()) { BooleanScriptFieldType fieldType = build("xor_param", Map.of("param", false), OnScriptError.FAIL); List expected = List.of(false, true); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(expected.subList(1, 2))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(expected)); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java index 3d8ed5ea60262..1eb0ba07d58e2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java @@ -493,9 +493,10 @@ public void testBlockLoader() throws IOException { try (DirectoryReader reader = iw.getReader()) { DateScriptFieldType fieldType = build("add_days", Map.of("days", 1), OnScriptError.FAIL); assertThat( - blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), + blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(1595518581354L, 1595518581355L)) ); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(1595518581355L))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(1595518581354L, 1595518581355L))); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java index 140137015d98a..b1cda53876993 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java @@ -262,7 +262,8 @@ public void testBlockLoader() throws IOException { ); try (DirectoryReader reader = iw.getReader()) { DoubleScriptFieldType fieldType = build("add_param", Map.of("param", 1), OnScriptError.FAIL); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(List.of(2d, 3d))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(2d, 3d))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(3d))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(2d, 3d))); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldFilterMapperPluginTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldFilterMapperPluginTests.java index c17c0d10410fa..9fa0760e2ae2a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldFilterMapperPluginTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldFilterMapperPluginTests.java @@ -35,8 +35,8 @@ import java.util.Set; import java.util.function.Function; -import static org.elasticsearch.cluster.metadata.MetadataTests.assertLeafs; -import static org.elasticsearch.cluster.metadata.MetadataTests.assertMultiField; +import static org.elasticsearch.cluster.metadata.ProjectMetadataTests.assertLeafs; +import static org.elasticsearch.cluster.metadata.ProjectMetadataTests.assertMultiField; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; public class FieldFilterMapperPluginTests extends ESSingleNodeTestCase { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java index 281d2993fa29c..7e9a236f6cc74 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java @@ -273,7 +273,8 @@ public void testBlockLoader() throws IOException { new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.1"))), new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.1.1"))) ); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(expected.subList(1, 2))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(expected)); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java index 57d52991a6442..ccc8ccac4deb4 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java @@ -409,9 +409,10 @@ public void testBlockLoader() throws IOException { try (DirectoryReader reader = iw.getReader()) { KeywordScriptFieldType fieldType = build("append_param", Map.of("param", "-Suffix"), OnScriptError.FAIL); assertThat( - blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), + blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(new BytesRef("1-Suffix"), new BytesRef("2-Suffix"))) ); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(new BytesRef("2-Suffix")))); assertThat( blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(new BytesRef("1-Suffix"), new BytesRef("2-Suffix"))) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java index a8cb4d51c5efa..01f96a1a4b1be 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java @@ -295,7 +295,8 @@ public void testBlockLoader() throws IOException { ); try (DirectoryReader reader = iw.getReader()) { LongScriptFieldType fieldType = build("add_param", Map.of("param", 1), OnScriptError.FAIL); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(List.of(2L, 3L))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(2L, 3L))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(3L))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(2L, 3L))); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index f79c14c831f86..bb7de99b72249 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1367,7 +1367,7 @@ public void testIVFParsing() throws IOException { b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); - b.field("type", "bbq_ivf"); + b.field("type", "bbq_disk"); b.endObject(); })); @@ -1386,7 +1386,7 @@ public void testIVFParsing() throws IOException { b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); - b.field("type", "bbq_ivf"); + b.field("type", "bbq_disk"); b.field("cluster_size", 1000); b.field("default_n_probe", 10); b.field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 2.0f)); @@ -1413,7 +1413,7 @@ public void testIVFParsingFailureInRelease() { b -> b.field("type", "dense_vector") .field("dims", dims) .startObject("index_options") - .field("type", "bbq_ivf") + .field("type", "bbq_disk") .endObject() ) ) @@ -2815,7 +2815,7 @@ public void testKnnBBQIVFVectorsFormat() throws IOException { b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); - b.field("type", "bbq_ivf"); + b.field("type", "bbq_disk"); b.endObject(); })); CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index cc682901876b6..1d59d44c3def7 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -78,6 +78,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.codec.CodecService; +import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec; import org.elasticsearch.index.engine.CommitStats; import org.elasticsearch.index.engine.DocIdSeqNoAndSource; import org.elasticsearch.index.engine.Engine; @@ -1882,8 +1883,12 @@ public void testShardFieldStats() throws IOException { assertThat(stats.numSegments(), equalTo(0)); assertThat(stats.totalFields(), equalTo(0)); assertThat(stats.fieldUsages(), equalTo(0L)); + assertThat(stats.postingsInMemoryBytes(), equalTo(0L)); + + boolean postingsBytesTrackingEnabled = TrackingPostingsInMemoryBytesCodec.TRACK_POSTINGS_IN_MEMORY_BYTES.isEnabled(); + // index some documents - int numDocs = between(1, 10); + int numDocs = between(2, 10); for (int i = 0; i < numDocs; i++) { indexDoc(shard, "_doc", "first_" + i, """ { @@ -1901,6 +1906,9 @@ public void testShardFieldStats() throws IOException { // _id(term), _source(0), _version(dv), _primary_term(dv), _seq_no(point,dv), f1(postings,norms), // f1.keyword(term,dv), f2(postings,norms), f2.keyword(term,dv), assertThat(stats.fieldUsages(), equalTo(13L)); + // _id: (5,8), f1: 3, f1.keyword: 3, f2: 3, f2.keyword: 3 + // 5 + 8 + 3 + 3 + 3 + 3 = 25 + assertThat(stats.postingsInMemoryBytes(), equalTo(postingsBytesTrackingEnabled ? 25L : 0L)); // don't re-compute on refresh without change if (randomBoolean()) { shard.refresh("test"); @@ -1919,11 +1927,18 @@ public void testShardFieldStats() throws IOException { assertThat(shard.getShardFieldStats(), sameInstance(stats)); // index more docs numDocs = between(1, 10); + indexDoc(shard, "_doc", "first_0", """ + { + "f1": "lorem", + "f2": "bar", + "f3": "sit amet" + } + """); for (int i = 0; i < numDocs; i++) { - indexDoc(shard, "_doc", "first_" + i, """ + indexDoc(shard, "_doc", "first_" + i + 1, """ { "f1": "foo", - "f2": "bar", + "f2": "ipsum", "f3": "foobar" } """); @@ -1948,6 +1963,11 @@ public void testShardFieldStats() throws IOException { assertThat(stats.totalFields(), equalTo(21)); // first segment: 13, second segment: 13 + f3(postings,norms) + f3.keyword(term,dv), and __soft_deletes to previous segment assertThat(stats.fieldUsages(), equalTo(31L)); + // segment 1: 25 (see above) + // segment 2: _id: (5,6), f1: (3,5), f1.keyword: (3,5), f2: (3,5), f2.keyword: (3,5), f3: (4,3), f3.keyword: (6,8) + // (5+6) + (3+5) + (3+5) + (3+5) + (3+5) + (4+3) + (6+8) = 64 + // 25 + 64 = 89 + assertThat(stats.postingsInMemoryBytes(), equalTo(postingsBytesTrackingEnabled ? 89L : 0L)); shard.forceMerge(new ForceMergeRequest().maxNumSegments(1).flush(true)); stats = shard.getShardFieldStats(); assertThat(stats.numSegments(), equalTo(1)); @@ -1955,6 +1975,8 @@ public void testShardFieldStats() throws IOException { // _id(term), _source(0), _version(dv), _primary_term(dv), _seq_no(point,dv), f1(postings,norms), // f1.keyword(term,dv), f2(postings,norms), f2.keyword(term,dv), f3(postings,norms), f3.keyword(term,dv), __soft_deletes assertThat(stats.fieldUsages(), equalTo(18L)); + // _id: (5,8), f1: (3,5), f1.keyword: (3,5), f2: (3,5), f2.keyword: (3,5), f3: (4,3), f3.keyword: (6,8) + assertThat(stats.postingsInMemoryBytes(), equalTo(postingsBytesTrackingEnabled ? 66L : 0L)); closeShards(shard); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java b/server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java similarity index 87% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java rename to server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java index f3800f91d9a54..0d71165823e89 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java +++ b/server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.telemetry; +package org.elasticsearch.inference.telemetry; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.inference.Model; @@ -22,9 +24,9 @@ import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.create; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.create; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.assertArg; @@ -35,9 +37,13 @@ public class InferenceStatsTests extends ESTestCase { + public static InferenceStats mockInferenceStats() { + return new InferenceStats(mock(), mock(), mock()); + } + public void testRecordWithModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -49,7 +55,7 @@ public void testRecordWithModel() { public void testRecordWithoutModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null))); @@ -63,7 +69,7 @@ public void testCreation() { public void testRecordDurationWithoutError() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -88,7 +94,7 @@ public void testRecordDurationWithoutError() { public void testRecordDurationWithElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -116,7 +122,7 @@ public void testRecordDurationWithElasticsearchStatusException() { public void testRecordDurationWithOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -138,7 +144,7 @@ public void testRecordDurationWithOtherException() { public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -163,7 +169,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() public void testRecordDurationWithUnparsedModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -187,7 +193,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() { public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -206,7 +212,7 @@ public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() public void testRecordDurationWithUnknownModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); diff --git a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java index b79f2f6517189..c354f7a7d1991 100644 --- a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.NodesShutdownMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -1087,7 +1088,12 @@ public Scope scope() { } @Override - public Assignment getAssignment(P params, Collection candidateNodes, ClusterState clusterState) { + protected Assignment doGetAssignment( + P params, + Collection candidateNodes, + ClusterState clusterState, + ProjectId projectId + ) { return fn.apply(params, candidateNodes, clusterState); } diff --git a/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java b/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java index e3189de94b1a6..a6e059444e4da 100644 --- a/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java +++ b/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java @@ -25,6 +25,7 @@ import org.elasticsearch.client.internal.ElasticsearchClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; @@ -326,12 +327,17 @@ public static void setNonClusterStateCondition(boolean nonClusterStateCondition) } @Override - public Assignment getAssignment(TestParams params, Collection candidateNodes, ClusterState clusterState) { + protected Assignment doGetAssignment( + TestParams params, + Collection candidateNodes, + ClusterState clusterState, + ProjectId projectId + ) { if (nonClusterStateCondition == false) { return new Assignment(null, "non cluster state condition prevents assignment"); } if (params == null || params.getExecutorNodeAttr() == null) { - return super.getAssignment(params, candidateNodes, clusterState); + return super.doGetAssignment(params, candidateNodes, clusterState, projectId); } else { DiscoveryNode executorNode = selectLeastLoadedNode( clusterState, diff --git a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java index ac9eae8190619..ae62957e33feb 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -118,6 +118,58 @@ public void testApplyContentParser() throws IOException { assertEquals(emptyMap(), source.get()); } + public void testParamAsIntWithNoParameters() { + RestRequest restRequest = contentRestRequest("", emptyMap()); + int defaultValue = randomInt(); + String parameterKey = randomIdentifier(); + + int value = restRequest.paramAsInt(parameterKey, defaultValue); + assertEquals(defaultValue, value); + } + + public void testParamAsIntWithIntegerParameter() { + String parameterKey = randomIdentifier(); + RestRequest restRequest = contentRestRequest("", singletonMap(parameterKey, "123")); + int defaultValue = randomInt(); + + int value = restRequest.paramAsInt(parameterKey, defaultValue); + assertEquals(123, value); + } + + public void testParamAsIntWithNonIntegerParameter() { + String parameterKey = randomIdentifier(); + RestRequest restRequest = contentRestRequest("", singletonMap(parameterKey, "123T")); + int defaultValue = randomInt(); + + assertThrows(IllegalArgumentException.class, () -> restRequest.paramAsInt(parameterKey, defaultValue)); + } + + public void testParamAsIntegerWithNoParameters() { + RestRequest restRequest = contentRestRequest("", emptyMap()); + int defaultValue = randomInt(); + String parameterKey = randomIdentifier(); + + Integer value2 = restRequest.paramAsInteger(parameterKey, defaultValue); + assertEquals(defaultValue, value2.intValue()); + } + + public void testParamAsIntegerWithIntegerParameter() { + String parameterKey = randomIdentifier(); + RestRequest restRequest = contentRestRequest("", singletonMap(parameterKey, "123")); + int defaultValue = randomInt(); + + Integer value2 = restRequest.paramAsInteger(parameterKey, defaultValue); + assertEquals(123, value2.intValue()); + } + + public void testParamAsIntegerWithNonIntegerParameter() { + String parameterKey = randomIdentifier(); + RestRequest restRequest = contentRestRequest("", singletonMap(parameterKey, "123T")); + int defaultValue = randomInt(); + + assertThrows(IllegalArgumentException.class, () -> restRequest.paramAsInteger(parameterKey, defaultValue)); + } + public void testContentOrSourceParam() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> contentRestRequest("", emptyMap()).contentOrSourceParam()); assertEquals("request body or source parameter is required", e.getMessage()); diff --git a/server/src/test/java/org/elasticsearch/search/sort/GeoDistanceSortBuilderTests.java b/server/src/test/java/org/elasticsearch/search/sort/GeoDistanceSortBuilderTests.java index 17a9fb5974176..5bef1f4769cff 100644 --- a/server/src/test/java/org/elasticsearch/search/sort/GeoDistanceSortBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/sort/GeoDistanceSortBuilderTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.mapper.GeoPointFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NestedPathFieldMapper; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.GeoValidationMethod; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; @@ -99,6 +100,9 @@ public static GeoDistanceSortBuilder randomGeoDistanceSortBuilder() { @Override protected MappedFieldType provideMappedFieldType(String name) { + if (name.equals("double")) { + return new NumberFieldMapper.NumberFieldType(name, NumberFieldMapper.NumberType.DOUBLE); + } return new GeoPointFieldMapper.GeoPointFieldType(name); } @@ -531,6 +535,12 @@ public void testBuildInvalidPoints() throws IOException { ); assertEquals("illegal longitude value [-360.0] for [GeoDistanceSort] for field [fieldName].", ex.getMessage()); } + { + GeoDistanceSortBuilder sortBuilder = new GeoDistanceSortBuilder("double", 0.0, 180.0); + sortBuilder.validation(GeoValidationMethod.STRICT); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> sortBuilder.build(searchExecutionContext)); + assertEquals("unable to apply geo distance sort to field [double] of type [double]", ex.getMessage()); + } } /** diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index be7ffdc60d2ea..86436a0852f58 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -68,6 +68,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; public class TaskManagerTests extends ESTestCase { @@ -281,7 +282,69 @@ public void testTaskAccounting() { /** * Check that registering a task also causes tracing to be started on that task. */ - public void testRegisterTaskStartsTracing() { + public void testRegisterTaskStartsTracingIfTraceParentExists() { + final Tracer mockTracer = mock(Tracer.class); + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + + // fake a trace parent + threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + final boolean hasParentTask = randomBoolean(); + final TaskId parentTask = hasParentTask ? new TaskId("parentNode", 1) : TaskId.EMPTY_TASK_ID; + + try (var ignored = threadPool.getThreadContext().newTraceContext()) { + + final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() { + + @Override + public void setParentTask(TaskId taskId) {} + + @Override + public void setRequestId(long requestId) {} + + @Override + public TaskId getParentTask() { + return parentTask; + } + }); + + Map attributes = hasParentTask + ? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString()) + : Map.of(Tracer.AttributeKeys.TASK_ID, task.getId()); + verify(mockTracer).startTrace(any(), eq(task), eq("testAction"), eq(attributes)); + } + } + + /** + * Check that registering a task also causes tracing to be started on that task. + */ + public void testRegisterTaskSkipsTracingIfTraceParentMissing() { + final Tracer mockTracer = mock(Tracer.class); + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + + // no trace parent + try (var ignored = threadPool.getThreadContext().newTraceContext()) { + final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() { + + @Override + public void setParentTask(TaskId taskId) {} + + @Override + public void setRequestId(long requestId) {} + + @Override + public TaskId getParentTask() { + return TaskId.EMPTY_TASK_ID; + } + }); + } + + verifyNoInteractions(mockTracer); + } + + /** + * Check that unregistering a task also causes tracing to be stopped on that task. + */ + public void testUnregisterTaskStopsTracingIfTraceContextExists() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); @@ -299,13 +362,17 @@ public TaskId getParentTask() { } }); - verify(mockTracer).startTrace(any(), eq(task), eq("testAction"), anyMap()); + // fake a trace context (trace parent) + threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + + taskManager.unregister(task); + verify(mockTracer).stopTrace(task); } /** * Check that unregistering a task also causes tracing to be stopped on that task. */ - public void testUnregisterTaskStopsTracing() { + public void testUnregisterTaskStopsTracingIfTraceContextMissing() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); @@ -323,18 +390,22 @@ public TaskId getParentTask() { } }); - taskManager.unregister(task); + // no trace context - verify(mockTracer).stopTrace(task); + taskManager.unregister(task); + verifyNoInteractions(mockTracer); } /** - * Check that registering and executing a task also causes tracing to be started and stopped on that task. + * Check that registering and executing a task also causes tracing to be started if a trace parent exists. */ - public void testRegisterAndExecuteStartsAndStopsTracing() { + public void testRegisterAndExecuteStartsTracingIfTraceParentExists() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + // fake a trace parent + threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + final Task task = taskManager.registerAndExecute( "testType", new TransportAction( @@ -369,25 +440,68 @@ public TaskId getParentTask() { verify(mockTracer).startTrace(any(), eq(task), eq("actionName"), anyMap()); } + /** + * Check that registering and executing a task skips tracing if trace parent does not exists. + */ + public void testRegisterAndExecuteSkipsTracingIfTraceParentMissing() { + final Tracer mockTracer = mock(Tracer.class); + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + + // clean thread context without trace parent + + final Task task = taskManager.registerAndExecute( + "testType", + new TransportAction( + "actionName", + new ActionFilters(Set.of()), + taskManager, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ) { + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + listener.onResponse(new ActionResponse() { + @Override + public void writeTo(StreamOutput out) {} + }); + } + }, + new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public TaskId getParentTask() { + return TaskId.EMPTY_TASK_ID; + } + }, + null, + ActionTestUtils.assertNoFailureListener(r -> {}) + ); + + verifyNoInteractions(mockTracer); + } + public void testRegisterWithEnabledDisabledTracing() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = spy(new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer)); taskManager.register("type", "action", makeTaskRequest(true, 123), false); - verify(taskManager, times(0)).startTrace(any(), any()); + verify(taskManager, times(0)).maybeStartTrace(any(), any()); taskManager.register("type", "action", makeTaskRequest(false, 234), false); - verify(taskManager, times(0)).startTrace(any(), any()); + verify(taskManager, times(0)).maybeStartTrace(any(), any()); clearInvocations(taskManager); taskManager.register("type", "action", makeTaskRequest(true, 345), true); - verify(taskManager, times(1)).startTrace(any(), any()); + verify(taskManager, times(1)).maybeStartTrace(any(), any()); clearInvocations(taskManager); taskManager.register("type", "action", makeTaskRequest(false, 456), true); - verify(taskManager, times(1)).startTrace(any(), any()); + verify(taskManager, times(1)).maybeStartTrace(any(), any()); } static class CancellableRequest extends AbstractTransportRequest { diff --git a/server/src/test/java/org/elasticsearch/threadpool/ThreadPoolTests.java b/server/src/test/java/org/elasticsearch/threadpool/ThreadPoolTests.java index ad86c1159f426..2cd166e002637 100644 --- a/server/src/test/java/org/elasticsearch/threadpool/ThreadPoolTests.java +++ b/server/src/test/java/org/elasticsearch/threadpool/ThreadPoolTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.common.util.concurrent.TaskExecutionTimeTrackingEsThreadPoolExecutor; +import org.elasticsearch.common.util.concurrent.TaskExecutionTimeTrackingEsThreadPoolExecutor.UtilizationTrackingPurpose; import org.elasticsearch.core.TimeValue; import org.elasticsearch.node.Node; import org.elasticsearch.telemetry.InstrumentType; @@ -509,13 +510,17 @@ public void testDetailedUtilizationMetric() throws Exception { final long beforePreviousCollectNanos = System.nanoTime(); meterRegistry.getRecorder().collect(); + double allocationUtilization = executor.pollUtilization(UtilizationTrackingPurpose.ALLOCATION); final long afterPreviousCollectNanos = System.nanoTime(); - metricAsserter.assertLatestMetricValueMatches( + + var metricValue = metricAsserter.assertLatestMetricValueMatches( InstrumentType.DOUBLE_GAUGE, ThreadPool.THREAD_POOL_METRIC_NAME_UTILIZATION, Measurement::getDouble, equalTo(0.0d) ); + logger.info("---> Utilization metric data points, APM: " + metricValue + ", Allocation: " + allocationUtilization); + assertThat(allocationUtilization, equalTo(0.0d)); final AtomicLong minimumDurationNanos = new AtomicLong(Long.MAX_VALUE); final long beforeStartNanos = System.nanoTime(); @@ -535,6 +540,7 @@ public void testDetailedUtilizationMetric() throws Exception { final long beforeMetricsCollectedNanos = System.nanoTime(); meterRegistry.getRecorder().collect(); + allocationUtilization = executor.pollUtilization(UtilizationTrackingPurpose.ALLOCATION); final long afterMetricsCollectedNanos = System.nanoTime(); // Calculate upper bound on utilisation metric @@ -549,12 +555,14 @@ public void testDetailedUtilizationMetric() throws Exception { logger.info("Utilization must be in [{}, {}]", minimumUtilization, maximumUtilization); Matcher matcher = allOf(greaterThan(minimumUtilization), lessThan(maximumUtilization)); - metricAsserter.assertLatestMetricValueMatches( + metricValue = metricAsserter.assertLatestMetricValueMatches( InstrumentType.DOUBLE_GAUGE, ThreadPool.THREAD_POOL_METRIC_NAME_UTILIZATION, Measurement::getDouble, matcher ); + logger.info("---> Utilization metric data points, APM: " + metricValue + ", Allocation: " + allocationUtilization); + assertThat(allocationUtilization, matcher); } finally { ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); } @@ -665,7 +673,7 @@ void assertLatestLongValueMatches(String metricName, InstrumentType instrumentTy assertLatestMetricValueMatches(instrumentType, metricName, Measurement::getLong, matcher); } - void assertLatestMetricValueMatches( + T assertLatestMetricValueMatches( InstrumentType instrumentType, String name, Function valueExtractor, @@ -674,7 +682,9 @@ void assertLatestMetricValueMatches( List measurements = meterRegistry.getRecorder() .getMeasurements(instrumentType, ThreadPool.THREAD_POOL_METRIC_PREFIX + threadPoolName + name); assertFalse(name + " has no measurements", measurements.isEmpty()); - assertThat(valueExtractor.apply(measurements.getLast()), matcher); + var latestMetric = valueExtractor.apply(measurements.getLast()); + assertThat(latestMetric, matcher); + return latestMetric; } } diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java index b8de02293f734..76e280c987ae1 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java @@ -170,7 +170,7 @@ public void testGroupClusterIndices() throws IOException { assertFalse(service.isRemoteClusterRegistered("foo")); { Map> perClusterIndices = service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "cluster_1:bar", "cluster_2:foo:bar", @@ -191,7 +191,7 @@ public void testGroupClusterIndices() throws IOException { expectThrows( NoSuchRemoteClusterException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "foo:bar", "cluster_1:bar", "cluster_2:foo:bar", "cluster_1:test", "cluster_2:foo*", "foo" } ) ); @@ -199,7 +199,7 @@ public void testGroupClusterIndices() throws IOException { expectThrows( NoSuchRemoteClusterException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "cluster_1:bar", "cluster_2:foo:bar", "cluster_1:test", "cluster_2:foo*", "does_not_exist:*" } ) ); @@ -208,7 +208,10 @@ public void testGroupClusterIndices() throws IOException { { String[] indices = shuffledList(List.of("cluster*:foo*", "foo", "-cluster_1:*", "*:boo")).toArray(new String[0]); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(2, perClusterIndices.size()); List localIndexes = perClusterIndices.get(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); assertNotNull(localIndexes); @@ -223,7 +226,10 @@ public void testGroupClusterIndices() throws IOException { { String[] indices = shuffledList(List.of("*:*", "-clu*_1:*", "*:boo")).toArray(new String[0]); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(1, perClusterIndices.size()); List cluster2 = perClusterIndices.get("cluster_2"); @@ -236,7 +242,10 @@ public void testGroupClusterIndices() throws IOException { new String[0] ); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(1, perClusterIndices.size()); List localIndexes = perClusterIndices.get(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); assertNotNull(localIndexes); @@ -246,7 +255,10 @@ public void testGroupClusterIndices() throws IOException { { String[] indices = shuffledList(List.of("cluster*:*", "foo", "-*:*")).toArray(new String[0]); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(1, perClusterIndices.size()); List localIndexes = perClusterIndices.get(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); assertNotNull(localIndexes); @@ -257,7 +269,7 @@ public void testGroupClusterIndices() throws IOException { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), // -cluster_1:foo* is not allowed, only -cluster_1:* new String[] { "cluster_1:bar", "-cluster_2:foo*", "cluster_1:test", "cluster_2:foo*", "foo" } ) @@ -271,7 +283,7 @@ public void testGroupClusterIndices() throws IOException { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), // -cluster_1:* will fail since cluster_1 was never included in order to qualify to be excluded new String[] { "-cluster_1:*", "cluster_2:foo*", "foo" } ) @@ -287,7 +299,7 @@ public void testGroupClusterIndices() throws IOException { { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> service.groupClusterIndices(service.getRemoteClusterNames(), new String[] { "-cluster_1:*" }) + () -> service.groupClusterIndices(service.getRegisteredRemoteClusterNames(), new String[] { "-cluster_1:*" }) ); assertThat( e.getMessage(), @@ -300,7 +312,7 @@ public void testGroupClusterIndices() throws IOException { { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> service.groupClusterIndices(service.getRemoteClusterNames(), new String[] { "-*:*" }) + () -> service.groupClusterIndices(service.getRegisteredRemoteClusterNames(), new String[] { "-*:*" }) ); assertThat( e.getMessage(), @@ -315,7 +327,7 @@ public void testGroupClusterIndices() throws IOException { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> service.groupClusterIndices(service.getRemoteClusterNames(), indices) + () -> service.groupClusterIndices(service.getRegisteredRemoteClusterNames(), indices) ); assertThat( e.getMessage(), @@ -394,7 +406,7 @@ public void testGroupIndices() throws IOException { expectThrows( NoSuchRemoteClusterException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "foo:bar", "cluster_1:bar", "cluster_2:foo:bar", "cluster_1:test", "cluster_2:foo*", "foo" } ) ); diff --git a/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java b/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java index 1a4299b0ce938..5f85dc8f3bec1 100644 --- a/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java +++ b/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java @@ -28,8 +28,8 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.lessThanOrEqualTo; public class EsqlPartialResultsIT extends ESRestTestCase { @ClassRule @@ -106,7 +106,11 @@ public void testPartialResult() throws Exception { Set okIds = populateIndices(); String query = """ { - "query": "FROM ok-index,failing-index | LIMIT 100 | KEEP fail_me,v" + "query": "FROM ok-index,failing-index | LIMIT 100 | KEEP fail_me,v", + "pragma": { + "max_concurrent_shards_per_node": 1 + }, + "accept_pragma_risks": true } """; // allow_partial_results = true @@ -123,7 +127,7 @@ public void testPartialResult() throws Exception { List columns = (List) results.get("columns"); assertThat(columns, equalTo(List.of(Map.of("name", "fail_me", "type", "long"), Map.of("name", "v", "type", "long")))); List values = (List) results.get("values"); - assertThat(values.size(), lessThanOrEqualTo(okIds.size())); + assertThat(values.size(), equalTo(okIds.size())); Map localInfo = (Map) XContentMapValues.extractValue( results, "_clusters", @@ -131,11 +135,10 @@ public void testPartialResult() throws Exception { "(local)" ); assertNotNull(localInfo); - assertThat(XContentMapValues.extractValue(localInfo, "_shards", "successful"), equalTo(0)); - assertThat( - XContentMapValues.extractValue(localInfo, "_shards", "failed"), - equalTo(XContentMapValues.extractValue(localInfo, "_shards", "total")) - ); + Integer successfulShards = (Integer) XContentMapValues.extractValue(localInfo, "_shards", "successful"); + Integer failedShards = (Integer) XContentMapValues.extractValue(localInfo, "_shards", "failed"); + assertThat(successfulShards, greaterThan(0)); + assertThat(failedShards, greaterThan(0)); List> failures = (List>) XContentMapValues.extractValue(localInfo, "failures"); assertThat(failures, hasSize(1)); assertThat( @@ -167,7 +170,11 @@ public void testFailureFromRemote() throws Exception { Set okIds = populateIndices(); String query = """ { - "query": "FROM *:ok-index,*:failing-index | LIMIT 100 | KEEP fail_me,v" + "query": "FROM *:ok-index,*:failing-index | LIMIT 100 | KEEP fail_me,v", + "pragma": { + "max_concurrent_shards_per_node": 1 + }, + "accept_pragma_risks": true } """; // allow_partial_results = true @@ -183,7 +190,7 @@ public void testFailureFromRemote() throws Exception { List columns = (List) results.get("columns"); assertThat(columns, equalTo(List.of(Map.of("name", "fail_me", "type", "long"), Map.of("name", "v", "type", "long")))); List values = (List) results.get("values"); - assertThat(values.size(), lessThanOrEqualTo(okIds.size())); + assertThat(values.size(), equalTo(okIds.size())); Map remoteCluster = (Map) XContentMapValues.extractValue( results, "_clusters", @@ -191,11 +198,10 @@ public void testFailureFromRemote() throws Exception { "cluster_one" ); assertNotNull(remoteCluster); - assertThat(XContentMapValues.extractValue(remoteCluster, "_shards", "successful"), equalTo(0)); - assertThat( - XContentMapValues.extractValue(remoteCluster, "_shards", "failed"), - equalTo(XContentMapValues.extractValue(remoteCluster, "_shards", "total")) - ); + Integer successfulShards = (Integer) XContentMapValues.extractValue(remoteCluster, "_shards", "successful"); + Integer failedShards = (Integer) XContentMapValues.extractValue(remoteCluster, "_shards", "failed"); + assertThat(successfulShards, greaterThan(0)); + assertThat(failedShards, greaterThan(0)); List> failures = (List>) XContentMapValues.extractValue(remoteCluster, "failures"); assertThat(failures, hasSize(1)); assertThat( @@ -207,6 +213,25 @@ public void testFailureFromRemote() throws Exception { } } + public void testAllShardsFailed() throws Exception { + setupRemoteClusters(); + populateIndices(); + try { + for (boolean allowPartialResults : List.of(Boolean.TRUE, Boolean.FALSE)) { + for (String index : List.of("failing*", "*:failing*", "*:failing*,failing*")) { + Request request = new Request("POST", "/_query"); + request.setJsonEntity("{\"query\": \"FROM " + index + " | LIMIT 100\"}"); + request.addParameter("allow_partial_results", Boolean.toString(allowPartialResults)); + var error = expectThrows(ResponseException.class, () -> client().performRequest(request)); + Response resp = error.getResponse(); + assertThat(EntityUtils.toString(resp.getEntity()), containsString("Accessing failing field")); + } + } + } finally { + removeRemoteCluster(); + } + } + private void setupRemoteClusters() throws IOException { String settings = String.format(Locale.ROOT, """ { diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java index 1c237404a78cc..bc5e1f123fe81 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java @@ -22,6 +22,7 @@ static ElasticsearchCluster buildCluster() { .setting("xpack.security.enabled", "false") .setting("xpack.license.self_generated.type", "trial") .setting("esql.query.allow_partial_results", "false") + .setting("logger.org.elasticsearch.compute.lucene.read", "DEBUG") .jvmArg("-Xmx512m"); String javaVersion = JvmInfo.jvmInfo().version(); if (javaVersion.equals("20") || javaVersion.equals("21")) { diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index 3912a63ef1514..893acbd22cc23 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -570,7 +570,7 @@ protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOE } public void testFetchManyBigFields() throws IOException { - initManyBigFieldsIndex(100); + initManyBigFieldsIndex(100, "keyword"); Map response = fetchManyBigFields(100); ListMatcher columns = matchesList(); for (int f = 0; f < 1000; f++) { @@ -580,7 +580,7 @@ public void testFetchManyBigFields() throws IOException { } public void testFetchTooManyBigFields() throws IOException { - initManyBigFieldsIndex(500); + initManyBigFieldsIndex(500, "keyword"); // 500 docs is plenty to circuit break on most nodes assertCircuitBreaks(attempt -> fetchManyBigFields(attempt * 500)); } @@ -594,6 +594,58 @@ private Map fetchManyBigFields(int docs) throws IOException { return responseAsMap(query(query.toString(), "columns")); } + public void testAggManyBigTextFields() throws IOException { + int docs = 100; + int fields = 100; + initManyBigFieldsIndex(docs, "text"); + Map response = aggManyBigFields(fields); + ListMatcher columns = matchesList().item(matchesMap().entry("name", "sum").entry("type", "long")); + assertMap( + response, + matchesMap().entry("columns", columns).entry("values", matchesList().item(matchesList().item(1024 * fields * docs))) + ); + } + + /** + * Aggregates documents containing many fields which are {@code 1kb} each. + */ + private Map aggManyBigFields(int fields) throws IOException { + StringBuilder query = startQuery(); + query.append("FROM manybigfields | STATS sum = SUM("); + query.append("LENGTH(f").append(String.format(Locale.ROOT, "%03d", 0)).append(")"); + for (int f = 1; f < fields; f++) { + query.append(" + LENGTH(f").append(String.format(Locale.ROOT, "%03d", f)).append(")"); + } + query.append(")\"}"); + return responseAsMap(query(query.toString(), "columns,values")); + } + + /** + * Aggregates on the {@code LENGTH} of a giant text field. Without + * splitting pages on load (#131053) this throws a {@link CircuitBreakingException} + * when it tries to load a giant field. With that change it finishes + * after loading many single-row pages. + */ + public void testAggGiantTextField() throws IOException { + int docs = 100; + initGiantTextField(docs); + Map response = aggGiantTextField(); + ListMatcher columns = matchesList().item(matchesMap().entry("name", "sum").entry("type", "long")); + assertMap( + response, + matchesMap().entry("columns", columns).entry("values", matchesList().item(matchesList().item(1024 * 1024 * 5 * docs))) + ); + } + + /** + * Aggregates documents containing a text field that is {@code 1mb} each. + */ + private Map aggGiantTextField() throws IOException { + StringBuilder query = startQuery(); + query.append("FROM bigtext | STATS sum = SUM(LENGTH(f))\"}"); + return responseAsMap(query(query.toString(), "columns,values")); + } + public void testAggMvLongs() throws IOException { int fieldValues = 100; initMvLongsIndex(1, 3, fieldValues); @@ -788,7 +840,7 @@ private void initSingleDocIndex() throws IOException { """); } - private void initManyBigFieldsIndex(int docs) throws IOException { + private void initManyBigFieldsIndex(int docs, String type) throws IOException { logger.info("loading many documents with many big fields"); int docsPerBulk = 5; int fields = 1000; @@ -799,7 +851,7 @@ private void initManyBigFieldsIndex(int docs) throws IOException { config.startObject("settings").field("index.mapping.total_fields.limit", 10000).endObject(); config.startObject("mappings").startObject("properties"); for (int f = 0; f < fields; f++) { - config.startObject("f" + String.format(Locale.ROOT, "%03d", f)).field("type", "keyword").endObject(); + config.startObject("f" + String.format(Locale.ROOT, "%03d", f)).field("type", type).endObject(); } config.endObject().endObject(); request.setJsonEntity(Strings.toString(config.endObject())); @@ -831,6 +883,37 @@ private void initManyBigFieldsIndex(int docs) throws IOException { initIndex("manybigfields", bulk.toString()); } + private void initGiantTextField(int docs) throws IOException { + logger.info("loading many documents with one big text field"); + int docsPerBulk = 3; + int fieldSize = Math.toIntExact(ByteSizeValue.ofMb(5).getBytes()); + + Request request = new Request("PUT", "/bigtext"); + XContentBuilder config = JsonXContent.contentBuilder().startObject(); + config.startObject("mappings").startObject("properties"); + config.startObject("f").field("type", "text").endObject(); + config.endObject().endObject(); + request.setJsonEntity(Strings.toString(config.endObject())); + Response response = client().performRequest(request); + assertThat( + EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), + equalTo("{\"acknowledged\":true,\"shards_acknowledged\":true,\"index\":\"bigtext\"}") + ); + + StringBuilder bulk = new StringBuilder(); + for (int d = 0; d < docs; d++) { + bulk.append("{\"create\":{}}\n"); + bulk.append("{\"f\":\""); + bulk.append(Integer.toString(d % 10).repeat(fieldSize)); + bulk.append("\"}\n"); + if (d % docsPerBulk == docsPerBulk - 1 && d != docs - 1) { + bulk("bigtext", bulk.toString()); + bulk.setLength(0); + } + } + initIndex("bigtext", bulk.toString()); + } + private void initMvLongsIndex(int docs, int fields, int fieldValues) throws IOException { logger.info("loading documents with many multivalued longs"); int docsPerBulk = 100; diff --git a/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java b/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java index be709eaf5f43c..bf53f14bc9e46 100644 --- a/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java +++ b/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.network.IfConfig; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Booleans; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.PathUtils; import org.elasticsearch.entitlement.bootstrap.TestEntitlementBootstrap; import org.elasticsearch.jdk.JarHell; @@ -76,20 +75,12 @@ public class BootstrapForTesting { // Fire up entitlements try { - TestEntitlementBootstrap.bootstrap(javaTmpDir, maybePath(System.getProperty("tests.config"))); + TestEntitlementBootstrap.bootstrap(javaTmpDir); } catch (IOException e) { throw new IllegalStateException(e.getClass().getSimpleName() + " while initializing entitlements for tests", e); } } - private static @Nullable Path maybePath(String str) { - if (str == null) { - return null; - } else { - return PathUtils.get(str); - } - } - // does nothing, just easy way to make sure the class is loaded. public static void ensureInitialized() {} } diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java b/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java index 01b11ce97460a..6b6136c6c861b 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java @@ -43,7 +43,7 @@ public static class TestPlugin extends Plugin {} private volatile BiFunction diskUsageFunction; public MockInternalClusterInfoService(Settings settings, ClusterService clusterService, ThreadPool threadPool, NodeClient client) { - super(settings, clusterService, threadPool, client, EstimatedHeapUsageCollector.EMPTY); + super(settings, clusterService, threadPool, client, EstimatedHeapUsageCollector.EMPTY, NodeUsageStatsForThreadPoolsCollector.EMPTY); } public void setDiskUsageFunctionAndRefresh(BiFunction diskUsageFn) { diff --git a/test/framework/src/main/java/org/elasticsearch/entitlement/bootstrap/TestEntitlementBootstrap.java b/test/framework/src/main/java/org/elasticsearch/entitlement/bootstrap/TestEntitlementBootstrap.java index 160d601e3e585..3e6f09915358b 100644 --- a/test/framework/src/main/java/org/elasticsearch/entitlement/bootstrap/TestEntitlementBootstrap.java +++ b/test/framework/src/main/java/org/elasticsearch/entitlement/bootstrap/TestEntitlementBootstrap.java @@ -12,6 +12,7 @@ import org.elasticsearch.bootstrap.TestBuildInfo; import org.elasticsearch.bootstrap.TestBuildInfoParser; import org.elasticsearch.bootstrap.TestScopeResolver; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.PathUtils; @@ -19,6 +20,7 @@ import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.entitlement.initialization.EntitlementInitialization; import org.elasticsearch.entitlement.runtime.policy.PathLookup; +import org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir; import org.elasticsearch.entitlement.runtime.policy.Policy; import org.elasticsearch.entitlement.runtime.policy.PolicyParser; import org.elasticsearch.entitlement.runtime.policy.TestPathLookup; @@ -32,41 +34,106 @@ import java.net.URI; import java.net.URL; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; +import java.util.function.Consumer; import static java.util.stream.Collectors.toCollection; -import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toSet; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.CONFIG; import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.TEMP; +import static org.elasticsearch.env.Environment.PATH_DATA_SETTING; +import static org.elasticsearch.env.Environment.PATH_HOME_SETTING; +import static org.elasticsearch.env.Environment.PATH_REPO_SETTING; public class TestEntitlementBootstrap { private static final Logger logger = LogManager.getLogger(TestEntitlementBootstrap.class); + private static Map> baseDirPaths = new ConcurrentHashMap<>(); private static TestPolicyManager policyManager; /** * Activates entitlement checking in tests. */ - public static void bootstrap(@Nullable Path tempDir, @Nullable Path configDir) throws IOException { + public static void bootstrap(@Nullable Path tempDir) throws IOException { if (isEnabledForTest() == false) { return; } - TestPathLookup pathLookup = new TestPathLookup(Map.of(TEMP, zeroOrOne(tempDir), CONFIG, zeroOrOne(configDir))); + var previousTempDir = baseDirPaths.put(TEMP, zeroOrOne(tempDir)); + assert previousTempDir == null : "Test entitlement bootstrap called multiple times"; + TestPathLookup pathLookup = new TestPathLookup(baseDirPaths); policyManager = createPolicyManager(pathLookup); EntitlementInitialization.initializeArgs = new EntitlementInitialization.InitializeArgs(pathLookup, Set.of(), policyManager); logger.debug("Loading entitlement agent"); EntitlementBootstrap.loadAgent(EntitlementBootstrap.findAgentJar(), EntitlementInitialization.class.getName()); } + public static void registerNodeBaseDirs(Settings settings, Path configPath) { + if (policyManager == null) { + return; + } + Path homeDir = absolutePath(PATH_HOME_SETTING.get(settings)); + Path configDir = configPath != null ? configPath : homeDir.resolve("config"); + Collection dataDirs = dataDirs(settings, homeDir); + Collection repoDirs = repoDirs(settings); + logger.debug("Registering node dirs: config [{}], dataDirs [{}], repoDirs [{}]", configDir, dataDirs, repoDirs); + baseDirPaths.compute(BaseDir.CONFIG, baseDirModifier(paths -> paths.add(configDir))); + baseDirPaths.compute(BaseDir.DATA, baseDirModifier(paths -> paths.addAll(dataDirs))); + baseDirPaths.compute(BaseDir.SHARED_REPO, baseDirModifier(paths -> paths.addAll(repoDirs))); + policyManager.reset(); + } + + public static void unregisterNodeBaseDirs(Settings settings, Path configPath) { + if (policyManager == null) { + return; + } + Path homeDir = absolutePath(PATH_HOME_SETTING.get(settings)); + Path configDir = configPath != null ? configPath : homeDir.resolve("config"); + Collection dataDirs = dataDirs(settings, homeDir); + Collection repoDirs = repoDirs(settings); + logger.debug("Unregistering node dirs: config [{}], dataDirs [{}], repoDirs [{}]", configDir, dataDirs, repoDirs); + baseDirPaths.compute(BaseDir.CONFIG, baseDirModifier(paths -> paths.remove(configDir))); + baseDirPaths.compute(BaseDir.DATA, baseDirModifier(paths -> paths.removeAll(dataDirs))); + baseDirPaths.compute(BaseDir.SHARED_REPO, baseDirModifier(paths -> paths.removeAll(repoDirs))); + policyManager.reset(); + } + + private static Collection dataDirs(Settings settings, Path homeDir) { + List dataDirs = PATH_DATA_SETTING.get(settings); + return dataDirs.isEmpty() + ? List.of(homeDir.resolve("data")) + : dataDirs.stream().map(TestEntitlementBootstrap::absolutePath).toList(); + } + + private static Collection repoDirs(Settings settings) { + return PATH_REPO_SETTING.get(settings).stream().map(TestEntitlementBootstrap::absolutePath).toList(); + } + + private static BiFunction, Collection> baseDirModifier(Consumer> consumer) { + return (BaseDir baseDir, Collection paths) -> { + if (paths == null) { + paths = new HashSet<>(); + } + consumer.accept(paths); + return paths; + }; + } + + @SuppressForbidden(reason = "must be resolved using the default file system, rather then the mocked test file system") + private static Path absolutePath(String path) { + return Paths.get(path).toAbsolutePath().normalize(); + } + private static List zeroOrOne(T item) { if (item == null) { return List.of(); @@ -128,8 +195,6 @@ private static TestPolicyManager createPolicyManager(PathLookup pathLookup) thro } else { classPathEntries = Arrays.stream(classPathProperty.split(separator)).map(PathUtils::get).collect(toCollection(TreeSet::new)); } - Map> pluginSourcePaths = pluginNames.stream().collect(toMap(n -> n, n -> classPathEntries)); - FilesEntitlementsValidation.validate(pluginPolicies, pathLookup); String testOnlyPathString = System.getenv("es.entitlement.testOnlyPath"); @@ -148,8 +213,8 @@ private static TestPolicyManager createPolicyManager(PathLookup pathLookup) thro HardcodedEntitlements.agentEntitlements(), pluginPolicies, scopeResolver, - pluginSourcePaths, pathLookup, + classPathEntries, testOnlyClassPath ); } diff --git a/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPathLookup.java b/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPathLookup.java index be99d8187f95e..458d83590758c 100644 --- a/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPathLookup.java +++ b/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPathLookup.java @@ -9,6 +9,8 @@ package org.elasticsearch.entitlement.runtime.policy; +import org.apache.lucene.tests.mockfile.FilterFileSystem; + import java.nio.file.Path; import java.util.Collection; import java.util.List; @@ -37,4 +39,14 @@ public Stream resolveSettingPaths(BaseDir baseDir, String settingName) { return Stream.empty(); } + @Override + public boolean isPathOnDefaultFilesystem(Path path) { + var fileSystem = path.getFileSystem(); + if (fileSystem.getClass() != DEFAULT_FILESYSTEM_CLASS) { + while (fileSystem instanceof FilterFileSystem ffs) { + fileSystem = ffs.getDelegate(); + } + } + return fileSystem.getClass() == DEFAULT_FILESYSTEM_CLASS; + } } diff --git a/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManager.java b/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManager.java index 3d7387a6a2f3a..b504e51e119f7 100644 --- a/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManager.java +++ b/test/framework/src/main/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManager.java @@ -38,7 +38,7 @@ public class TestPolicyManager extends PolicyManager { * We need this larger map per class instead. */ final Map, ModuleEntitlements> classEntitlementsMap = new ConcurrentHashMap<>(); - + final Collection classpath; final Collection testOnlyClasspath; public TestPolicyManager( @@ -46,11 +46,12 @@ public TestPolicyManager( List apmAgentEntitlements, Map pluginPolicies, Function, PolicyScope> scopeResolver, - Map> pluginSourcePaths, PathLookup pathLookup, + Collection classpath, Collection testOnlyClasspath ) { - super(serverPolicy, apmAgentEntitlements, pluginPolicies, scopeResolver, pluginSourcePaths, pathLookup); + super(serverPolicy, apmAgentEntitlements, pluginPolicies, scopeResolver, name -> classpath, pathLookup); + this.classpath = classpath; this.testOnlyClasspath = testOnlyClasspath; reset(); } @@ -118,6 +119,11 @@ boolean isTriviallyAllowed(Class requestingClass) { return super.isTriviallyAllowed(requestingClass); } + @Override + protected Collection getComponentPathsFromClass(Class requestingClass) { + return classpath; // required to grant read access to the production source and test resources + } + private boolean isEntitlementClass(Class requestingClass) { return requestingClass.getPackageName().startsWith("org.elasticsearch.entitlement") && (requestingClass.getName().contains("Test") == false); @@ -180,6 +186,9 @@ private boolean isTestCode(Class requestingClass) { URI needle; try { needle = codeSource.getLocation().toURI(); + if (needle.getScheme().equals("jrt")) { + return false; // won't be on testOnlyClasspath + } } catch (URISyntaxException e) { throw new IllegalStateException(e); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java index 1c785d58f9804..f099aaac463db 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java @@ -420,13 +420,12 @@ public final void testCacheable() throws IOException { } } - protected final List blockLoaderReadValuesFromColumnAtATimeReader(DirectoryReader reader, MappedFieldType fieldType) + protected final List blockLoaderReadValuesFromColumnAtATimeReader(DirectoryReader reader, MappedFieldType fieldType, int offset) throws IOException { BlockLoader loader = fieldType.blockLoader(blContext()); List all = new ArrayList<>(); for (LeafReaderContext ctx : reader.leaves()) { - TestBlock block = (TestBlock) loader.columnAtATimeReader(ctx) - .read(TestBlock.factory(ctx.reader().numDocs()), TestBlock.docs(ctx)); + TestBlock block = (TestBlock) loader.columnAtATimeReader(ctx).read(TestBlock.factory(), TestBlock.docs(ctx), offset); for (int i = 0; i < block.size(); i++) { all.add(block.get(i)); } @@ -440,7 +439,7 @@ protected final List blockLoaderReadValuesFromRowStrideReader(DirectoryR List all = new ArrayList<>(); for (LeafReaderContext ctx : reader.leaves()) { BlockLoader.RowStrideReader blockReader = loader.rowStrideReader(ctx); - BlockLoader.Builder builder = loader.builder(TestBlock.factory(ctx.reader().numDocs()), ctx.reader().numDocs()); + BlockLoader.Builder builder = loader.builder(TestBlock.factory(), ctx.reader().numDocs()); for (int i = 0; i < ctx.reader().numDocs(); i++) { blockReader.read(i, null, builder); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java index e35a53c0ecca8..eeb1a349d8bbc 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java @@ -36,6 +36,8 @@ import static org.apache.lucene.tests.util.LuceneTestCase.newDirectory; import static org.apache.lucene.tests.util.LuceneTestCase.random; import static org.elasticsearch.index.mapper.BlockLoaderTestRunner.PrettyEqual.prettyEqualTo; +import static org.elasticsearch.test.ESTestCase.between; +import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -69,7 +71,11 @@ private Object setupAndInvokeBlockLoader(MapperService mapperService, XContentBu ); LuceneDocument doc = mapperService.documentMapper().parse(source).rootDoc(); - iw.addDocument(doc); + /* + * Add three documents with doc id 0, 1, 2. The real document is 1. + * The other two are empty documents. + */ + iw.addDocuments(List.of(List.of(), doc, List.of())); iw.close(); try (DirectoryReader reader = DirectoryReader.open(directory)) { @@ -83,9 +89,32 @@ private Object load(BlockLoader blockLoader, LeafReaderContext context, MapperSe // `columnAtATimeReader` is tried first, we mimic `ValuesSourceReaderOperator` var columnAtATimeReader = blockLoader.columnAtATimeReader(context); if (columnAtATimeReader != null) { - BlockLoader.Docs docs = TestBlock.docs(0); - var block = (TestBlock) columnAtATimeReader.read(TestBlock.factory(context.reader().numDocs()), docs); - assertThat(block.size(), equalTo(1)); + int[] docArray; + int offset; + if (randomBoolean()) { + // Half the time we load a single document. Nice and simple. + docArray = new int[] { 1 }; + offset = 0; + } else { + /* + * The other half the time we emulate loading a larger page, + * starting part way through the page. + */ + docArray = new int[between(2, 10)]; + offset = between(0, docArray.length - 1); + for (int i = 0; i < docArray.length; i++) { + if (i < offset) { + docArray[i] = 0; + } else if (i == offset) { + docArray[i] = 1; + } else { + docArray[i] = 2; + } + } + } + BlockLoader.Docs docs = TestBlock.docs(docArray); + var block = (TestBlock) columnAtATimeReader.read(TestBlock.factory(), docs, offset); + assertThat(block.size(), equalTo(docArray.length - offset)); return block.get(0); } @@ -102,10 +131,10 @@ private Object load(BlockLoader blockLoader, LeafReaderContext context, MapperSe StoredFieldLoader.fromSpec(storedFieldsSpec).getLoader(context, null), leafSourceLoader ); - storedFieldsLoader.advanceTo(0); + storedFieldsLoader.advanceTo(1); - BlockLoader.Builder builder = blockLoader.builder(TestBlock.factory(context.reader().numDocs()), 1); - blockLoader.rowStrideReader(context).read(0, storedFieldsLoader, builder); + BlockLoader.Builder builder = blockLoader.builder(TestBlock.factory(), 1); + blockLoader.rowStrideReader(context).read(1, storedFieldsLoader, builder); var block = (TestBlock) builder.build(); assertThat(block.size(), equalTo(1)); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java index 779d7a2a976d9..cb73dc96f69b2 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java @@ -11,7 +11,9 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.util.BytesRef; +import org.hamcrest.Matcher; import java.io.IOException; import java.io.UncheckedIOException; @@ -19,11 +21,14 @@ import java.util.HashMap; import java.util.List; +import static org.elasticsearch.test.ESTestCase.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; public class TestBlock implements BlockLoader.Block { - public static BlockLoader.BlockFactory factory(int pageSize) { + public static BlockLoader.BlockFactory factory() { return new BlockLoader.BlockFactory() { @Override public BlockLoader.BooleanBuilder booleansFromDocValues(int expectedCount) { @@ -33,6 +38,10 @@ public BlockLoader.BooleanBuilder booleansFromDocValues(int expectedCount) { @Override public BlockLoader.BooleanBuilder booleans(int expectedCount) { class BooleansBuilder extends TestBlock.Builder implements BlockLoader.BooleanBuilder { + private BooleansBuilder() { + super(expectedCount); + } + @Override public BooleansBuilder appendBoolean(boolean value) { add(value); @@ -44,12 +53,41 @@ public BooleansBuilder appendBoolean(boolean value) { @Override public BlockLoader.BytesRefBuilder bytesRefsFromDocValues(int expectedCount) { - return bytesRefs(expectedCount); + class BytesRefsFromDocValuesBuilder extends TestBlock.Builder implements BlockLoader.BytesRefBuilder { + private BytesRefsFromDocValuesBuilder() { + super(1); + } + + @Override + public BytesRefsFromDocValuesBuilder appendBytesRef(BytesRef value) { + add(BytesRef.deepCopyOf(value)); + return this; + } + + @Override + public TestBlock build() { + TestBlock result = super.build(); + List r; + if (result.values.get(0) instanceof List l) { + r = l; + } else { + r = List.of(result.values.get(0)); + } + assertThat(r, hasSize(expectedCount)); + return result; + } + + } + return new BytesRefsFromDocValuesBuilder(); } @Override public BlockLoader.BytesRefBuilder bytesRefs(int expectedCount) { class BytesRefsBuilder extends TestBlock.Builder implements BlockLoader.BytesRefBuilder { + private BytesRefsBuilder() { + super(expectedCount); + } + @Override public BytesRefsBuilder appendBytesRef(BytesRef value) { add(BytesRef.deepCopyOf(value)); @@ -67,6 +105,10 @@ public BlockLoader.DoubleBuilder doublesFromDocValues(int expectedCount) { @Override public BlockLoader.DoubleBuilder doubles(int expectedCount) { class DoublesBuilder extends TestBlock.Builder implements BlockLoader.DoubleBuilder { + private DoublesBuilder() { + super(expectedCount); + } + @Override public DoublesBuilder appendDouble(double value) { add(value); @@ -81,6 +123,10 @@ public BlockLoader.FloatBuilder denseVectors(int expectedCount, int dimensions) class FloatsBuilder extends TestBlock.Builder implements BlockLoader.FloatBuilder { int numElements = 0; + private FloatsBuilder() { + super(expectedCount); + } + @Override public BlockLoader.FloatBuilder appendFloat(float value) { add(value); @@ -117,6 +163,10 @@ public BlockLoader.IntBuilder intsFromDocValues(int expectedCount) { @Override public BlockLoader.IntBuilder ints(int expectedCount) { class IntsBuilder extends TestBlock.Builder implements BlockLoader.IntBuilder { + private IntsBuilder() { + super(expectedCount); + } + @Override public IntsBuilder appendInt(int value) { add(value); @@ -134,6 +184,10 @@ public BlockLoader.LongBuilder longsFromDocValues(int expectedCount) { @Override public BlockLoader.LongBuilder longs(int expectedCount) { class LongsBuilder extends TestBlock.Builder implements BlockLoader.LongBuilder { + private LongsBuilder() { + super(expectedCount); + } + @Override public LongsBuilder appendLong(long value) { add(value); @@ -149,26 +203,30 @@ public BlockLoader.Builder nulls(int expectedCount) { } @Override - public BlockLoader.Block constantNulls() { - BlockLoader.LongBuilder builder = longs(pageSize); - for (int i = 0; i < pageSize; i++) { + public BlockLoader.Block constantNulls(int count) { + BlockLoader.LongBuilder builder = longs(count); + for (int i = 0; i < count; i++) { builder.appendNull(); } return builder.build(); } @Override - public BlockLoader.Block constantBytes(BytesRef value) { - BlockLoader.BytesRefBuilder builder = bytesRefs(pageSize); - for (int i = 0; i < pageSize; i++) { + public BlockLoader.Block constantBytes(BytesRef value, int count) { + BlockLoader.BytesRefBuilder builder = bytesRefs(count); + for (int i = 0; i < count; i++) { builder.appendBytesRef(value); } return builder.build(); } @Override - public BlockLoader.SingletonOrdinalsBuilder singletonOrdinalsBuilder(SortedDocValues ordinals, int count) { + public BlockLoader.SingletonOrdinalsBuilder singletonOrdinalsBuilder(SortedDocValues ordinals, int expectedCount) { class SingletonOrdsBuilder extends TestBlock.Builder implements BlockLoader.SingletonOrdinalsBuilder { + private SingletonOrdsBuilder() { + super(expectedCount); + } + @Override public SingletonOrdsBuilder appendOrd(int value) { try { @@ -183,8 +241,27 @@ public SingletonOrdsBuilder appendOrd(int value) { } @Override - public BlockLoader.AggregateMetricDoubleBuilder aggregateMetricDoubleBuilder(int count) { - return new AggregateMetricDoubleBlockBuilder(); + public BlockLoader.SortedSetOrdinalsBuilder sortedSetOrdinalsBuilder(SortedSetDocValues ordinals, int expectedSize) { + class SortedSetOrdinalBuilder extends TestBlock.Builder implements BlockLoader.SortedSetOrdinalsBuilder { + private SortedSetOrdinalBuilder() { + super(expectedSize); + } + + @Override + public SortedSetOrdinalBuilder appendOrd(int value) { + try { + add(BytesRef.deepCopyOf(ordinals.lookupOrd(value))); + return this; + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + return new SortedSetOrdinalBuilder(); + } + + public BlockLoader.AggregateMetricDoubleBuilder aggregateMetricDoubleBuilder(int expectedSize) { + return new AggregateMetricDoubleBlockBuilder(expectedSize); } }; } @@ -239,8 +316,14 @@ public void close() { private abstract static class Builder implements BlockLoader.Builder { private final List values = new ArrayList<>(); + private Matcher expectedSize; + private List currentPosition = null; + private Builder(int expectedSize) { + this.expectedSize = equalTo(expectedSize); + } + @Override public Builder appendNull() { assertNull(currentPosition); @@ -269,6 +352,7 @@ protected void add(Object value) { @Override public TestBlock build() { + assertThat(values, hasSize(expectedSize)); return new TestBlock(values); } @@ -283,12 +367,23 @@ public void close() { * The implementation here is fairly close to the production one. */ private static class AggregateMetricDoubleBlockBuilder implements BlockLoader.AggregateMetricDoubleBuilder { - private final DoubleBuilder min = new DoubleBuilder(); - private final DoubleBuilder max = new DoubleBuilder(); - private final DoubleBuilder sum = new DoubleBuilder(); - private final IntBuilder count = new IntBuilder(); + private final DoubleBuilder min; + private final DoubleBuilder max; + private final DoubleBuilder sum; + private final IntBuilder count; + + private AggregateMetricDoubleBlockBuilder(int expectedSize) { + min = new DoubleBuilder(expectedSize); + max = new DoubleBuilder(expectedSize); + sum = new DoubleBuilder(expectedSize); + count = new IntBuilder(expectedSize); + } private static class DoubleBuilder extends TestBlock.Builder implements BlockLoader.DoubleBuilder { + private DoubleBuilder(int expectedSize) { + super(expectedSize); + } + @Override public BlockLoader.DoubleBuilder appendDouble(double value) { add(value); @@ -297,6 +392,10 @@ public BlockLoader.DoubleBuilder appendDouble(double value) { } private static class IntBuilder extends TestBlock.Builder implements BlockLoader.IntBuilder { + private IntBuilder(int expectedSize) { + super(expectedSize); + } + @Override public BlockLoader.IntBuilder appendInt(int value) { add(value); diff --git a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java index 03d6ac6342b42..0c1e381f69c4e 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; @@ -242,6 +243,7 @@ public MockIndexService indexService(Index index) { @Override public void createShard( + final ProjectId projectId, final ShardRouting shardRouting, final PeerRecoveryTargetService recoveryTargetService, final PeerRecoveryTargetService.RecoveryListener recoveryListener, diff --git a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java index 4c2e0a3c6c047..3d5229435e729 100644 --- a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java +++ b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.entitlement.bootstrap.TestEntitlementBootstrap; import org.elasticsearch.env.Environment; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.indices.ExecutorSelector; @@ -53,6 +54,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportSettings; +import java.io.IOException; import java.nio.file.Path; import java.util.Collection; import java.util.Collections; @@ -254,16 +256,7 @@ public MockNode( final Path configPath, final boolean forbidPrivateIndexSettings ) { - this( - InternalSettingsPreparer.prepareEnvironment( - Settings.builder().put(TransportSettings.PORT.getKey(), ESTestCase.getPortRange()).put(settings).build(), - Collections.emptyMap(), - configPath, - () -> "mock_ node" - ), - classpathPlugins, - forbidPrivateIndexSettings - ); + this(prepareEnvironment(settings, configPath), classpathPlugins, forbidPrivateIndexSettings); } private MockNode( @@ -282,6 +275,25 @@ PluginsService newPluginService(Environment environment, PluginsLoader pluginsLo this.classpathPlugins = classpathPlugins; } + private static Environment prepareEnvironment(final Settings settings, final Path configPath) { + TestEntitlementBootstrap.registerNodeBaseDirs(settings, configPath); + return InternalSettingsPreparer.prepareEnvironment( + Settings.builder().put(TransportSettings.PORT.getKey(), ESTestCase.getPortRange()).put(settings).build(), + Collections.emptyMap(), + configPath, + () -> "mock_ node" + ); + } + + @Override + public synchronized void close() throws IOException { + try { + super.close(); + } finally { + TestEntitlementBootstrap.unregisterNodeBaseDirs(getEnvironment().settings(), getEnvironment().configDir()); + } + } + /** * The classpath plugins this node was constructed with. */ diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index f0bdbe5ced329..89099bb7e32ff 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -286,7 +286,6 @@ * */ @LuceneTestCase.SuppressFileSystems("ExtrasFS") // doesn't work with potential multi data path from test cluster yet -@ESTestCase.WithoutEntitlements // ES-12042 public abstract class ESIntegTestCase extends ESTestCase { /** node names of the corresponding clusters will start with these prefixes */ diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java index f7d272e793e0f..7ebc5765bda63 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java @@ -90,7 +90,6 @@ * A test that keep a singleton node started for all tests that can be used to get * references to Guice injectors in unit tests. */ -@ESTestCase.WithoutEntitlements // ES-12042 public abstract class ESSingleNodeTestCase extends ESTestCase { private static Node NODE = null; diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 47ca9e5ee2afc..6ced34ce72759 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -534,7 +534,6 @@ public static void setupEntitlementsForClass() { TestEntitlementBootstrap.setActive(false == withoutEntitlements); TestEntitlementBootstrap.setTriviallyAllowingTestCode(false == withEntitlementsOnTestCode); if (entitledPackages != null) { - assert withEntitlementsOnTestCode == false : "Cannot use @WithEntitlementsOnTestCode together with @EntitledTestPackages"; assert entitledPackages.value().length > 0 : "No test packages specified in @EntitledTestPackages"; TestEntitlementBootstrap.setEntitledTestPackages(entitledPackages.value()); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java index 0a7b4696ff457..70dd46816d5d7 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java @@ -1228,7 +1228,7 @@ protected static void wipeAllIndices(boolean preserveSecurityIndices) throws IOE try { // remove all indices except some history indices which can pop up after deleting all data streams but shouldn't interfere final List indexPatterns = new ArrayList<>( - List.of("*", "-.ds-ilm-history-*", "-.ds-.slm-history-*", "-.ds-.watcher-history-*") + List.of("*", "-.ds-ilm-history-*", "-.ds-.slm-history-*", "-.ds-.watcher-history-*", "-.ds-.triggered_watches-*") ); if (preserveSecurityIndices) { indexPatterns.add("-.security-*"); diff --git a/test/framework/src/test/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManagerTests.java b/test/framework/src/test/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManagerTests.java index 366e0bd5505fe..4a62355f398d8 100644 --- a/test/framework/src/test/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManagerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/entitlement/runtime/policy/TestPolicyManagerTests.java @@ -34,8 +34,8 @@ public void setupPolicyManager() { List.of(), Map.of(), c -> new PolicyScope(PLUGIN, "example-plugin" + scopeCounter.incrementAndGet(), "org.example.module"), - Map.of(), new TestPathLookup(Map.of()), + List.of(), List.of() ); policyManager.setActive(true); diff --git a/x-pack/plugin/apm-data/src/main/resources/component-templates/apm@settings.yaml b/x-pack/plugin/apm-data/src/main/resources/component-templates/apm@settings.yaml index 75671948de11a..29eb115efac0b 100644 --- a/x-pack/plugin/apm-data/src/main/resources/component-templates/apm@settings.yaml +++ b/x-pack/plugin/apm-data/src/main/resources/component-templates/apm@settings.yaml @@ -12,3 +12,6 @@ template: ignore_malformed: true total_fields: ignore_dynamic_beyond_limit: true + data_stream_options: + failure_store: + enabled: true diff --git a/x-pack/plugin/apm-data/src/main/resources/resources.yaml b/x-pack/plugin/apm-data/src/main/resources/resources.yaml index 70675f1dd10d6..9704557063bc5 100644 --- a/x-pack/plugin/apm-data/src/main/resources/resources.yaml +++ b/x-pack/plugin/apm-data/src/main/resources/resources.yaml @@ -1,7 +1,7 @@ # "version" holds the version of the templates and ingest pipelines installed # by xpack-plugin apm-data. This must be increased whenever an existing template or # pipeline is changed, in order for it to be updated on Elasticsearch upgrade. -version: 101 +version: 102 component-templates: # Data lifecycle. diff --git a/x-pack/plugin/async-search/build.gradle b/x-pack/plugin/async-search/build.gradle index 6117f8bea15dc..edb0847575fde 100644 --- a/x-pack/plugin/async-search/build.gradle +++ b/x-pack/plugin/async-search/build.gradle @@ -1,3 +1,5 @@ +import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask + apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-java-rest-test' @@ -34,3 +36,8 @@ restResources { include '_common', 'indices', 'index', 'async_search' } } + +tasks.withType(StandaloneRestIntegTestTask).configureEach { + def isSnapshot = buildParams.snapshotBuild + it.onlyIf("snapshot build") { isSnapshot } +} diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java index bf8576afc5d70..e2b76658e5246 100644 --- a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.ErrorTraceHelper; import org.elasticsearch.search.SearchService; @@ -72,11 +71,11 @@ private void setupIndexWithDocs() { refresh(); } - public void testAsyncSearchFailingQueryErrorTraceDefault() throws IOException, InterruptedException { + public void testAsyncSearchFailingQueryErrorTraceDefault() throws Exception { setupIndexWithDocs(); - Request searchRequest = new Request("POST", "/_async_search"); - searchRequest.setJsonEntity(""" + Request createAsyncRequest = new Request("POST", "/_async_search"); + createAsyncRequest.setJsonEntity(""" { "query": { "simple_query_string" : { @@ -86,23 +85,23 @@ public void testAsyncSearchFailingQueryErrorTraceDefault() throws IOException, I } } """); - searchRequest.addParameter("keep_on_completion", "true"); - searchRequest.addParameter("wait_for_completion_timeout", "0ms"); - Map responseEntity = performRequestAndGetResponseEntityAfterDelay(searchRequest, TimeValue.ZERO); - String asyncExecutionId = (String) responseEntity.get("id"); - Request request = new Request("GET", "/_async_search/" + asyncExecutionId); - while (responseEntity.get("is_running") instanceof Boolean isRunning && isRunning) { - responseEntity = performRequestAndGetResponseEntityAfterDelay(request, TimeValue.timeValueSeconds(1L)); + createAsyncRequest.addParameter("keep_on_completion", "true"); + createAsyncRequest.addParameter("wait_for_completion_timeout", "0ms"); + Map createAsyncResponseEntity = performRequestAndGetResponseEntity(createAsyncRequest); + if (createAsyncResponseEntity.get("is_running").equals("true")) { + String asyncExecutionId = (String) createAsyncResponseEntity.get("id"); + Request getAsyncRequest = new Request("GET", "/_async_search/" + asyncExecutionId); + awaitAsyncRequestDoneRunning(getAsyncRequest); } // check that the stack trace was not sent from the data node to the coordinating node assertFalse(transportMessageHasStackTrace.getAsBoolean()); } - public void testAsyncSearchFailingQueryErrorTraceTrue() throws IOException, InterruptedException { + public void testAsyncSearchFailingQueryErrorTraceTrue() throws Exception { setupIndexWithDocs(); - Request searchRequest = new Request("POST", "/_async_search"); - searchRequest.setJsonEntity(""" + Request createAsyncRequest = new Request("POST", "/_async_search"); + createAsyncRequest.setJsonEntity(""" { "query": { "simple_query_string" : { @@ -112,25 +111,25 @@ public void testAsyncSearchFailingQueryErrorTraceTrue() throws IOException, Inte } } """); - searchRequest.addParameter("error_trace", "true"); - searchRequest.addParameter("keep_on_completion", "true"); - searchRequest.addParameter("wait_for_completion_timeout", "0ms"); - Map responseEntity = performRequestAndGetResponseEntityAfterDelay(searchRequest, TimeValue.ZERO); - String asyncExecutionId = (String) responseEntity.get("id"); - Request request = new Request("GET", "/_async_search/" + asyncExecutionId); - request.addParameter("error_trace", "true"); - while (responseEntity.get("is_running") instanceof Boolean isRunning && isRunning) { - responseEntity = performRequestAndGetResponseEntityAfterDelay(request, TimeValue.timeValueSeconds(1L)); + createAsyncRequest.addParameter("error_trace", "true"); + createAsyncRequest.addParameter("keep_on_completion", "true"); + createAsyncRequest.addParameter("wait_for_completion_timeout", "0ms"); + Map createAsyncResponseEntity = performRequestAndGetResponseEntity(createAsyncRequest); + if (createAsyncResponseEntity.get("is_running").equals("true")) { + String asyncExecutionId = (String) createAsyncResponseEntity.get("id"); + Request getAsyncRequest = new Request("GET", "/_async_search/" + asyncExecutionId); + getAsyncRequest.addParameter("error_trace", "true"); + awaitAsyncRequestDoneRunning(getAsyncRequest); } // check that the stack trace was sent from the data node to the coordinating node assertTrue(transportMessageHasStackTrace.getAsBoolean()); } - public void testAsyncSearchFailingQueryErrorTraceFalse() throws IOException, InterruptedException { + public void testAsyncSearchFailingQueryErrorTraceFalse() throws Exception { setupIndexWithDocs(); - Request searchRequest = new Request("POST", "/_async_search"); - searchRequest.setJsonEntity(""" + Request createAsyncRequest = new Request("POST", "/_async_search"); + createAsyncRequest.setJsonEntity(""" { "query": { "simple_query_string" : { @@ -140,28 +139,25 @@ public void testAsyncSearchFailingQueryErrorTraceFalse() throws IOException, Int } } """); - searchRequest.addParameter("error_trace", "false"); - searchRequest.addParameter("keep_on_completion", "true"); - searchRequest.addParameter("wait_for_completion_timeout", "0ms"); - Map responseEntity = performRequestAndGetResponseEntityAfterDelay(searchRequest, TimeValue.ZERO); - String asyncExecutionId = (String) responseEntity.get("id"); - Request request = new Request("GET", "/_async_search/" + asyncExecutionId); - request.addParameter("error_trace", "false"); - while (responseEntity.get("is_running") instanceof Boolean isRunning && isRunning) { - responseEntity = performRequestAndGetResponseEntityAfterDelay(request, TimeValue.timeValueSeconds(1L)); + createAsyncRequest.addParameter("error_trace", "false"); + createAsyncRequest.addParameter("keep_on_completion", "true"); + createAsyncRequest.addParameter("wait_for_completion_timeout", "0ms"); + Map createAsyncResponseEntity = performRequestAndGetResponseEntity(createAsyncRequest); + if (createAsyncResponseEntity.get("is_running").equals("true")) { + String asyncExecutionId = (String) createAsyncResponseEntity.get("id"); + Request getAsyncRequest = new Request("GET", "/_async_search/" + asyncExecutionId); + getAsyncRequest.addParameter("error_trace", "false"); + awaitAsyncRequestDoneRunning(getAsyncRequest); } // check that the stack trace was not sent from the data node to the coordinating node assertFalse(transportMessageHasStackTrace.getAsBoolean()); } - public void testDataNodeLogsStackTrace() throws IOException, InterruptedException { + public void testDataNodeLogsStackTrace() throws Exception { setupIndexWithDocs(); - // error_trace defaults to false so we can test both cases with some randomization - final boolean defineErrorTraceFalse = randomBoolean(); - - Request searchRequest = new Request("POST", "/_async_search"); - searchRequest.setJsonEntity(""" + Request createAsyncRequest = new Request("POST", "/_async_search"); + createAsyncRequest.setJsonEntity(""" { "query": { "simple_query_string" : { @@ -175,43 +171,40 @@ public void testDataNodeLogsStackTrace() throws IOException, InterruptedExceptio // No matter the value of error_trace (empty, true, or false) we should see stack traces logged int errorTraceValue = randomIntBetween(0, 2); if (errorTraceValue == 0) { - searchRequest.addParameter("error_trace", "true"); + createAsyncRequest.addParameter("error_trace", "true"); } else if (errorTraceValue == 1) { - searchRequest.addParameter("error_trace", "false"); + createAsyncRequest.addParameter("error_trace", "false"); } // else empty - searchRequest.addParameter("keep_on_completion", "true"); - searchRequest.addParameter("wait_for_completion_timeout", "0ms"); + createAsyncRequest.addParameter("keep_on_completion", "true"); + createAsyncRequest.addParameter("wait_for_completion_timeout", "0ms"); String errorTriggeringIndex = "test2"; int numShards = getNumShards(errorTriggeringIndex).numPrimaries; try (var mockLog = MockLog.capture(SearchService.class)) { ErrorTraceHelper.addSeenLoggingExpectations(numShards, mockLog, errorTriggeringIndex); - - Map responseEntity = performRequestAndGetResponseEntityAfterDelay(searchRequest, TimeValue.ZERO); - String asyncExecutionId = (String) responseEntity.get("id"); - Request request = new Request("GET", "/_async_search/" + asyncExecutionId); - - // Use the same value of error_trace as the search request - if (errorTraceValue == 0) { - request.addParameter("error_trace", "true"); - } else if (errorTraceValue == 1) { - request.addParameter("error_trace", "false"); - } // else empty - - while (responseEntity.get("is_running") instanceof Boolean isRunning && isRunning) { - responseEntity = performRequestAndGetResponseEntityAfterDelay(request, TimeValue.timeValueSeconds(1L)); + Map createAsyncResponseEntity = performRequestAndGetResponseEntity(createAsyncRequest); + if (createAsyncResponseEntity.get("is_running").equals("true")) { + String asyncExecutionId = (String) createAsyncResponseEntity.get("id"); + Request getAsyncRequest = new Request("GET", "/_async_search/" + asyncExecutionId); + // Use the same value of error_trace as the search request + if (errorTraceValue == 0) { + getAsyncRequest.addParameter("error_trace", "true"); + } else if (errorTraceValue == 1) { + getAsyncRequest.addParameter("error_trace", "false"); + } // else empty + awaitAsyncRequestDoneRunning(getAsyncRequest); } mockLog.assertAllExpectationsMatched(); } } - public void testAsyncSearchFailingQueryErrorTraceFalseOnSubmitAndTrueOnGet() throws IOException, InterruptedException { + public void testAsyncSearchFailingQueryErrorTraceFalseOnSubmitAndTrueOnGet() throws Exception { setupIndexWithDocs(); - Request searchRequest = new Request("POST", "/_async_search"); - searchRequest.setJsonEntity(""" + Request createAsyncSearchRequest = new Request("POST", "/_async_search"); + createAsyncSearchRequest.setJsonEntity(""" { "query": { "simple_query_string" : { @@ -221,25 +214,25 @@ public void testAsyncSearchFailingQueryErrorTraceFalseOnSubmitAndTrueOnGet() thr } } """); - searchRequest.addParameter("error_trace", "false"); - searchRequest.addParameter("keep_on_completion", "true"); - searchRequest.addParameter("wait_for_completion_timeout", "0ms"); - Map responseEntity = performRequestAndGetResponseEntityAfterDelay(searchRequest, TimeValue.ZERO); - String asyncExecutionId = (String) responseEntity.get("id"); - Request request = new Request("GET", "/_async_search/" + asyncExecutionId); - request.addParameter("error_trace", "true"); - while (responseEntity.get("is_running") instanceof Boolean isRunning && isRunning) { - responseEntity = performRequestAndGetResponseEntityAfterDelay(request, TimeValue.timeValueSeconds(1L)); + createAsyncSearchRequest.addParameter("error_trace", "false"); + createAsyncSearchRequest.addParameter("keep_on_completion", "true"); + createAsyncSearchRequest.addParameter("wait_for_completion_timeout", "0ms"); + Map createAsyncResponseEntity = performRequestAndGetResponseEntity(createAsyncSearchRequest); + if (createAsyncResponseEntity.get("is_running").equals("true")) { + String asyncExecutionId = (String) createAsyncResponseEntity.get("id"); + Request getAsyncRequest = new Request("GET", "/_async_search/" + asyncExecutionId); + getAsyncRequest.addParameter("error_trace", "true"); + awaitAsyncRequestDoneRunning(getAsyncRequest); } // check that the stack trace was not sent from the data node to the coordinating node assertFalse(transportMessageHasStackTrace.getAsBoolean()); } - public void testAsyncSearchFailingQueryErrorTraceTrueOnSubmitAndFalseOnGet() throws IOException, InterruptedException { + public void testAsyncSearchFailingQueryErrorTraceTrueOnSubmitAndFalseOnGet() throws Exception { setupIndexWithDocs(); - Request searchRequest = new Request("POST", "/_async_search"); - searchRequest.setJsonEntity(""" + Request createAsyncSearchRequest = new Request("POST", "/_async_search"); + createAsyncSearchRequest.setJsonEntity(""" { "query": { "simple_query_string" : { @@ -249,25 +242,30 @@ public void testAsyncSearchFailingQueryErrorTraceTrueOnSubmitAndFalseOnGet() thr } } """); - searchRequest.addParameter("error_trace", "true"); - searchRequest.addParameter("keep_on_completion", "true"); - searchRequest.addParameter("wait_for_completion_timeout", "0ms"); - Map responseEntity = performRequestAndGetResponseEntityAfterDelay(searchRequest, TimeValue.ZERO); - String asyncExecutionId = (String) responseEntity.get("id"); - Request request = new Request("GET", "/_async_search/" + asyncExecutionId); - request.addParameter("error_trace", "false"); - while (responseEntity.get("is_running") instanceof Boolean isRunning && isRunning) { - responseEntity = performRequestAndGetResponseEntityAfterDelay(request, TimeValue.timeValueSeconds(1L)); + createAsyncSearchRequest.addParameter("error_trace", "true"); + createAsyncSearchRequest.addParameter("keep_on_completion", "true"); + createAsyncSearchRequest.addParameter("wait_for_completion_timeout", "0ms"); + Map createAsyncResponseEntity = performRequestAndGetResponseEntity(createAsyncSearchRequest); + if (createAsyncResponseEntity.get("is_running").equals("true")) { + String asyncExecutionId = (String) createAsyncResponseEntity.get("id"); + Request getAsyncRequest = new Request("GET", "/_async_search/" + asyncExecutionId); + getAsyncRequest.addParameter("error_trace", "false"); + awaitAsyncRequestDoneRunning(getAsyncRequest); } // check that the stack trace was sent from the data node to the coordinating node assertTrue(transportMessageHasStackTrace.getAsBoolean()); } - private Map performRequestAndGetResponseEntityAfterDelay(Request r, TimeValue sleep) throws IOException, - InterruptedException { - Thread.sleep(sleep.millis()); + private Map performRequestAndGetResponseEntity(Request r) throws IOException { Response response = getRestClient().performRequest(r); XContentType entityContentType = XContentType.fromMediaType(response.getEntity().getContentType().getValue()); return XContentHelper.convertToMap(entityContentType.xContent(), response.getEntity().getContent(), false); } + + private void awaitAsyncRequestDoneRunning(Request getAsyncRequest) throws Exception { + assertBusy(() -> { + Map getAsyncResponseEntity = performRequestAndGetResponseEntity(getAsyncRequest); + assertFalse((Boolean) getAsyncResponseEntity.get("is_running")); + }); + } } diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java index d951b21ba1380..7cc43d43c6ff5 100644 --- a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java @@ -1778,6 +1778,7 @@ public void testCancellationViaTimeoutWithAllowPartialResultsSetToFalse() throws } SearchListenerPlugin.waitLocalSearchStarted(); + SearchListenerPlugin.waitRemoteSearchStarted(); // ensure tasks are present on both clusters and not cancelled try { diff --git a/x-pack/plugin/autoscaling/build.gradle b/x-pack/plugin/autoscaling/build.gradle index 24400a0fc418e..22a43654fd602 100644 --- a/x-pack/plugin/autoscaling/build.gradle +++ b/x-pack/plugin/autoscaling/build.gradle @@ -1,5 +1,6 @@ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' esplugin { name = 'x-pack-autoscaling' @@ -15,15 +16,21 @@ base { dependencies { compileOnly project(path: xpackModule('core')) - testImplementation(testArtifact(project(xpackModule('core')))) - testImplementation project(':modules:data-streams') - testImplementation project(path: xpackModule('blob-cache')) - testImplementation project(path: xpackModule('searchable-snapshots')) - testImplementation project(path: xpackModule('ilm')) - testImplementation project(path: xpackModule('slm')) - testImplementation project(path: xpackModule('ccr')) + testImplementation testArtifact(project(':server')) + testImplementation testArtifact(project(xpackModule('core'))) testImplementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}" + + internalClusterTestImplementation project(':modules:data-streams') + internalClusterTestImplementation project(xpackModule('blob-cache')) + internalClusterTestImplementation project(xpackModule("searchable-snapshots")) + internalClusterTestImplementation project(xpackModule('ilm')) + internalClusterTestImplementation project(xpackModule('slm')) + internalClusterTestImplementation project(xpackModule('ccr')) } -addQaCheckDependencies(project) +restResources { + restApi { + include '_common', 'autoscaling' + } +} diff --git a/x-pack/plugin/autoscaling/qa/build.gradle b/x-pack/plugin/autoscaling/qa/build.gradle deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/x-pack/plugin/autoscaling/qa/rest/build.gradle b/x-pack/plugin/autoscaling/qa/rest/build.gradle deleted file mode 100644 index 903e76fd986cf..0000000000000 --- a/x-pack/plugin/autoscaling/qa/rest/build.gradle +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -apply plugin: 'elasticsearch.legacy-yaml-rest-test' -apply plugin: 'elasticsearch.legacy-yaml-rest-compat-test' - -dependencies { - yamlRestTestImplementation(testArtifact(project(xpackModule('core')))) -} - -restResources { - restApi { - include '_common', 'autoscaling' - } -} - -testClusters.configureEach { - testDistribution = 'DEFAULT' - setting 'xpack.security.enabled', 'true' - setting 'xpack.license.self_generated.type', 'trial' - extraConfigFile 'roles.yml', file('autoscaling-roles.yml') - user username: 'autoscaling-admin', password: 'autoscaling-admin-password', role: 'superuser' - user username: 'autoscaling-user', password: 'autoscaling-user-password', role: 'autoscaling' -} diff --git a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java index e451b1d45817d..c76a88b0da2f9 100644 --- a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java +++ b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java @@ -960,6 +960,8 @@ private ExtendedClusterInfo(Map extraShardSizes, ClusterInfo info) Map.of(), Map.of(), Map.of(), + Map.of(), + Map.of(), Map.of() ); this.delegate = info; diff --git a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/AutoscalingCalculateCapacityServiceTests.java b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/AutoscalingCalculateCapacityServiceTests.java index 12f7dde103c9c..dfc44b64cb691 100644 --- a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/AutoscalingCalculateCapacityServiceTests.java +++ b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/capacity/AutoscalingCalculateCapacityServiceTests.java @@ -262,7 +262,7 @@ public void testContext() { } } state = ClusterState.builder(ClusterName.DEFAULT).nodes(nodes).build(); - info = new ClusterInfo(leastUsages, mostUsages, Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); + info = ClusterInfo.builder().leastAvailableSpaceUsage(leastUsages).mostAvailableSpaceUsage(mostUsages).build(); context = new AutoscalingCalculateCapacityService.DefaultAutoscalingDeciderContext( roleNames, state, @@ -311,7 +311,7 @@ public void testContext() { ) ); - info = new ClusterInfo(leastUsages, mostUsages, Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); + info = ClusterInfo.builder().leastAvailableSpaceUsage(leastUsages).mostAvailableSpaceUsage(mostUsages).build(); context = new AutoscalingCalculateCapacityService.DefaultAutoscalingDeciderContext( roleNames, state, diff --git a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/FrozenStorageDeciderServiceTests.java b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/FrozenStorageDeciderServiceTests.java index 37295ebf44208..286b5e48010ea 100644 --- a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/FrozenStorageDeciderServiceTests.java +++ b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/FrozenStorageDeciderServiceTests.java @@ -109,7 +109,7 @@ public Tuple sizeAndClusterInfo(IndexMetadata indexMetadata) // add irrelevant shards noise for completeness (should not happen IRL). sizes.put(new ShardId(index, i), randomLongBetween(0, Integer.MAX_VALUE)); } - ClusterInfo info = new ClusterInfo(Map.of(), Map.of(), Map.of(), sizes, Map.of(), Map.of(), Map.of()); + ClusterInfo info = ClusterInfo.builder().shardDataSetSizes(sizes).build(); return Tuple.tuple(totalSize, info); } } diff --git a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ProactiveStorageDeciderServiceTests.java b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ProactiveStorageDeciderServiceTests.java index 8c1f18e84a619..0d054f45367bc 100644 --- a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ProactiveStorageDeciderServiceTests.java +++ b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ProactiveStorageDeciderServiceTests.java @@ -408,7 +408,7 @@ private ClusterInfo randomClusterInfo(ProjectState projectState) { for (var id : projectState.cluster().nodes().getDataNodes().keySet()) { diskUsage.put(id, new DiskUsage(id, id, "/test", Long.MAX_VALUE, Long.MAX_VALUE)); } - return new ClusterInfo(diskUsage, diskUsage, shardSizes, Map.of(), Map.of(), Map.of(), Map.of()); + return ClusterInfo.builder().leastAvailableSpaceUsage(diskUsage).mostAvailableSpaceUsage(diskUsage).shardSizes(shardSizes).build(); } private ProjectMetadata applyCreatedDates(ProjectMetadata project, DataStream ds, long last, long decrement) { diff --git a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderServiceTests.java b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderServiceTests.java index 2ee94340f6d2c..18115a35039b2 100644 --- a/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderServiceTests.java +++ b/x-pack/plugin/autoscaling/src/test/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderServiceTests.java @@ -379,7 +379,7 @@ public void validateSizeOf(ClusterState clusterState, ShardRouting subjectShard, } private ReactiveStorageDeciderService.AllocationState createAllocationState(Map shardSize, ClusterState clusterState) { - ClusterInfo info = new ClusterInfo(Map.of(), Map.of(), shardSize, Map.of(), Map.of(), Map.of(), Map.of()); + ClusterInfo info = ClusterInfo.builder().shardSizes(shardSize).build(); ReactiveStorageDeciderService.AllocationState allocationState = new ReactiveStorageDeciderService.AllocationState( clusterState, null, @@ -544,7 +544,11 @@ public void testUnmovableSize() { } var diskUsages = Map.of(nodeId, new DiskUsage(nodeId, null, null, ByteSizeUnit.KB.toBytes(100), ByteSizeUnit.KB.toBytes(5))); - ClusterInfo info = new ClusterInfo(diskUsages, diskUsages, shardSize, Map.of(), Map.of(), Map.of(), Map.of()); + ClusterInfo info = ClusterInfo.builder() + .leastAvailableSpaceUsage(diskUsages) + .mostAvailableSpaceUsage(diskUsages) + .shardSizes(shardSize) + .build(); ReactiveStorageDeciderService.AllocationState allocationState = new ReactiveStorageDeciderService.AllocationState( clusterState, diff --git a/x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/autoscaling/AutoscalingRestIT.java b/x-pack/plugin/autoscaling/src/yamlRestTest/java/org/elasticsearch/xpack/autoscaling/AutoscalingRestIT.java similarity index 67% rename from x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/autoscaling/AutoscalingRestIT.java rename to x-pack/plugin/autoscaling/src/yamlRestTest/java/org/elasticsearch/xpack/autoscaling/AutoscalingRestIT.java index 89bc24ecc1ed0..15bef71fe1ea5 100644 --- a/x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/autoscaling/AutoscalingRestIT.java +++ b/x-pack/plugin/autoscaling/src/yamlRestTest/java/org/elasticsearch/xpack/autoscaling/AutoscalingRestIT.java @@ -12,13 +12,24 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.util.resource.Resource; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; - -import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import org.junit.ClassRule; public class AutoscalingRestIT extends ESClientYamlSuiteTestCase { + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .module("x-pack-autoscaling") + .setting("xpack.security.enabled", "true") + .setting("xpack.license.self_generated.type", "trial") + .rolesFile(Resource.fromClasspath("autoscaling-roles.yml")) + .user("autoscaling-admin", "autoscaling-admin-password", "superuser", false) + .user("autoscaling-user", "autoscaling-user-password", "autoscaling", false) + .build(); + public AutoscalingRestIT(final ClientYamlTestCandidate testCandidate) { super(testCandidate); } @@ -40,4 +51,8 @@ protected Settings restClientSettings() { return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", value).build(); } + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } } diff --git a/x-pack/plugin/autoscaling/qa/rest/autoscaling-roles.yml b/x-pack/plugin/autoscaling/src/yamlRestTest/resources/autoscaling-roles.yml similarity index 100% rename from x-pack/plugin/autoscaling/qa/rest/autoscaling-roles.yml rename to x-pack/plugin/autoscaling/src/yamlRestTest/resources/autoscaling-roles.yml diff --git a/x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/delete_autoscaling_policy.yml b/x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/delete_autoscaling_policy.yml similarity index 100% rename from x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/delete_autoscaling_policy.yml rename to x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/delete_autoscaling_policy.yml diff --git a/x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_capacity.yml b/x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_capacity.yml similarity index 100% rename from x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_capacity.yml rename to x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_capacity.yml diff --git a/x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_policy.yml b/x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_policy.yml similarity index 100% rename from x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_policy.yml rename to x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/get_autoscaling_policy.yml diff --git a/x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/put_autoscaling_policy.yml b/x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/put_autoscaling_policy.yml similarity index 100% rename from x-pack/plugin/autoscaling/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/put_autoscaling_policy.yml rename to x-pack/plugin/autoscaling/src/yamlRestTest/resources/rest-api-spec/test/autoscaling/put_autoscaling_policy.yml diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index 4a202562e5e3d..3fc66d504a76e 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -492,8 +492,7 @@ CacheFileRegion get(KeyType cacheKey, long fileLength, int region) { } /** - * Fetch and cache the full blob for the given cache entry from the remote repository if there - * are enough free pages in the cache to do so. + * Fetch and write in cache a region of a blob if there are enough free pages in the cache to do so. *

* This method returns as soon as the download tasks are instantiated, but the tasks themselves * are run on the bulk executor. @@ -502,67 +501,32 @@ CacheFileRegion get(KeyType cacheKey, long fileLength, int region) { * and unlinked * * @param cacheKey the key to fetch data for - * @param length the length of the blob to fetch + * @param region the region of the blob to fetch + * @param blobLength the length of the blob from which the region is fetched (used to compute the size of the ending region) * @param writer a writer that handles writing of newly downloaded data to the shared cache * @param fetchExecutor an executor to use for reading from the blob store - * @param listener listener that is called once all downloading has finished - * @return {@code true} if there were enough free pages to start downloading the full entry + * @param listener a listener that is completed with {@code true} if the current thread triggered the fetching of the region, in + * which case the data is available in cache. The listener is completed with {@code false} in every other cases: if + * the region to write is already available in cache, if the region is pending fetching via another thread or if + * there is not enough free pages to fetch the region. */ - public boolean maybeFetchFullEntry( - KeyType cacheKey, - long length, - RangeMissingHandler writer, - Executor fetchExecutor, - ActionListener listener + public void maybeFetchRegion( + final KeyType cacheKey, + final int region, + final long blobLength, + final RangeMissingHandler writer, + final Executor fetchExecutor, + final ActionListener listener ) { - int finalRegion = getEndingRegion(length); - // TODO freeRegionCount uses freeRegions.size() which is is NOT a constant-time operation. Can we do better? - if (freeRegionCount() < finalRegion) { - // Not enough room to download a full file without evicting existing data, so abort - listener.onResponse(null); - return false; - } - long regionLength = regionSize; - try (RefCountingListener refCountingListener = new RefCountingListener(listener)) { - for (int region = 0; region <= finalRegion; region++) { - if (region == finalRegion) { - regionLength = length - getRegionStart(region); - } - ByteRange rangeToWrite = ByteRange.of(0, regionLength); - if (rangeToWrite.isEmpty()) { - return true; - } - final ActionListener regionListener = refCountingListener.acquire(ignored -> {}); - final CacheFileRegion entry; - try { - entry = get(cacheKey, length, region); - } catch (AlreadyClosedException e) { - // failed to grab a cache page because some other operation concurrently acquired some - regionListener.onResponse(0); - return false; - } - // set read range == write range so the listener completes only once all the bytes have been downloaded - entry.populateAndRead( - rangeToWrite, - rangeToWrite, - (channel, pos, relativePos, len) -> Math.toIntExact(len), - writer, - fetchExecutor, - regionListener.delegateResponse((l, e) -> { - if (e instanceof AlreadyClosedException) { - l.onResponse(0); - } else { - l.onFailure(e); - } - }) - ); - } - } - return true; + fetchRegion(cacheKey, region, blobLength, writer, fetchExecutor, false, listener); } /** - * Fetch and write in cache a region of a blob if there are enough free pages in the cache to do so. + * Fetch and write in cache a region of a blob. + *

+ * If {@code force} is {@code true} and no free regions remain, an existing region will be evicted to make room. + *

+ * *

* This method returns as soon as the download tasks are instantiated, but the tasks themselves * are run on the bulk executor. @@ -575,20 +539,23 @@ public boolean maybeFetchFullEntry( * @param blobLength the length of the blob from which the region is fetched (used to compute the size of the ending region) * @param writer a writer that handles writing of newly downloaded data to the shared cache * @param fetchExecutor an executor to use for reading from the blob store + * @param force flag indicating whether the cache should free an occupied region to accommodate the requested + * region when none are free. * @param listener a listener that is completed with {@code true} if the current thread triggered the fetching of the region, in * which case the data is available in cache. The listener is completed with {@code false} in every other cases: if * the region to write is already available in cache, if the region is pending fetching via another thread or if * there is not enough free pages to fetch the region. */ - public void maybeFetchRegion( + public void fetchRegion( final KeyType cacheKey, final int region, final long blobLength, final RangeMissingHandler writer, final Executor fetchExecutor, + final boolean force, final ActionListener listener ) { - if (freeRegions.isEmpty() && maybeEvictLeastUsed() == false) { + if (force == false && freeRegions.isEmpty() && maybeEvictLeastUsed() == false) { // no free page available and no old enough unused region to be evicted logger.info("No free regions, skipping loading region [{}]", region); listener.onResponse(false); @@ -636,7 +603,45 @@ public void maybeFetchRange( final Executor fetchExecutor, final ActionListener listener ) { - if (freeRegions.isEmpty() && maybeEvictLeastUsed() == false) { + fetchRange(cacheKey, region, range, blobLength, writer, fetchExecutor, false, listener); + } + + /** + * Fetch and write in cache a range within a blob region. + *

+ * If {@code force} is {@code true} and no free regions remain, an existing region will be evicted to make room. + *

+ *

+ * This method returns as soon as the download tasks are instantiated, but the tasks themselves + * are run on the bulk executor. + *

+ * If an exception is thrown from the writer then the cache entry being downloaded is freed + * and unlinked + * + * @param cacheKey the key to fetch data for + * @param region the region of the blob + * @param range the range of the blob to fetch + * @param blobLength the length of the blob from which the region is fetched (used to compute the size of the ending region) + * @param writer a writer that handles writing of newly downloaded data to the shared cache + * @param fetchExecutor an executor to use for reading from the blob store + * @param force flag indicating whether the cache should free an occupied region to accommodate the requested + * range when none are free. + * @param listener a listener that is completed with {@code true} if the current thread triggered the fetching of the range, in + * which case the data is available in cache. The listener is completed with {@code false} in every other cases: if + * the range to write is already available in cache, if the range is pending fetching via another thread or if + * there is not enough free pages to fetch the range. + */ + public void fetchRange( + final KeyType cacheKey, + final int region, + final ByteRange range, + final long blobLength, + final RangeMissingHandler writer, + final Executor fetchExecutor, + final boolean force, + final ActionListener listener + ) { + if (force == false && freeRegions.isEmpty() && maybeEvictLeastUsed() == false) { // no free page available and no old enough unused region to be evicted logger.info("No free regions, skipping loading region [{}]", region); listener.onResponse(false); @@ -723,8 +728,6 @@ private static void throwAlreadyClosed(String message) { /** * NOTE: Method is package private mostly to allow checking the number of fee regions in tests. - * However, it is also used by {@link SharedBlobCacheService#maybeFetchFullEntry} but we should try - * to move away from that because calling "size" on a ConcurrentLinkedQueue is not a constant time operation. */ int freeRegionCount() { return freeRegions.size(); diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java index 04658606ce132..1b3335d47b1f0 100644 --- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java +++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java @@ -565,139 +565,6 @@ public void testGetMultiThreaded() throws IOException { } } - public void testFetchFullCacheEntry() throws Exception { - Settings settings = Settings.builder() - .put(NODE_NAME_SETTING.getKey(), "node") - .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(500)).getStringRep()) - .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(100)).getStringRep()) - .put("path.home", createTempDir()) - .build(); - - final var bulkTaskCount = new AtomicInteger(0); - final var threadPool = new TestThreadPool("test"); - final var bulkExecutor = new StoppableExecutorServiceWrapper(threadPool.generic()) { - @Override - public void execute(Runnable command) { - super.execute(command); - bulkTaskCount.incrementAndGet(); - } - }; - - try ( - NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); - var cacheService = new SharedBlobCacheService<>( - environment, - settings, - threadPool, - threadPool.executor(ThreadPool.Names.GENERIC), - BlobCacheMetrics.NOOP - ) - ) { - { - final var cacheKey = generateCacheKey(); - assertEquals(5, cacheService.freeRegionCount()); - final long size = size(250); - AtomicLong bytesRead = new AtomicLong(size); - final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchFullEntry( - cacheKey, - size, - (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( - completionListener, - () -> { - assert streamFactory == null : streamFactory; - bytesRead.addAndGet(-length); - progressUpdater.accept(length); - } - ), - bulkExecutor, - future - ); - - future.get(10, TimeUnit.SECONDS); - assertEquals(0L, bytesRead.get()); - assertEquals(2, cacheService.freeRegionCount()); - assertEquals(3, bulkTaskCount.get()); - } - { - // a download that would use up all regions should not run - final var cacheKey = generateCacheKey(); - assertEquals(2, cacheService.freeRegionCount()); - var configured = cacheService.maybeFetchFullEntry( - cacheKey, - size(500), - (ch, chPos, streamFactory, relPos, len, update, completionListener) -> completeWith(completionListener, () -> { - throw new AssertionError("Should never reach here"); - }), - bulkExecutor, - ActionListener.noop() - ); - assertFalse(configured); - assertEquals(2, cacheService.freeRegionCount()); - } - } - - threadPool.shutdown(); - } - - public void testFetchFullCacheEntryConcurrently() throws Exception { - Settings settings = Settings.builder() - .put(NODE_NAME_SETTING.getKey(), "node") - .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(500)).getStringRep()) - .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(100)).getStringRep()) - .put("path.home", createTempDir()) - .build(); - - final var threadPool = new TestThreadPool("test"); - final var bulkExecutor = new StoppableExecutorServiceWrapper(threadPool.generic()); - - try ( - NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); - var cacheService = new SharedBlobCacheService<>( - environment, - settings, - threadPool, - threadPool.executor(ThreadPool.Names.GENERIC), - BlobCacheMetrics.NOOP - ) - ) { - - final long size = size(randomIntBetween(1, 100)); - final Thread[] threads = new Thread[10]; - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - for (int j = 0; j < 1000; j++) { - final var cacheKey = generateCacheKey(); - safeAwait( - (ActionListener listener) -> cacheService.maybeFetchFullEntry( - cacheKey, - size, - ( - channel, - channelPos, - streamFactory, - relativePos, - length, - progressUpdater, - completionListener) -> completeWith(completionListener, () -> progressUpdater.accept(length)), - bulkExecutor, - listener - ) - ); - } - }); - } - for (Thread thread : threads) { - thread.start(); - } - for (Thread thread : threads) { - thread.join(); - } - } finally { - assertTrue(ThreadPool.terminate(threadPool, 10L, TimeUnit.SECONDS)); - } - } - public void testCacheSizeRejectedOnNonFrozenNodes() { String cacheSize = randomBoolean() ? ByteSizeValue.ofBytes(size(500)).getStringRep() @@ -1130,6 +997,195 @@ public void execute(Runnable command) { threadPool.shutdown(); } + public void testFetchRegion() throws Exception { + final long cacheSize = size(500L); + final long regionSize = size(100L); + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(cacheSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put("path.home", createTempDir()) + .build(); + + final var bulkTaskCount = new AtomicInteger(0); + final var threadPool = new TestThreadPool("test"); + final var bulkExecutor = new StoppableExecutorServiceWrapper(threadPool.generic()) { + @Override + public void execute(Runnable command) { + super.execute(command); + bulkTaskCount.incrementAndGet(); + } + }; + + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService<>( + environment, + settings, + threadPool, + threadPool.executor(ThreadPool.Names.GENERIC), + BlobCacheMetrics.NOOP + ) + ) { + { + // fetch a single region + final var cacheKey = generateCacheKey(); + assertEquals(5, cacheService.freeRegionCount()); + final long blobLength = size(250); // 3 regions + AtomicLong bytesRead = new AtomicLong(0L); + final PlainActionFuture future = new PlainActionFuture<>(); + cacheService.fetchRegion( + cacheKey, + 0, + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + assert streamFactory == null : streamFactory; + bytesRead.addAndGet(length); + progressUpdater.accept(length); + } + ), + bulkExecutor, + true, + future + ); + + var fetched = future.get(10, TimeUnit.SECONDS); + assertThat("Region has been fetched", fetched, is(true)); + assertEquals(regionSize, bytesRead.get()); + assertEquals(4, cacheService.freeRegionCount()); + assertEquals(1, bulkTaskCount.get()); + } + { + // fetch multiple regions to used all the cache + final int remainingFreeRegions = cacheService.freeRegionCount(); + assertEquals(4, cacheService.freeRegionCount()); + + final var cacheKey = generateCacheKey(); + final long blobLength = regionSize * remainingFreeRegions; + AtomicLong bytesRead = new AtomicLong(0L); + + final PlainActionFuture> future = new PlainActionFuture<>(); + final var listener = new GroupedActionListener<>(remainingFreeRegions, future); + for (int region = 0; region < remainingFreeRegions; region++) { + cacheService.fetchRegion( + cacheKey, + region, + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + assert streamFactory == null : streamFactory; + bytesRead.addAndGet(length); + progressUpdater.accept(length); + } + ), + bulkExecutor, + true, + listener + ); + } + + var results = future.get(10, TimeUnit.SECONDS); + assertThat(results.stream().allMatch(result -> result), is(true)); + assertEquals(blobLength, bytesRead.get()); + assertEquals(0, cacheService.freeRegionCount()); + assertEquals(1 + remainingFreeRegions, bulkTaskCount.get()); + } + { + // cache fully used, no entry old enough to be evicted and force=false should not evict entries + assertEquals(0, cacheService.freeRegionCount()); + final var cacheKey = generateCacheKey(); + final PlainActionFuture future = new PlainActionFuture<>(); + cacheService.fetchRegion( + cacheKey, + 0, + regionSize, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + throw new AssertionError("should not be executed"); + } + ), + bulkExecutor, + false, + future + ); + assertThat("Listener is immediately completed", future.isDone(), is(true)); + assertThat("Region already exists in cache", future.get(), is(false)); + } + { + // cache fully used, but force=true, so the cache should evict regions to make space for the requested regions + assertEquals(0, cacheService.freeRegionCount()); + AtomicLong bytesRead = new AtomicLong(0L); + final var cacheKey = generateCacheKey(); + final PlainActionFuture> future = new PlainActionFuture<>(); + var regionsToFetch = randomIntBetween(1, (int) (cacheSize / regionSize)); + final var listener = new GroupedActionListener<>(regionsToFetch, future); + long blobLength = regionsToFetch * regionSize; + for (int region = 0; region < regionsToFetch; region++) { + cacheService.fetchRegion( + cacheKey, + region, + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + assert streamFactory == null : streamFactory; + bytesRead.addAndGet(length); + progressUpdater.accept(length); + } + ), + bulkExecutor, + true, + listener + ); + } + + var results = future.get(10, TimeUnit.SECONDS); + assertThat(results.stream().allMatch(result -> result), is(true)); + assertEquals(blobLength, bytesRead.get()); + assertEquals(0, cacheService.freeRegionCount()); + assertEquals(regionsToFetch + 5, bulkTaskCount.get()); + } + { + cacheService.computeDecay(); + + // We explicitly called computeDecay, meaning that some regions must have been demoted to level 0, + // therefore there should be enough room to fetch the requested range regardless of the force flag. + final var cacheKey = generateCacheKey(); + assertEquals(0, cacheService.freeRegionCount()); + long blobLength = randomLongBetween(1L, regionSize); + AtomicLong bytesRead = new AtomicLong(0L); + final PlainActionFuture future = new PlainActionFuture<>(); + cacheService.fetchRegion( + cacheKey, + 0, + blobLength, + (channel, channelPos, ignore, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + assert ignore == null : ignore; + bytesRead.addAndGet(length); + progressUpdater.accept(length); + } + ), + bulkExecutor, + randomBoolean(), + future + ); + + var fetched = future.get(10, TimeUnit.SECONDS); + assertThat("Region has been fetched", fetched, is(true)); + assertEquals(blobLength, bytesRead.get()); + assertEquals(0, cacheService.freeRegionCount()); + } + } finally { + TestThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + public void testMaybeFetchRange() throws Exception { final long cacheSize = size(500L); final long regionSize = size(100L); @@ -1301,6 +1357,208 @@ public void execute(Runnable command) { threadPool.shutdown(); } + public void testFetchRange() throws Exception { + final long cacheSize = size(500L); + final long regionSize = size(100L); + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(cacheSize).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put("path.home", createTempDir()) + .build(); + + final var bulkTaskCount = new AtomicInteger(0); + final var threadPool = new TestThreadPool("test"); + final var bulkExecutor = new StoppableExecutorServiceWrapper(threadPool.generic()) { + @Override + public void execute(Runnable command) { + super.execute(command); + bulkTaskCount.incrementAndGet(); + } + }; + + try ( + NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings)); + var cacheService = new SharedBlobCacheService<>( + environment, + settings, + threadPool, + threadPool.executor(ThreadPool.Names.GENERIC), + BlobCacheMetrics.NOOP + ) + ) { + { + // fetch a random range in a random region of the blob + final var cacheKey = generateCacheKey(); + assertEquals(5, cacheService.freeRegionCount()); + + // blobLength is 1024000 bytes and requires 3 regions + final long blobLength = size(250); + final var regions = List.of( + // region 0: 0-409600 + ByteRange.of(cacheService.getRegionStart(0), cacheService.getRegionEnd(0)), + // region 1: 409600-819200 + ByteRange.of(cacheService.getRegionStart(1), cacheService.getRegionEnd(1)), + // region 2: 819200-1228800 + ByteRange.of(cacheService.getRegionStart(2), cacheService.getRegionEnd(2)) + ); + + long pos = randomLongBetween(0, blobLength - 1L); + long len = randomLongBetween(1, blobLength - pos); + var range = ByteRange.of(pos, pos + len); + var region = between(0, regions.size() - 1); + var regionRange = cacheService.mapSubRangeToRegion(range, region); + + var bytesCopied = new AtomicLong(0L); + var future = new PlainActionFuture(); + cacheService.maybeFetchRange( + cacheKey, + region, + range, + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + assertThat(range.start() + relativePos, equalTo(cacheService.getRegionStart(region) + regionRange.start())); + assertThat(channelPos, equalTo(Math.toIntExact(regionRange.start()))); + assertThat(length, equalTo(Math.toIntExact(regionRange.length()))); + bytesCopied.addAndGet(length); + } + ), + bulkExecutor, + future + ); + var fetched = future.get(10, TimeUnit.SECONDS); + + assertThat(regionRange.length(), equalTo(bytesCopied.get())); + if (regionRange.isEmpty()) { + assertThat(fetched, is(false)); + assertEquals(5, cacheService.freeRegionCount()); + assertEquals(0, bulkTaskCount.get()); + } else { + assertThat(fetched, is(true)); + assertEquals(4, cacheService.freeRegionCount()); + assertEquals(1, bulkTaskCount.get()); + } + } + { + // fetch multiple ranges to use all the cache + final int remainingFreeRegions = cacheService.freeRegionCount(); + assertThat(remainingFreeRegions, greaterThanOrEqualTo(4)); + bulkTaskCount.set(0); + + final var cacheKey = generateCacheKey(); + final long blobLength = regionSize * remainingFreeRegions; + AtomicLong bytesCopied = new AtomicLong(0L); + + final PlainActionFuture> future = new PlainActionFuture<>(); + final var listener = new GroupedActionListener<>(remainingFreeRegions, future); + for (int region = 0; region < remainingFreeRegions; region++) { + cacheService.fetchRange( + cacheKey, + region, + ByteRange.of(0L, blobLength), + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> bytesCopied.addAndGet(length) + ), + bulkExecutor, + true, + listener + ); + } + + var results = future.get(10, TimeUnit.SECONDS); + assertThat(results.stream().allMatch(result -> result), is(true)); + assertEquals(blobLength, bytesCopied.get()); + assertEquals(0, cacheService.freeRegionCount()); + assertEquals(remainingFreeRegions, bulkTaskCount.get()); + } + { + // cache fully used, no entry old enough to be evicted and force=false + assertEquals(0, cacheService.freeRegionCount()); + final var cacheKey = generateCacheKey(); + final var blobLength = randomLongBetween(1L, regionSize); + final PlainActionFuture future = new PlainActionFuture<>(); + cacheService.fetchRange( + cacheKey, + randomIntBetween(0, 10), + ByteRange.of(0L, blobLength), + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> { + throw new AssertionError("should not be executed"); + } + ), + bulkExecutor, + false, + future + ); + assertThat("Listener is immediately completed", future.isDone(), is(true)); + assertThat("Region already exists in cache", future.get(), is(false)); + } + { + // cache fully used, since force=true the range should be populated + final var cacheKey = generateCacheKey(); + assertEquals(0, cacheService.freeRegionCount()); + long blobLength = randomLongBetween(1L, regionSize); + AtomicLong bytesCopied = new AtomicLong(0L); + final PlainActionFuture future = new PlainActionFuture<>(); + cacheService.fetchRange( + cacheKey, + 0, + ByteRange.of(0L, blobLength), + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> bytesCopied.addAndGet(length) + ), + bulkExecutor, + true, + future + ); + + var fetched = future.get(10, TimeUnit.SECONDS); + assertThat("Region has been fetched", fetched, is(true)); + assertEquals(blobLength, bytesCopied.get()); + assertEquals(0, cacheService.freeRegionCount()); + } + { + cacheService.computeDecay(); + + // We explicitly called computeDecay, meaning that some regions must have been demoted to level 0, + // therefore there should be enough room to fetch the requested range regardless of the force flag. + final var cacheKey = generateCacheKey(); + assertEquals(0, cacheService.freeRegionCount()); + long blobLength = randomLongBetween(1L, regionSize); + AtomicLong bytesCopied = new AtomicLong(0L); + final PlainActionFuture future = new PlainActionFuture<>(); + cacheService.fetchRange( + cacheKey, + 0, + ByteRange.of(0L, blobLength), + blobLength, + (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith( + completionListener, + () -> bytesCopied.addAndGet(length) + ), + bulkExecutor, + randomBoolean(), + future + ); + + var fetched = future.get(10, TimeUnit.SECONDS); + assertThat("Region has been fetched", fetched, is(true)); + assertEquals(blobLength, bytesCopied.get()); + assertEquals(0, cacheService.freeRegionCount()); + } + } finally { + TestThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + public void testPopulate() throws Exception { final long regionSize = size(1L); Settings settings = Settings.builder() diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index 8839aadb13716..f7d28c1ef8bfe 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -135,9 +135,10 @@ tasks.named("yamlRestCompatTestTransform").configure({ task -> task.skipTest("esql/63_enrich_int_range/Invalid age as double", "TODO: require disable allow_partial_results") task.skipTest("esql/191_lookup_join_on_datastreams/data streams not supported in LOOKUP JOIN", "Added support for aliases in JOINs") task.skipTest("esql/190_lookup_join/non-lookup index", "Error message changed") - task.skipTest("esql/192_lookup_join_on_aliases/alias-pattern-multiple", "Error message changed") task.skipTest("esql/190_lookup_join/fails with non-lookup index", "Error message changed") + task.skipTest("esql/192_lookup_join_on_aliases/alias-pattern-multiple", "Error message changed") task.skipTest("esql/192_lookup_join_on_aliases/fails when alias or pattern resolves to multiple", "Error message changed") + task.skipTest("esql/10_basic/Test wrong LIMIT parameter", "Error message changed") }) tasks.named('yamlRestCompatTest').configure { diff --git a/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java b/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java index 7c47237d35bd5..6c56724a6f20a 100644 --- a/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java +++ b/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.PlainActionFuture; @@ -51,6 +52,7 @@ import org.elasticsearch.snapshots.SnapshotId; import org.elasticsearch.snapshots.SnapshotShardSizeInfo; import org.elasticsearch.snapshots.SnapshotsInfoService; +import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.transport.TransportActionProxy; import org.elasticsearch.transport.TransportService; @@ -655,39 +657,39 @@ public void testCcrRepositoryFailsToFetchSnapshotShardSizes() throws Exception { try { final SnapshotsInfoService snapshotsInfoService = getFollowerCluster().getCurrentMasterNodeInstance(SnapshotsInfoService.class); + final ClusterService clusterService = getFollowerCluster().getCurrentMasterNodeInstance(ClusterService.class); final PlainActionFuture waitForAllShardSnapshotSizesFailures = new PlainActionFuture<>(); - final ClusterStateListener listener = event -> { - if (RestoreInProgress.get(event.state()).isEmpty() == false && event.state().routingTable().hasIndex(followerIndex)) { - try { - final IndexRoutingTable indexRoutingTable = event.state().routingTable().index(followerIndex); - // this assertBusy completes because the listener is added after the InternalSnapshotsInfoService - // and ClusterService preserves the order of listeners. - assertBusy(() -> { - List sizes = indexRoutingTable.shardsWithState(ShardRoutingState.UNASSIGNED) - .stream() - .filter(shard -> shard.unassignedInfo().lastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA) - .sorted(Comparator.comparingInt(ShardRouting::getId)) - .map(shard -> snapshotsInfoService.snapshotShardSizes().getShardSize(shard)) - .filter(Objects::nonNull) - .filter(size -> ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE == size) - .collect(Collectors.toList()); - assertThat(sizes, hasSize(numberOfShards)); - }); - waitForAllShardSnapshotSizesFailures.onResponse(null); - } catch (Exception e) { - throw new AssertionError("Failed to retrieve all snapshot shard sizes", e); - } + ClusterServiceUtils.addTemporaryStateListener( + clusterService, + state -> RestoreInProgress.get(state).isEmpty() == false && state.routingTable().hasIndex(followerIndex) + ).addListener(ActionTestUtils.assertNoFailureListener(ignore -> { + try { + // This listener runs synchronously in the same thread so that clusterService.state() returns the same state + // that satisfied the predicate. + final IndexRoutingTable indexRoutingTable = clusterService.state().routingTable().index(followerIndex); + // this assertBusy completes because the listener is added after the InternalSnapshotsInfoService + // and ClusterService preserves the order of listeners. + assertBusy(() -> { + List sizes = indexRoutingTable.shardsWithState(ShardRoutingState.UNASSIGNED) + .stream() + .filter(shard -> shard.unassignedInfo().lastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA) + .sorted(Comparator.comparingInt(ShardRouting::getId)) + .map(shard -> snapshotsInfoService.snapshotShardSizes().getShardSize(shard)) + .filter(Objects::nonNull) + .filter(size -> ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE == size) + .collect(Collectors.toList()); + assertThat(sizes, hasSize(numberOfShards)); + }); + waitForAllShardSnapshotSizesFailures.onResponse(null); + } catch (Exception e) { + throw new AssertionError("Failed to retrieve all snapshot shard sizes", e); } - }; - - final ClusterService clusterService = getFollowerCluster().getCurrentMasterNodeInstance(ClusterService.class); - clusterService.addListener(listener); + })); logger.debug("--> creating follower index [{}]", followerIndex); followerClient().execute(PutFollowAction.INSTANCE, putFollow(leaderIndex, followerIndex, ActiveShardCount.NONE)); waitForAllShardSnapshotSizesFailures.get(30L, TimeUnit.SECONDS); - clusterService.removeListener(listener); assertThat(simulatedFailures.get(), equalTo(numberOfShards)); diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java index 029ea6dcd6871..717ec4761c87e 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java @@ -32,6 +32,7 @@ import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.MappingMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; @@ -43,6 +44,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; @@ -118,7 +120,7 @@ public ShardFollowTasksExecutor(Client client, ThreadPool threadPool, ClusterSer } @Override - public void validate(ShardFollowTask params, ClusterState clusterState) { + public void validate(ShardFollowTask params, ClusterState clusterState, @Nullable ProjectId projectId) { final IndexRoutingTable routingTable = clusterState.getRoutingTable().index(params.getFollowShardId().getIndex()); final ShardRouting primaryShard = routingTable.shard(params.getFollowShardId().id()).primaryShard(); if (primaryShard.active() == false) { @@ -129,10 +131,11 @@ public void validate(ShardFollowTask params, ClusterState clusterState) { private static final Assignment NO_ASSIGNMENT = new Assignment(null, "no nodes found with data and remote cluster client roles"); @Override - public Assignment getAssignment( + protected Assignment doGetAssignment( final ShardFollowTask params, - Collection candidateNodes, - final ClusterState clusterState + final Collection candidateNodes, + final ClusterState clusterState, + @Nullable final ProjectId projectId ) { final DiscoveryNode node = selectLeastLoadedNode( clusterState, diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java index 630aab4c78f43..7cb549df52301 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; @@ -93,7 +94,8 @@ private void runAssignmentTest( final Assignment assignment = executor.getAssignment( mock(ShardFollowTask.class), clusterStateBuilder.nodes().getAllNodes(), - clusterStateBuilder.build() + clusterStateBuilder.build(), + ProjectId.DEFAULT ); consumer.accept(theSpecial, assignment); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackSettings.java index 31edc83c00b3a..fb29dde69ae7c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackSettings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackSettings.java @@ -344,18 +344,14 @@ public static Setting defaultStoredSecureTokenHashAlgorithmSetting( public static final SslClientAuthenticationMode REMOTE_CLUSTER_CLIENT_AUTH_DEFAULT = SslClientAuthenticationMode.NONE; public static final SslVerificationMode VERIFICATION_MODE_DEFAULT = SslVerificationMode.FULL; - // http specific settings public static final String HTTP_SSL_PREFIX = SecurityField.setting("http.ssl."); - private static final SSLConfigurationSettings HTTP_SSL = SSLConfigurationSettings.withPrefix(HTTP_SSL_PREFIX, true); - - // transport specific settings public static final String TRANSPORT_SSL_PREFIX = SecurityField.setting("transport.ssl."); - private static final SSLConfigurationSettings TRANSPORT_SSL = SSLConfigurationSettings.withPrefix(TRANSPORT_SSL_PREFIX, true); - - // remote cluster specific settings public static final String REMOTE_CLUSTER_SERVER_SSL_PREFIX = SecurityField.setting("remote_cluster_server.ssl."); public static final String REMOTE_CLUSTER_CLIENT_SSL_PREFIX = SecurityField.setting("remote_cluster_client.ssl."); + private static final SSLConfigurationSettings HTTP_SSL = SSLConfigurationSettings.withPrefix(HTTP_SSL_PREFIX, true); + private static final SSLConfigurationSettings TRANSPORT_SSL = SSLConfigurationSettings.withPrefix(TRANSPORT_SSL_PREFIX, true); + private static final SSLConfigurationSettings REMOTE_CLUSTER_SERVER_SSL = SSLConfigurationSettings.withPrefix( REMOTE_CLUSTER_SERVER_SSL_PREFIX, false diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilter.java index 5bbd0b487e35b..390744abf0125 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilter.java @@ -22,14 +22,25 @@ public abstract class ApiFilteringActionFilter imple private final ThreadContext threadContext; private final String actionName; private final Class responseClass; + private final boolean filterOperatorRequests; protected ApiFilteringActionFilter(ThreadContext threadContext, String actionName, Class responseClass) { + this(threadContext, actionName, responseClass, false); + } + + protected ApiFilteringActionFilter( + ThreadContext threadContext, + String actionName, + Class responseClass, + boolean filterOperatorRequests + ) { assert threadContext != null : "threadContext cannot be null"; assert actionName != null : "actionName cannot be null"; assert responseClass != null : "responseClass cannot be null"; this.threadContext = threadContext; this.actionName = actionName; this.responseClass = responseClass; + this.filterOperatorRequests = filterOperatorRequests; } @Override @@ -46,7 +57,7 @@ public void app ActionFilterChain chain ) { final ActionListener responseFilteringListener; - if (isOperator(threadContext) == false && actionName.equals(action)) { + if ((filterOperatorRequests || isOperator(threadContext) == false) && actionName.equals(action)) { responseFilteringListener = listener.map(this::filter); } else { responseFilteringListener = listener; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index f01ac08f922c4..f58387d958fb3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -286,6 +286,12 @@ public final class Messages { public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE = "Requested model ID [{}] does not have a matching trained model and thus cannot be updated."; public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated"; + public static final String INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT = + "Cannot update inference endpoint [{}] for model deployment [{}] as it was created by another inference endpoint. " + + "The model can only be updated using inference endpoint id [{}]."; + public static final String INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED = + "Cannot update inference endpoint [{}] using model deployment [{}]. " + + "The model deployment must be updated through the trained models API."; private Messages() {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/UpdateIndexMigrationVersionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/UpdateIndexMigrationVersionAction.java index 1ebf129e8f553..674354c15702b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/UpdateIndexMigrationVersionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/UpdateIndexMigrationVersionAction.java @@ -19,7 +19,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.common.Priority; @@ -139,22 +139,19 @@ static class UpdateIndexMigrationVersionTask implements ClusterStateTaskListener } ClusterState execute(ClusterState currentState) { - IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder( - currentState.metadata().getProject().indices().get(indexName) - ); + final var project = currentState.metadata().getProject(); + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(project.indices().get(indexName)); indexMetadataBuilder.putCustom( MIGRATION_VERSION_CUSTOM_KEY, Map.of(MIGRATION_VERSION_CUSTOM_DATA_KEY, Integer.toString(indexMigrationVersion)) ); indexMetadataBuilder.version(indexMetadataBuilder.version() + 1); - final ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder( - currentState.metadata().getProject().indices() - ); + final ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(project.indices()); builder.put(indexName, indexMetadataBuilder.build()); return ClusterState.builder(currentState) - .metadata(Metadata.builder(currentState.metadata()).indices(builder.build()).build()) + .putProjectMetadata(ProjectMetadata.builder(project).indices(builder.build()).build()) .build(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLConfigurationSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLConfigurationSettings.java index 055454847e154..6aa50df60a1f5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLConfigurationSettings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLConfigurationSettings.java @@ -13,9 +13,12 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.ssl.SslClientAuthenticationMode; import org.elasticsearch.common.ssl.SslConfigurationKeys; +import org.elasticsearch.common.ssl.SslConfigurationLoader; import org.elasticsearch.common.ssl.SslVerificationMode; import org.elasticsearch.common.ssl.X509Field; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import java.util.ArrayList; @@ -30,6 +33,7 @@ import javax.net.ssl.TrustManagerFactory; import static org.elasticsearch.common.ssl.SslConfigurationLoader.GLOBAL_DEFAULT_RESTRICTED_TRUST_FIELDS; +import static org.elasticsearch.xpack.core.XPackSettings.TRANSPORT_SSL_PREFIX; /** * Bridges SSLConfiguration into the {@link Settings} framework, using {@link Setting} objects. @@ -50,6 +54,7 @@ public class SSLConfigurationSettings { final Setting> caPaths; final Setting> clientAuth; final Setting> verificationMode; + final Setting handshakeTimeout; // public for PKI realm private final Setting legacyTruststorePassword; @@ -223,6 +228,11 @@ public class SSLConfigurationSettings { public static final Function>> VERIFICATION_MODE_SETTING_REALM = VERIFICATION_MODE::realm; + public static final SslSetting HANDSHAKE_TIMEOUT = SslSetting.setting( + SslConfigurationKeys.HANDSHAKE_TIMEOUT, + key -> Setting.positiveTimeSetting(key, SslConfigurationLoader.DEFAULT_HANDSHAKE_TIMEOUT, Property.NodeScope) + ); + /** * @param prefix The prefix under which each setting should be defined. Must be either the empty string ("") or a string * ending in "." @@ -246,6 +256,7 @@ private SSLConfigurationSettings(String prefix, boolean acceptNonSecurePasswords caPaths = CERT_AUTH_PATH.withPrefix(prefix); clientAuth = CLIENT_AUTH_SETTING.withPrefix(prefix); verificationMode = VERIFICATION_MODE.withPrefix(prefix); + handshakeTimeout = HANDSHAKE_TIMEOUT.withPrefix(prefix); final List> enabled = CollectionUtils.arrayAsArrayList( ciphers, @@ -270,6 +281,16 @@ private SSLConfigurationSettings(String prefix, boolean acceptNonSecurePasswords enabled.addAll(x509KeyPair.getEnabledSettings()); disabled.addAll(x509KeyPair.getDisabledSettings()); + if (TRANSPORT_SSL_PREFIX.equals(prefix) + || XPackSettings.REMOTE_CLUSTER_CLIENT_SSL_PREFIX.equals(prefix) + || XPackSettings.REMOTE_CLUSTER_SERVER_SSL_PREFIX.equals(prefix)) { + enabled.add(handshakeTimeout); + } else { + // Today the handshake timeout is only adjustable for transport connections - see SecurityNetty4Transport. In principle we + // could extend this to other contexts too, we just haven't done so yet. + disabled.add(handshakeTimeout); + } + this.enabledSettings = Collections.unmodifiableList(enabled); this.disabledSettings = Collections.unmodifiableList(disabled); } @@ -327,7 +348,8 @@ private static Collection> settings() { CERT, CERT_AUTH_PATH, CLIENT_AUTH_SETTING, - VERIFICATION_MODE + VERIFICATION_MODE, + HANDSHAKE_TIMEOUT ); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/TransformMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/TransformMetadata.java index 86651fe241b3d..f76e14c66a57c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/TransformMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/TransformMetadata.java @@ -12,7 +12,9 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.ProjectState; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; @@ -209,6 +211,9 @@ public TransformMetadata build() { } } + /** + * @deprecated use {@link #transformMetadata(ClusterState, ProjectId)} + */ @Deprecated(forRemoval = true) public static TransformMetadata getTransformMetadata(ClusterState state) { TransformMetadata TransformMetadata = (state == null) ? null : state.metadata().getSingleProjectCustom(TYPE); @@ -218,6 +223,24 @@ public static TransformMetadata getTransformMetadata(ClusterState state) { return TransformMetadata; } + public static TransformMetadata transformMetadata(@Nullable ClusterState state, @Nullable ProjectId projectId) { + if (state == null || projectId == null) { + return EMPTY_METADATA; + } + return transformMetadata(state.projectState(projectId)); + } + + public static TransformMetadata transformMetadata(@Nullable ProjectState projectState) { + if (projectState == null) { + return EMPTY_METADATA; + } + TransformMetadata transformMetadata = projectState.metadata().custom(TYPE); + if (transformMetadata == null) { + return EMPTY_METADATA; + } + return transformMetadata; + } + public static boolean upgradeMode(ClusterState state) { return getTransformMetadata(state).upgradeMode(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilterTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilterTests.java index c5365c19632b1..51d3efe652a0c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilterTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/api/filtering/ApiFilteringActionFilterTests.java @@ -30,7 +30,7 @@ public void testApply() { boolean isOperator = randomBoolean(); final ThreadContext threadContext = getTestThreadContext(isOperator); String action = "test.action"; - ApiFilteringActionFilter filter = new TestFilter(threadContext); + ApiFilteringActionFilter filter = new TestFilter(threadContext, false); Task task = null; TestRequest request = new TestRequest(); AtomicBoolean listenerCalled = new AtomicBoolean(false); @@ -59,6 +59,37 @@ public void onFailure(Exception e) { assertThat(responseModified.get(), equalTo(isOperator == false)); } + public void testApplyAsOperator() { + final ThreadContext threadContext = getTestThreadContext(true); + ApiFilteringActionFilter filter = new TestFilter(threadContext, true); + Task task = null; + TestRequest request = new TestRequest(); + AtomicBoolean listenerCalled = new AtomicBoolean(false); + AtomicBoolean responseModified = new AtomicBoolean(false); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(TestResponse testResponse) { + listenerCalled.set(true); + responseModified.set(testResponse.modified); + } + + @Override + public void onFailure(Exception e) { + fail(Strings.format("Unexpected exception: %s", e.getMessage())); + } + }; + ActionFilterChain chain = (task1, action1, request1, listener1) -> { + listener1.onResponse(new TestResponse()); + }; + filter.apply(task, "wrong.action", request, listener, chain); + assertThat(listenerCalled.get(), equalTo(true)); + assertThat(responseModified.get(), equalTo(false)); + filter.apply(task, "test.action", request, listener, chain); + assertThat(listenerCalled.get(), equalTo(true)); + // The response should always be modified + assertThat(responseModified.get(), equalTo(true)); + } + public void testApplyWithException() { /* * This test makes sure that we have correct behavior if the filter function throws an exception. In that case we expect @@ -94,8 +125,8 @@ public void onFailure(Exception e) { private static class TestFilter extends ApiFilteringActionFilter { - TestFilter(ThreadContext threadContext) { - super(threadContext, "test.action", TestResponse.class); + TestFilter(ThreadContext threadContext, boolean filterOperatorRequests) { + super(threadContext, "test.action", TestResponse.class, filterOperatorRequests); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 56dc2a6d0212a..fd5632606867e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -119,7 +119,7 @@ public void testParseAllFields() throws IOException { assertThat(request, is(expected)); assertThat( - Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)), + Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokens("gpt-4o", ToXContent.EMPTY_PARAMS)), is(XContentHelper.stripWhitespace(requestJson)) ); } diff --git a/x-pack/plugin/deprecation/src/test/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckActionTests.java b/x-pack/plugin/deprecation/src/test/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckActionTests.java index 40a564088aee6..0a323140c7e09 100644 --- a/x-pack/plugin/deprecation/src/test/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckActionTests.java +++ b/x-pack/plugin/deprecation/src/test/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckActionTests.java @@ -167,15 +167,9 @@ public void testCheckDiskLowWatermark() { String nodeId = "123"; long totalBytesOnMachine = 100; long totalBytesFree = 70; - ClusterInfo clusterInfo = new ClusterInfo( - Map.of(), - Map.of(nodeId, new DiskUsage(nodeId, "", "", totalBytesOnMachine, totalBytesFree)), - Map.of(), - Map.of(), - Map.of(), - Map.of(), - Map.of() - ); + ClusterInfo clusterInfo = ClusterInfo.builder() + .mostAvailableSpaceUsage(Map.of(nodeId, new DiskUsage(nodeId, "", "", totalBytesOnMachine, totalBytesFree))) + .build(); DeprecationIssue issue = TransportNodeDeprecationCheckAction.checkDiskLowWatermark( nodeSettings, dynamicSettings, diff --git a/x-pack/plugin/downsample/build.gradle b/x-pack/plugin/downsample/build.gradle index 2de12b89b5d3d..b823f1cd2e266 100644 --- a/x-pack/plugin/downsample/build.gradle +++ b/x-pack/plugin/downsample/build.gradle @@ -18,6 +18,8 @@ dependencies { compileOnly project(path: xpackModule('mapper-aggregate-metric')) testImplementation(testArtifact(project(xpackModule('core')))) testImplementation project(xpackModule('ccr')) + testImplementation project(xpackModule('esql')) + testImplementation project(xpackModule('esql-core')) } addQaCheckDependencies(project) diff --git a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java index 70150d4f95bc9..6eb3efcdeb735 100644 --- a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java +++ b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java @@ -7,24 +7,38 @@ package org.elasticsearch.xpack.downsample; +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest; +import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; +import org.elasticsearch.action.admin.indices.delete.TransportDeleteIndexAction; import org.elasticsearch.action.admin.indices.rollover.RolloverRequest; +import org.elasticsearch.action.datastreams.ModifyDataStreamsAction; import org.elasticsearch.action.downsample.DownsampleAction; import org.elasticsearch.action.downsample.DownsampleConfig; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.metadata.DataStreamAction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.esql.action.ColumnInfoImpl; +import org.elasticsearch.xpack.esql.action.EsqlQueryAction; +import org.elasticsearch.xpack.esql.action.EsqlQueryRequest; +import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; import java.io.IOException; import java.time.Instant; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xpack.downsample.DownsampleDataStreamTests.TIMEOUT; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; public class DownsampleIT extends DownsamplingIntegTestCase { @@ -96,4 +110,163 @@ public void testDownsamplingPassthroughDimensions() throws Exception { assertDownsampleIndexFieldsAndDimensions(sourceIndex, targetIndex, downsampleConfig); } + + public void testAggMetricInEsqlTSAfterDownsampling() throws Exception { + String dataStreamName = "metrics-foo"; + Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); + putTSDBIndexTemplate("my-template", List.of("metrics-foo"), settings, """ + { + "properties": { + "host": { + "type": "keyword", + "time_series_dimension": true + }, + "cluster" : { + "type": "keyword", + "time_series_dimension": true + }, + "cpu": { + "type": "double", + "time_series_metric": "gauge" + } + } + } + """, null, null); + + // Create data stream by indexing documents + final Instant now = Instant.now(); + Supplier sourceSupplier = () -> { + String ts = randomDateForRange(now.minusSeconds(60 * 60).toEpochMilli(), now.plusSeconds(60 * 29).toEpochMilli()); + try { + return XContentFactory.jsonBuilder() + .startObject() + .field("@timestamp", ts) + .field("host", randomFrom("host1", "host2", "host3")) + .field("cluster", randomFrom("cluster1", "cluster2", "cluster3")) + .field("cpu", randomDouble()) + .endObject(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + bulkIndex(dataStreamName, sourceSupplier, 100); + + // Rollover to ensure the index we will downsample is not the write index + assertAcked(client().admin().indices().rolloverIndex(new RolloverRequest(dataStreamName, null))); + List backingIndices = waitForDataStreamBackingIndices(dataStreamName, 2); + String sourceIndex = backingIndices.get(0); + String interval = "5m"; + String targetIndex = "downsample-" + interval + "-" + sourceIndex; + // Set the source index to read-only state + assertAcked( + indicesAdmin().prepareUpdateSettings(sourceIndex) + .setSettings(Settings.builder().put(IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.getKey(), true).build()) + ); + + DownsampleConfig downsampleConfig = new DownsampleConfig(new DateHistogramInterval(interval)); + assertAcked( + client().execute( + DownsampleAction.INSTANCE, + new DownsampleAction.Request(TEST_REQUEST_TIMEOUT, sourceIndex, targetIndex, TIMEOUT, downsampleConfig) + ) + ); + + // Wait for downsampling to complete + SubscribableListener listener = ClusterServiceUtils.addMasterTemporaryStateListener(clusterState -> { + final var indexMetadata = clusterState.metadata().getProject().index(targetIndex); + if (indexMetadata == null) { + return false; + } + var downsampleStatus = IndexMetadata.INDEX_DOWNSAMPLE_STATUS.get(indexMetadata.getSettings()); + return downsampleStatus == IndexMetadata.DownsampleTaskStatus.SUCCESS; + }); + safeAwait(listener); + + assertDownsampleIndexFieldsAndDimensions(sourceIndex, targetIndex, downsampleConfig); + + // remove old backing index and replace with downsampled index and delete old so old is not queried + assertAcked( + client().execute( + ModifyDataStreamsAction.INSTANCE, + new ModifyDataStreamsAction.Request( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + List.of( + DataStreamAction.removeBackingIndex(dataStreamName, sourceIndex), + DataStreamAction.addBackingIndex(dataStreamName, targetIndex) + ) + ) + ).actionGet() + ); + assertAcked(client().execute(TransportDeleteIndexAction.TYPE, new DeleteIndexRequest(sourceIndex)).actionGet()); + + // index to the next backing index; random time between 31 and 59m in the future to because default look_ahead_time is 30m and we + // don't want to conflict with the previous backing index + Supplier nextSourceSupplier = () -> { + String ts = randomDateForRange(now.plusSeconds(60 * 31).toEpochMilli(), now.plusSeconds(60 * 59).toEpochMilli()); + try { + return XContentFactory.jsonBuilder() + .startObject() + .field("@timestamp", ts) + .field("host", randomFrom("host1", "host2", "host3")) + .field("cluster", randomFrom("cluster1", "cluster2", "cluster3")) + .field("cpu", randomDouble()) + .endObject(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + bulkIndex(dataStreamName, nextSourceSupplier, 100); + + // check that TS command is available + var response = clusterAdmin().nodesCapabilities( + new NodesCapabilitiesRequest().method(RestRequest.Method.POST).path("/_query").capabilities(METRICS_COMMAND.capabilityName()) + ).actionGet(); + assumeTrue("TS command must be available for this test", response.isSupported().orElse(Boolean.FALSE)); + + // Since the downsampled field (cpu) is downsampled in one index and not in the other, we want to confirm + // first that the field is unsupported and has 2 original types - double and aggregate_metric_double + try (var resp = esqlCommand("TS " + dataStreamName + " | KEEP @timestamp, host, cluster, cpu")) { + var columns = resp.columns(); + assertThat(columns, hasSize(4)); + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl("@timestamp", "date", null), + new ColumnInfoImpl("host", "keyword", null), + new ColumnInfoImpl("cluster", "keyword", null), + new ColumnInfoImpl("cpu", "unsupported", List.of("aggregate_metric_double", "double")) + ) + ) + ); + } + + // test _over_time commands with implicit casting of aggregate_metric_double + for (String innerCommand : List.of("min_over_time", "max_over_time", "avg_over_time", "count_over_time")) { + for (String outerCommand : List.of("min", "max", "sum", "count")) { + String command = outerCommand + " (" + innerCommand + "(cpu))"; + String expectedType = innerCommand.equals("count_over_time") || outerCommand.equals("count") ? "long" : "double"; + try (var resp = esqlCommand("TS " + dataStreamName + " | STATS " + command + " by cluster, bucket(@timestamp, 1 hour)")) { + var columns = resp.columns(); + assertThat(columns, hasSize(3)); + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl(command, expectedType, null), + new ColumnInfoImpl("cluster", "keyword", null), + new ColumnInfoImpl("bucket(@timestamp, 1 hour)", "date", null) + ) + ) + ); + // TODO: verify the numbers are accurate + } + } + } + } + + private EsqlQueryResponse esqlCommand(String command) throws IOException { + return client().execute(EsqlQueryAction.INSTANCE, new EsqlQueryRequest().query(command)).actionGet(30, TimeUnit.SECONDS); + } } diff --git a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsamplingIntegTestCase.java b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsamplingIntegTestCase.java index 27de42447d3a0..4991a9025956f 100644 --- a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsamplingIntegTestCase.java +++ b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsamplingIntegTestCase.java @@ -45,6 +45,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.aggregatemetric.AggregateMetricMapperPlugin; import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import java.io.IOException; import java.time.LocalDateTime; @@ -82,7 +83,13 @@ public abstract class DownsamplingIntegTestCase extends ESIntegTestCase { @Override protected Collection> nodePlugins() { - return List.of(DataStreamsPlugin.class, LocalStateCompositeXPackPlugin.class, Downsample.class, AggregateMetricMapperPlugin.class); + return List.of( + DataStreamsPlugin.class, + LocalStateCompositeXPackPlugin.class, + Downsample.class, + AggregateMetricMapperPlugin.class, + EsqlPlugin.class + ); } /** diff --git a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java index 5f91fb18fd58e..76615876c5255 100644 --- a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java +++ b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java @@ -22,6 +22,7 @@ import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; @@ -29,6 +30,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper; import org.elasticsearch.index.shard.ShardId; @@ -116,7 +118,7 @@ protected AllocatedPersistentTask createTask( } @Override - public void validate(DownsampleShardTaskParams params, ClusterState clusterState) { + public void validate(DownsampleShardTaskParams params, ClusterState clusterState, @Nullable ProjectId projectId) { // This is just a pre-check, but doesn't prevent from avoiding from aborting the task when source index disappeared // after initial creation of the persistent task. var indexShardRouting = findShardRoutingTable(params.shardId(), clusterState); @@ -126,10 +128,11 @@ public void validate(DownsampleShardTaskParams params, ClusterState clusterState } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( final DownsampleShardTaskParams params, final Collection candidateNodes, - final ClusterState clusterState + final ClusterState clusterState, + @Nullable final ProjectId projectId ) { // NOTE: downsampling works by running a task per each shard of the source index. // Here we make sure we assign the task to the actual node holding the shard identified by diff --git a/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java b/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java index 39e92f06ada16..5a4e14dc24015 100644 --- a/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java +++ b/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java @@ -96,7 +96,7 @@ public void testGetAssignment() { Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY ); - var result = executor.getAssignment(params, Set.of(node), clusterState); + var result = executor.getAssignment(params, Set.of(node), clusterState, projectId); assertThat(result.getExecutorNode(), equalTo(node.getId())); } @@ -128,7 +128,7 @@ public void testGetAssignmentMissingIndex() { Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY ); - var result = executor.getAssignment(params, Set.of(node), clusterState); + var result = executor.getAssignment(params, Set.of(node), clusterState, projectId); assertThat(result.getExecutorNode(), equalTo(node.getId())); assertThat(result.getExplanation(), equalTo("a node to fail and stop this persistent task")); } @@ -165,7 +165,7 @@ public void testGetStatelessAssignment() { Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY ); - var result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState); + var result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState, projectId); assertThat(result.getExecutorNode(), nullValue()); // Assign a copy of the shard to a search node @@ -185,7 +185,7 @@ public void testGetStatelessAssignment() { .build() ) .build(); - result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState); + result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState, projectId); assertThat(result.getExecutorNode(), equalTo(searchNode.getId())); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java index 1af98f4b21dc5..6b700f0ee6a7f 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java @@ -134,4 +134,19 @@ public String nodeString() { } protected abstract String label(); + + /** + * Compares the size and datatypes of two lists of attributes for equality. + */ + public static boolean dataTypeEquals(List left, List right) { + if (left.size() != right.size()) { + return false; + } + for (int i = 0; i < left.size(); i++) { + if (left.get(i).dataType() != right.get(i).dataType()) { + return false; + } + } + return true; + } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/AttributeSet.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/AttributeSet.java index d281db4e6bf63..fefaf3098319e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/AttributeSet.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/AttributeSet.java @@ -261,5 +261,9 @@ public boolean isEmpty() { public AttributeSet build() { return new AttributeSet(mapBuilder.build()); } + + public void clear() { + mapBuilder.keySet().clear(); + } } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePattern.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePattern.java index 0744977170911..72b8c2efb2eba 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePattern.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePattern.java @@ -9,10 +9,14 @@ import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.Operations; import org.apache.lucene.util.automaton.RegExp; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import java.io.IOException; import java.util.Objects; -public class RLikePattern extends AbstractStringPattern { +public class RLikePattern extends AbstractStringPattern implements Writeable { private final String regexpPattern; @@ -20,6 +24,15 @@ public RLikePattern(String regexpPattern) { this.regexpPattern = regexpPattern; } + public RLikePattern(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(regexpPattern); + } + @Override public Automaton createAutomaton(boolean ignoreCase) { int matchFlags = ignoreCase ? RegExp.CASE_INSENSITIVE : 0; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePatternList.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePatternList.java new file mode 100644 index 0000000000000..be62d189bafa4 --- /dev/null +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLikePatternList.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.esql.core.expression.predicate.regex; + +import org.apache.lucene.util.automaton.Automaton; +import org.apache.lucene.util.automaton.Operations; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class RLikePatternList extends AbstractStringPattern implements Writeable { + + private final List patternList; + + public RLikePatternList(List patternList) { + this.patternList = patternList; + } + + public RLikePatternList(StreamInput in) throws IOException { + this(in.readCollectionAsList(RLikePattern::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(patternList, (o, pattern) -> pattern.writeTo(o)); + } + + public List patternList() { + return patternList; + } + + /** + * Creates an automaton that matches any of the patterns in the list. + * We create a single automaton that is the union of all individual automatons to improve performance + */ + @Override + public Automaton createAutomaton(boolean ignoreCase) { + List automatonList = patternList.stream().map(x -> x.createAutomaton(ignoreCase)).toList(); + Automaton result = Operations.union(automatonList); + return Operations.determinize(result, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT); + } + + /** + * Returns a Java regex that matches any of the patterns in the list. + * The patterns are joined with the '|' operator to create a single regex. + */ + @Override + public String asJavaRegex() { + return patternList.stream().map(RLikePattern::asJavaRegex).collect(Collectors.joining("|")); + } + + @Override + public int hashCode() { + return Objects.hash(patternList); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + RLikePatternList other = (RLikePatternList) obj; + return patternList.equals(other.patternList); + } + + /** + * Returns a string that matches any of the patterns in the list. + * The patterns are joined with the '|' operator to create a single regex string. + */ + @Override + public String pattern() { + if (patternList.isEmpty()) { + return ""; + } + if (patternList.size() == 1) { + return patternList.get(0).pattern(); + } + return "(\"" + patternList.stream().map(RLikePattern::pattern).collect(Collectors.joining("\", \"")) + "\")"; + } +} diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 1704f4cbeb1fe..8dc6f594ca47a 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -199,14 +200,13 @@ private TypeSpec type() { builder.addMethod(ctor()); builder.addMethod(intermediateStateDesc()); builder.addMethod(intermediateBlockCount()); - builder.addMethod(prepareProcessPage()); + builder.addMethod(prepareProcessRawInputPage()); for (ClassName groupIdClass : GROUP_IDS_CLASSES) { builder.addMethod(addRawInputLoop(groupIdClass, blockType(aggParam.type()))); builder.addMethod(addRawInputLoop(groupIdClass, vectorType(aggParam.type()))); + builder.addMethod(addIntermediateInput(groupIdClass)); } builder.addMethod(selectedMayContainUnseenGroups()); - builder.addMethod(addIntermediateInput()); - builder.addMethod(addIntermediateRowInput()); builder.addMethod(evaluateIntermediate()); builder.addMethod(evaluateFinal()); builder.addMethod(toStringMethod()); @@ -314,10 +314,10 @@ private MethodSpec intermediateBlockCount() { } /** - * Prepare to process a single page of results. + * Prepare to process a single raw input page. */ - private MethodSpec prepareProcessPage() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("prepareProcessPage"); + private MethodSpec prepareProcessRawInputPage() { + MethodSpec.Builder builder = MethodSpec.methodBuilder("prepareProcessRawInputPage"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT); builder.addParameter(SEEN_GROUP_IDS, "seenGroupIds").addParameter(PAGE, "page"); @@ -411,10 +411,16 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); { - if (groupsIsBlock) { - builder.beginControlFlow("if (groups.isNull(groupPosition))"); + if (groupsIsBlock || valuesIsBlock) { + String conditions = Stream.of( + groupsIsBlock ? "groups.isNull(groupPosition)" : null, + valuesIsBlock ? "values.isNull(groupPosition + positionOffset)" : null + ).filter(Objects::nonNull).collect(Collectors.joining(" || ")); + builder.beginControlFlow("if (" + conditions + ")"); builder.addStatement("continue"); builder.endControlFlow(); + } + if (groupsIsBlock) { builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); @@ -430,9 +436,6 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { } if (valuesIsBlock) { - builder.beginControlFlow("if (values.isNull(groupPosition + positionOffset))"); - builder.addStatement("continue"); - builder.endControlFlow(); builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)"); builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)"); if (aggParam.isArray()) { @@ -580,11 +583,12 @@ private MethodSpec selectedMayContainUnseenGroups() { return builder.build(); } - private MethodSpec addIntermediateInput() { + private MethodSpec addIntermediateInput(TypeName groupsType) { + boolean groupsIsBlock = groupsType.toString().endsWith("Block"); MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); builder.addParameter(TypeName.INT, "positionOffset"); - builder.addParameter(INT_VECTOR, "groups"); + builder.addParameter(groupsType, "groups"); builder.addParameter(PAGE, "page"); builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS); @@ -610,7 +614,18 @@ private MethodSpec addIntermediateInput() { } builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); { - builder.addStatement("int groupId = groups.getInt(groupPosition)"); + if (groupsIsBlock) { + builder.beginControlFlow("if (groups.isNull(groupPosition))"); + builder.addStatement("continue"); + builder.endControlFlow(); + builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); + builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); + builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); + builder.addStatement("int groupId = groups.getInt(g)"); + } else { + builder.addStatement("int groupId = groups.getInt(groupPosition)"); + } + if (aggState.declaredType().isPrimitive()) { if (warnExceptions.isEmpty()) { assert intermediateState.size() == 2; @@ -661,43 +676,14 @@ private MethodSpec addIntermediateInput() { declarationType ); } + if (groupsIsBlock) { + builder.endControlFlow(); + } builder.endControlFlow(); } return builder.build(); } - private MethodSpec addIntermediateRowInput() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput"); - builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); - builder.addParameter(int.class, "groupId").addParameter(GROUPING_AGGREGATOR_FUNCTION, "input").addParameter(int.class, "position"); - builder.beginControlFlow("if (input.getClass() != getClass())"); - { - builder.addStatement("throw new IllegalArgumentException($S + getClass() + $S + input.getClass())", "expected ", "; got "); - } - builder.endControlFlow(); - builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation); - builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS); - if (aggState.declaredType().isPrimitive()) { - builder.beginControlFlow("if (inState.hasValue(position))"); - builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType); - builder.endControlFlow(); - } else { - requireStaticMethod( - declarationType, - requireVoidType(), - requireName("combineStates"), - requireArgs( - requireType(aggState.declaredType()), - requireType(TypeName.INT), - requireType(aggState.declaredType()), - requireType(TypeName.INT) - ) - ); - builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType); - } - return builder.build(); - } - private MethodSpec evaluateIntermediate() { MethodSpec.Builder builder = MethodSpec.methodBuilder("evaluateIntermediate"); builder.addAnnotation(Override.class) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleAggregator.java index 50b00c998a8af..3c3ca7f6cb2d9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleAggregator.java @@ -56,14 +56,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeFloatAggregator.java index 69ad3c6eb3db5..407c6f9d2ff6d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeFloatAggregator.java @@ -56,14 +56,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeIntAggregator.java index 134af879b1d04..d006907f1800e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeIntAggregator.java @@ -56,14 +56,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeLongAggregator.java index b052f43e3aff4..62f9a46ef9ed7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/FirstOverTimeLongAggregator.java @@ -54,14 +54,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeDoubleAggregator.java index 77aafed555519..91373a508784c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeDoubleAggregator.java @@ -56,14 +56,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeFloatAggregator.java index d55cbfc09dc12..4bb863ad6474f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeFloatAggregator.java @@ -56,14 +56,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeIntAggregator.java index 5ea8cc7f27bd7..ee2cc6f3049b2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeIntAggregator.java @@ -56,14 +56,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeLongAggregator.java index 781c52a627649..d32144ef1fc19 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LastOverTimeLongAggregator.java @@ -54,14 +54,6 @@ public static void combineIntermediate( } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java index 15d4c0b060440..92f8886712d35 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java @@ -24,8 +24,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import java.util.Arrays; - /** * A rate grouping aggregation definition for double. * This class is generated. Edit `X-RateAggregator.java.st` instead. @@ -60,15 +58,6 @@ public static void combineIntermediate( current.combine(groupId, timestamps, values, sampleCount, reset, otherPosition); } - public static void combineStates( - DoubleRateGroupingState current, - int currentGroupId, // make the stylecheck happy - DoubleRateGroupingState otherState, - int otherGroupId - ) { - current.combineState(currentGroupId, otherState, otherGroupId); - } - public static Block evaluateFinal(DoubleRateGroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } @@ -217,25 +206,6 @@ void merge(DoubleRateState curr, DoubleRateState dst, int firstIndex, int rightC } } - void combineState(int groupId, DoubleRateGroupingState otherState, int otherGroupId) { - var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null; - if (other == null) { - return; - } - ensureCapacity(groupId); - var curr = states.get(groupId); - if (curr == null) { - var len = other.entries(); - adjustBreaker(DoubleRateState.bytesUsed(len)); - curr = new DoubleRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len)); - curr.reset = other.reset; - curr.sampleCount = other.sampleCount; - states.set(groupId, curr); - } else { - states.set(groupId, mergeState(curr, other)); - } - } - DoubleRateState mergeState(DoubleRateState s1, DoubleRateState s2) { var newLen = s1.entries() + s2.entries(); adjustBreaker(DoubleRateState.bytesUsed(newLen)); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java index f19b4a91596eb..eb8d6a194e6e5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java @@ -25,8 +25,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import java.util.Arrays; - /** * A rate grouping aggregation definition for float. * This class is generated. Edit `X-RateAggregator.java.st` instead. @@ -61,15 +59,6 @@ public static void combineIntermediate( current.combine(groupId, timestamps, values, sampleCount, reset, otherPosition); } - public static void combineStates( - FloatRateGroupingState current, - int currentGroupId, // make the stylecheck happy - FloatRateGroupingState otherState, - int otherGroupId - ) { - current.combineState(currentGroupId, otherState, otherGroupId); - } - public static Block evaluateFinal(FloatRateGroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } @@ -218,25 +207,6 @@ void merge(FloatRateState curr, FloatRateState dst, int firstIndex, int rightCou } } - void combineState(int groupId, FloatRateGroupingState otherState, int otherGroupId) { - var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null; - if (other == null) { - return; - } - ensureCapacity(groupId); - var curr = states.get(groupId); - if (curr == null) { - var len = other.entries(); - adjustBreaker(FloatRateState.bytesUsed(len)); - curr = new FloatRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len)); - curr.reset = other.reset; - curr.sampleCount = other.sampleCount; - states.set(groupId, curr); - } else { - states.set(groupId, mergeState(curr, other)); - } - } - FloatRateState mergeState(FloatRateState s1, FloatRateState s2) { var newLen = s1.entries() + s2.entries(); adjustBreaker(FloatRateState.bytesUsed(newLen)); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java index fcb744720a3db..fdacd473264a6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java @@ -25,8 +25,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import java.util.Arrays; - /** * A rate grouping aggregation definition for int. * This class is generated. Edit `X-RateAggregator.java.st` instead. @@ -61,15 +59,6 @@ public static void combineIntermediate( current.combine(groupId, timestamps, values, sampleCount, reset, otherPosition); } - public static void combineStates( - IntRateGroupingState current, - int currentGroupId, // make the stylecheck happy - IntRateGroupingState otherState, - int otherGroupId - ) { - current.combineState(currentGroupId, otherState, otherGroupId); - } - public static Block evaluateFinal(IntRateGroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } @@ -218,25 +207,6 @@ void merge(IntRateState curr, IntRateState dst, int firstIndex, int rightCount, } } - void combineState(int groupId, IntRateGroupingState otherState, int otherGroupId) { - var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null; - if (other == null) { - return; - } - ensureCapacity(groupId); - var curr = states.get(groupId); - if (curr == null) { - var len = other.entries(); - adjustBreaker(IntRateState.bytesUsed(len)); - curr = new IntRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len)); - curr.reset = other.reset; - curr.sampleCount = other.sampleCount; - states.set(groupId, curr); - } else { - states.set(groupId, mergeState(curr, other)); - } - } - IntRateState mergeState(IntRateState s1, IntRateState s2) { var newLen = s1.entries() + s2.entries(); adjustBreaker(IntRateState.bytesUsed(newLen)); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java index 39eba21a7be7d..ea9f7802656fb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java @@ -24,8 +24,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import java.util.Arrays; - /** * A rate grouping aggregation definition for long. * This class is generated. Edit `X-RateAggregator.java.st` instead. @@ -60,15 +58,6 @@ public static void combineIntermediate( current.combine(groupId, timestamps, values, sampleCount, reset, otherPosition); } - public static void combineStates( - LongRateGroupingState current, - int currentGroupId, // make the stylecheck happy - LongRateGroupingState otherState, - int otherGroupId - ) { - current.combineState(currentGroupId, otherState, otherGroupId); - } - public static Block evaluateFinal(LongRateGroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } @@ -217,25 +206,6 @@ void merge(LongRateState curr, LongRateState dst, int firstIndex, int rightCount } } - void combineState(int groupId, LongRateGroupingState otherState, int otherGroupId) { - var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null; - if (other == null) { - return; - } - ensureCapacity(groupId); - var curr = states.get(groupId); - if (curr == null) { - var len = other.entries(); - adjustBreaker(LongRateState.bytesUsed(len)); - curr = new LongRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len)); - curr.reset = other.reset; - curr.sampleCount = other.sampleCount; - states.set(groupId, curr); - } else { - states.set(groupId, mergeState(curr, other)); - } - } - LongRateState mergeState(LongRateState s1, LongRateState s2) { var newLen = s1.entries() + s2.entries(); adjustBreaker(LongRateState.bytesUsed(newLen)); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java index 7ef2b2c52f685..6017256910e44 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java @@ -84,10 +84,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); } @@ -150,10 +146,6 @@ public void add(int groupId, boolean value) { bytesRefBuilder.clear(); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -185,10 +177,6 @@ public void add(boolean value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java index c9e42350dd060..2917efbff7ad1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java @@ -84,10 +84,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); } @@ -150,10 +146,6 @@ public void add(int groupId, BytesRef value) { bytesRefBuilder.clear(); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -185,10 +177,6 @@ public void add(BytesRef value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java index f526c54c6ddff..14578b7c2c15d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java @@ -84,10 +84,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); } @@ -150,10 +146,6 @@ public void add(int groupId, double value) { bytesRefBuilder.clear(); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -185,10 +177,6 @@ public void add(double value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java index d6172006e46df..6d25d4005f100 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java @@ -84,10 +84,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); } @@ -150,10 +146,6 @@ public void add(int groupId, int value) { bytesRefBuilder.clear(); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -185,10 +177,6 @@ public void add(int value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java index cd97db7155d58..680b513f5d01b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java @@ -84,10 +84,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); } @@ -150,10 +146,6 @@ public void add(int groupId, long value) { bytesRefBuilder.clear(); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -185,10 +177,6 @@ public void add(long value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java index 3a1185d34fa23..cb2550d1ce726 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java @@ -52,10 +52,6 @@ public static void combine(StdDevStates.GroupingState current, int groupId, doub current.add(groupId, value); } - public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { - current.combine(groupId, state.getOrNull(statePosition)); - } - public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java index 51c22e7e29c1e..8ac7a21817abe 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java @@ -52,10 +52,6 @@ public static void combine(StdDevStates.GroupingState current, int groupId, floa current.add(groupId, value); } - public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { - current.combine(groupId, state.getOrNull(statePosition)); - } - public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java index 24eae35cb3249..991382560269a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevIntAggregator.java @@ -52,10 +52,6 @@ public static void combine(StdDevStates.GroupingState current, int groupId, int current.add(groupId, value); } - public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { - current.combine(groupId, state.getOrNull(statePosition)); - } - public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java index 888ace30a0c8e..5df0e5ae061a4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/StdDevLongAggregator.java @@ -52,10 +52,6 @@ public static void combine(StdDevStates.GroupingState current, int groupId, long current.add(groupId, value); } - public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { - current.combine(groupId, state.getOrNull(statePosition)); - } - public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java index a2e86b3b09340..1b683b99c9df5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java @@ -65,10 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, Boolean } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -84,10 +80,6 @@ public void add(int groupId, boolean value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -119,10 +111,6 @@ public void add(boolean value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java index 0a965899c0775..5ec303500451a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java @@ -69,10 +69,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -90,10 +86,6 @@ public void add(int groupId, BytesRef value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -125,10 +117,6 @@ public void add(BytesRef value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java index 6a20ed99bc236..ac833dee81922 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java @@ -65,10 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, DoubleB } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -84,10 +80,6 @@ public void add(int groupId, double value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -119,10 +111,6 @@ public void add(double value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java index cf6ad0f9017de..50e652d0af6f0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java @@ -65,10 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, FloatBl } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -84,10 +80,6 @@ public void add(int groupId, float value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -119,10 +111,6 @@ public void add(float value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java index f4ac83c438063..d0d93b2d971f5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java @@ -65,10 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, IntBloc } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -84,10 +80,6 @@ public void add(int groupId, int value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -119,10 +111,6 @@ public void add(int value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java index 292dd539edeb5..16a26c5f01f52 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java @@ -68,10 +68,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -87,10 +83,6 @@ public void add(int groupId, BytesRef value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -122,10 +114,6 @@ public void add(BytesRef value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java index c5af92956bec1..754a4e513eaab 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java @@ -65,10 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, LongBlo } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -84,10 +80,6 @@ public void add(int groupId, long value) { sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -119,10 +111,6 @@ public void add(long value) { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index cb0dff8a86dc5..337c8cde768f9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -89,19 +89,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - if (statePosition > state.maxGroupId) { - return; - } - var sorted = state.sortedForOrdinalMerging(current); - var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; - var end = sorted.counts[statePosition]; - for (int i = start; i < end; i++) { - int id = sorted.ids[i]; - current.addValueOrdinal(currentGroupId, id); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -148,8 +135,6 @@ public void close() { * and then use it to iterate over the values in order. * * @param ids positions of the {@link GroupingState#values} to read. - * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, - * these are ordinals referring to the {@link GroupingState#bytes} in the target state. */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override @@ -171,8 +156,6 @@ public static class GroupingState implements GroupingAggregatorState { private final LongLongHash values; BytesRefHash bytes; - private Sorted sortedForOrdinalMerging = null; - private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); LongLongHash _values = null; @@ -312,34 +295,6 @@ private Sorted buildSorted(IntVector selected) { } } - private Sorted sortedForOrdinalMerging(GroupingState other) { - if (sortedForOrdinalMerging == null) { - try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { - sortedForOrdinalMerging = buildSorted(selected); - // hash all the bytes to the destination to avoid hashing them multiple times - BytesRef scratch = new BytesRef(); - final int totalValue = Math.toIntExact(bytes.size()); - blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); - try { - final int[] mappedIds = new int[totalValue]; - for (int i = 0; i < totalValue; i++) { - var v = bytes.get(i, scratch); - mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); - } - // no longer need the bytes - bytes.close(); - bytes = null; - for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { - sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; - } - } finally { - blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); - } - } - } - return sortedForOrdinalMerging; - } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -416,7 +371,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values, bytes); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 3c0dcd58c29ee..5f01ad586976f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -66,19 +66,6 @@ public static void combineIntermediate(GroupingState state, int groupId, DoubleB } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - if (statePosition > state.maxGroupId) { - return; - } - var sorted = state.sortedForOrdinalMerging(current); - var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; - var end = sorted.counts[statePosition]; - for (int i = start; i < end; i++) { - int id = sorted.ids[i]; - current.addValue(currentGroupId, state.getValue(id)); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -144,8 +131,6 @@ public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; private final LongLongHash values; - private Sorted sortedForOrdinalMerging = null; - private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); values = new LongLongHash(1, driverContext.bigArrays()); @@ -263,15 +248,6 @@ private Sorted buildSorted(IntVector selected) { } } - private Sorted sortedForOrdinalMerging(GroupingState other) { - if (sortedForOrdinalMerging == null) { - try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { - sortedForOrdinalMerging = buildSorted(selected); - } - } - return sortedForOrdinalMerging; - } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -310,7 +286,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index a25d69b712538..9acaaccd80a85 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -65,19 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, FloatBl } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - if (statePosition > state.maxGroupId) { - return; - } - var sorted = state.sortedForOrdinalMerging(current); - var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; - var end = sorted.counts[statePosition]; - for (int i = start; i < end; i++) { - int id = sorted.ids[i]; - current.addValue(currentGroupId, state.getValue(id)); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -143,8 +130,6 @@ public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; private final LongHash values; - private Sorted sortedForOrdinalMerging = null; - private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); values = new LongHash(1, driverContext.bigArrays()); @@ -268,15 +253,6 @@ private Sorted buildSorted(IntVector selected) { } } - private Sorted sortedForOrdinalMerging(GroupingState other) { - if (sortedForOrdinalMerging == null) { - try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { - sortedForOrdinalMerging = buildSorted(selected); - } - } - return sortedForOrdinalMerging; - } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -316,7 +292,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 2c8c0f409dd5b..3690df739552b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -65,19 +65,6 @@ public static void combineIntermediate(GroupingState state, int groupId, IntBloc } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - if (statePosition > state.maxGroupId) { - return; - } - var sorted = state.sortedForOrdinalMerging(current); - var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; - var end = sorted.counts[statePosition]; - for (int i = start; i < end; i++) { - int id = sorted.ids[i]; - current.addValue(currentGroupId, state.getValue(id)); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -143,8 +130,6 @@ public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; private final LongHash values; - private Sorted sortedForOrdinalMerging = null; - private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); values = new LongHash(1, driverContext.bigArrays()); @@ -268,15 +253,6 @@ private Sorted buildSorted(IntVector selected) { } } - private Sorted sortedForOrdinalMerging(GroupingState other) { - if (sortedForOrdinalMerging == null) { - try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { - sortedForOrdinalMerging = buildSorted(selected); - } - } - return sortedForOrdinalMerging; - } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -316,7 +292,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 2790a182d5041..9514e9147e05d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -66,19 +66,6 @@ public static void combineIntermediate(GroupingState state, int groupId, LongBlo } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - if (statePosition > state.maxGroupId) { - return; - } - var sorted = state.sortedForOrdinalMerging(current); - var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; - var end = sorted.counts[statePosition]; - for (int i = start; i < end; i++) { - int id = sorted.ids[i]; - current.addValue(currentGroupId, state.getValue(id)); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -144,8 +131,6 @@ public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; private final LongLongHash values; - private Sorted sortedForOrdinalMerging = null; - private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); values = new LongLongHash(1, driverContext.bigArrays()); @@ -263,15 +248,6 @@ private Sorted buildSorted(IntVector selected) { } } - private Sorted sortedForOrdinalMerging(GroupingState other) { - if (sortedForOrdinalMerging == null) { - try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { - sortedForOrdinalMerging = buildSorted(selected); - } - } - return sortedForOrdinalMerging; - } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -310,7 +286,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java index 4fce90e84add6..9d04612e02511 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java @@ -56,7 +56,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BooleanBlock valuesBlock = page.getBlock(channels.get(0)); BooleanVector valuesVector = valuesBlock.asVector(); @@ -109,16 +109,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -142,7 +139,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block fbitUncast = page.getBlock(channels.get(0)); + if (fbitUncast.areAllValuesNull()) { + return; + } + BooleanVector fbit = ((BooleanBlock) fbitUncast).asVector(); + Block tbitUncast = page.getBlock(channels.get(1)); + if (tbitUncast.areAllValuesNull()) { + return; + } + BooleanVector tbit = ((BooleanBlock) tbitUncast).asVector(); + assert fbit.getPositionCount() == tbit.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -151,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + CountDistinctBooleanAggregator.combineIntermediate(state, groupId, fbit.getBoolean(groupPosition + positionOffset), tbit.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -177,12 +199,40 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block fbitUncast = page.getBlock(channels.get(0)); + if (fbitUncast.areAllValuesNull()) { + return; + } + BooleanVector fbit = ((BooleanBlock) fbitUncast).asVector(); + Block tbitUncast = page.getBlock(channels.get(1)); + if (tbitUncast.areAllValuesNull()) { + return; + } + BooleanVector tbit = ((BooleanBlock) tbitUncast).asVector(); + assert fbit.getPositionCount() == tbit.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctBooleanAggregator.combineIntermediate(state, groupId, fbit.getBoolean(groupPosition + positionOffset), tbit.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -198,11 +248,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +270,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - CountDistinctBooleanAggregator.GroupingState inState = ((CountDistinctBooleanGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - CountDistinctBooleanAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java index 2d005a17dd182..e73d20887e29e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -157,9 +162,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + CountDistinctBytesRefAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,13 +201,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctBytesRefAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +247,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,13 +264,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - HllStates.GroupingState inState = ((CountDistinctBytesRefGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - CountDistinctBytesRefAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java index 0f0dfd4fa5b2c..9011e9ea7de07 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + CountDistinctDoubleAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +199,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctDoubleAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +260,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - HllStates.GroupingState inState = ((CountDistinctDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - CountDistinctDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java index 8e2fa1d71419a..6296aac243bcc 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + CountDistinctFloatAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +199,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctFloatAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +260,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - HllStates.GroupingState inState = ((CountDistinctFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - CountDistinctFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java index 08768acfa5261..8ff5b6636bc57 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java @@ -60,7 +60,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -155,9 +161,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + CountDistinctIntAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -181,12 +198,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctIntAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -202,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -224,13 +259,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - HllStates.GroupingState inState = ((CountDistinctIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - CountDistinctIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java index 0b1caa1c3727c..e6c746887f6f9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + CountDistinctLongAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +199,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctLongAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +260,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - HllStates.GroupingState inState = ((CountDistinctLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - CountDistinctLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java index c0e299d57f6bb..08e11f0ddb3d6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -117,16 +117,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -152,19 +149,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -190,13 +212,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -214,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -241,13 +286,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - FirstOverTimeDoubleAggregator.GroupingState inState = ((FirstOverTimeDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - FirstOverTimeDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java index df4b6c843ff75..f17f5facc8c85 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -117,16 +117,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -152,19 +149,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -190,13 +212,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -214,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -241,13 +286,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - FirstOverTimeFloatAggregator.GroupingState inState = ((FirstOverTimeFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - FirstOverTimeFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java index d0252f8b420d0..a973f01dcda3a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java @@ -57,7 +57,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -116,16 +116,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,19 +148,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -189,13 +211,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -213,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values, } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -240,13 +285,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - FirstOverTimeIntAggregator.GroupingState inState = ((FirstOverTimeIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - FirstOverTimeIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java index 8506d1e8d527b..0d88b3190f1f3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java @@ -56,7 +56,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -115,16 +115,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -150,19 +147,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -188,13 +210,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -212,11 +262,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -239,13 +284,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - FirstOverTimeLongAggregator.GroupingState inState = ((FirstOverTimeLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - FirstOverTimeLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java index 8a32e5552dd1c..ad935063a95b4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -117,16 +117,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -152,19 +149,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -190,13 +212,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -214,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -241,13 +286,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LastOverTimeDoubleAggregator.GroupingState inState = ((LastOverTimeDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - LastOverTimeDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java index 250c5cd755a12..249b27bf7ee70 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -117,16 +117,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -152,19 +149,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -190,13 +212,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -214,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -241,13 +286,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LastOverTimeFloatAggregator.GroupingState inState = ((LastOverTimeFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - LastOverTimeFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java index 9b118c7dea9be..fe25154290aac 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java @@ -57,7 +57,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -116,16 +116,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,19 +148,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -189,13 +211,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -213,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values, } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -240,13 +285,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LastOverTimeIntAggregator.GroupingState inState = ((LastOverTimeIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - LastOverTimeIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java index 82bfc732969e5..3772f8bf186c1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java @@ -56,7 +56,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -115,16 +115,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -150,19 +147,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -188,13 +210,41 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -212,11 +262,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -239,13 +284,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LastOverTimeLongAggregator.GroupingState inState = ((LastOverTimeLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - LastOverTimeLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java index f7390f55bc52b..ab6177f82e6e4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java @@ -56,7 +56,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BooleanBlock valuesBlock = page.getBlock(channels.get(0)); BooleanVector valuesVector = valuesBlock.asVector(); @@ -109,16 +109,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -142,7 +139,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BooleanVector max = ((BooleanBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -151,9 +162,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), max.getBoolean(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -177,12 +201,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BooleanVector max = ((BooleanBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), max.getBoolean(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -198,11 +252,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -227,15 +276,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - BooleanArrayState inState = ((MaxBooleanGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java index 41f98d962bd2f..588144c23162f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -157,9 +168,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MaxBytesRefAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,13 +207,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MaxBytesRefAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +282,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - MaxBytesRefAggregator.GroupingState inState = ((MaxBytesRefGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MaxBytesRefAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java index 53273dad7c0f0..cf06006a24150 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + DoubleVector max = ((DoubleBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), max.getDouble(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + DoubleVector max = ((DoubleBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), max.getDouble(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - DoubleArrayState inState = ((MaxDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java index 49afaf3c7265d..5d1ac766b590d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + FloatVector max = ((FloatBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), max.getFloat(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + FloatVector max = ((FloatBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), max.getFloat(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - FloatArrayState inState = ((MaxFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java index 3d97bf9df5dd9..ee501aed26bc2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java @@ -57,7 +57,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -110,16 +110,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -143,7 +140,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + IntVector max = ((IntBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -152,9 +163,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), max.getInt(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -178,12 +202,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + IntVector max = ((IntBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), max.getInt(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -199,11 +253,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,15 +277,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - IntArrayState inState = ((MaxIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java index fd38873655edd..cfc13a77b2984 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -157,9 +168,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MaxIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,13 +207,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MaxIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +282,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - MaxIpAggregator.GroupingState inState = ((MaxIpGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MaxIpAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java index fcaea869f84d4..36e2101baaae7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + LongVector max = ((LongBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), max.getLong(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + LongVector max = ((LongBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), max.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LongArrayState inState = ((MaxLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java index c380146094f44..bdc7ebfeb03f2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +159,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MedianAbsoluteDeviationDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +196,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +240,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,13 +257,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MedianAbsoluteDeviationDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java index a895ebc9eda6b..b789cae8704a3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +159,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MedianAbsoluteDeviationFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +196,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +240,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,13 +257,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MedianAbsoluteDeviationFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java index f9b9934520f06..6cc6271982921 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java @@ -57,7 +57,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -110,16 +110,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -143,7 +140,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -152,9 +158,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MedianAbsoluteDeviationIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -178,12 +195,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -199,11 +239,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -221,13 +256,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MedianAbsoluteDeviationIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java index e1693d7475c6f..cccdec47b3030 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +159,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MedianAbsoluteDeviationLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +196,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +240,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,13 +257,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MedianAbsoluteDeviationLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java index 4ca346913a25b..52231c0e8975e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java @@ -56,7 +56,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BooleanBlock valuesBlock = page.getBlock(channels.get(0)); BooleanVector valuesVector = valuesBlock.asVector(); @@ -109,16 +109,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -142,7 +139,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BooleanVector min = ((BooleanBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -151,9 +162,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), min.getBoolean(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -177,12 +201,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BooleanVector min = ((BooleanBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), min.getBoolean(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -198,11 +252,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -227,15 +276,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - BooleanArrayState inState = ((MinBooleanGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java index dc721573876ab..e7baef1459eb8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BytesRefVector min = ((BytesRefBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -157,9 +168,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MinBytesRefAggregator.combineIntermediate(state, groupId, min.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,13 +207,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BytesRefVector min = ((BytesRefBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MinBytesRefAggregator.combineIntermediate(state, groupId, min.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +282,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - MinBytesRefAggregator.GroupingState inState = ((MinBytesRefGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MinBytesRefAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java index 3212ca644aee7..ea1ecf6c1f271 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + DoubleVector min = ((DoubleBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), min.getDouble(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + DoubleVector min = ((DoubleBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), min.getDouble(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - DoubleArrayState inState = ((MinDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java index 2e7b089e7592a..bf489b7bf6dc9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + FloatVector min = ((FloatBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), min.getFloat(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + FloatVector min = ((FloatBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), min.getFloat(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - FloatArrayState inState = ((MinFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java index 50c5e80a55b0c..51102c5dff22a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java @@ -57,7 +57,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -110,16 +110,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -143,7 +140,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + IntVector min = ((IntBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -152,9 +163,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), min.getInt(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -178,12 +202,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + IntVector min = ((IntBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), min.getInt(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -199,11 +253,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,15 +277,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - IntArrayState inState = ((MinIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java index c89c1feb6790f..542f744c04a8a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -157,9 +168,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + MinIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,13 +207,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MinIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +282,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - MinIpAggregator.GroupingState inState = ((MinIpGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - MinIpAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java index dc92d712ddb6a..e5683a154285d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + LongVector min = ((LongBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), min.getLong(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + LongVector min = ((LongBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), min.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LongArrayState inState = ((MinLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java index 1264bff20abf6..4e88aa944f6b5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + PercentileDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +199,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +260,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((PercentileDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - PercentileDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java index f844efae8d218..04f057ff87cb8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + PercentileFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +199,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +260,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((PercentileFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - PercentileFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java index e0dd21ecc80d1..402c928970893 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java @@ -60,7 +60,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -155,9 +161,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + PercentileIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -181,12 +198,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -202,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -224,13 +259,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((PercentileIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - PercentileIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java index 1baa4a662175c..8509057d6202f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + PercentileLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +199,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +260,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - QuantileStates.GroupingState inState = ((PercentileLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - PercentileLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java index 25923bf02a761..79db8bb3401a1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -120,16 +120,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -155,19 +152,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -193,13 +225,51 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -217,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -254,13 +319,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - RateDoubleAggregator.DoubleRateGroupingState inState = ((RateDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - RateDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java index 7dbe1a2de02bd..892ed9c2eb25f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java @@ -63,7 +63,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -122,16 +122,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -157,19 +154,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateFloatAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -195,13 +227,51 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateFloatAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -219,11 +289,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -256,13 +321,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - RateFloatAggregator.FloatRateGroupingState inState = ((RateFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - RateFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java index 4650ebf0c5bb2..bc8445cc9f069 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -120,16 +120,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -155,19 +152,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateIntAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -193,13 +225,51 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateIntAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -217,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values, } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -254,13 +319,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - RateIntAggregator.IntRateGroupingState inState = ((RateIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - RateIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java index a219a58068ea0..16d8c8b0b2fa8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -120,16 +120,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -155,19 +152,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateLongAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -193,13 +225,51 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateLongAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -217,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -254,13 +319,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - RateLongAggregator.LongRateGroupingState inState = ((RateLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - RateLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java index cec8ea8b6c21a..5ddd5fc7e4a30 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java @@ -60,7 +60,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BooleanBlock valuesBlock = page.getBlock(channels.get(0)); BooleanVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -155,9 +161,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SampleBooleanAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -181,12 +198,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBooleanAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -202,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -224,13 +259,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SampleBooleanAggregator.GroupingState inState = ((SampleBooleanGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SampleBooleanAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java index 60e38edd06d1f..0ce4aa997db6e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -157,9 +162,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SampleBytesRefAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,13 +201,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBytesRefAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +247,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,13 +264,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SampleBytesRefAggregator.GroupingState inState = ((SampleBytesRefGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SampleBytesRefAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java index cd76527394432..05e1dc8ba1783 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java @@ -60,7 +60,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -155,9 +161,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SampleDoubleAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -181,12 +198,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleDoubleAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -202,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -224,13 +259,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SampleDoubleAggregator.GroupingState inState = ((SampleDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SampleDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java index b2cf3114fa951..9935a7bfe654d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -112,16 +112,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -145,7 +142,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -154,9 +160,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SampleIntAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -180,12 +197,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleIntAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -201,11 +241,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -223,13 +258,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SampleIntAggregator.GroupingState inState = ((SampleIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SampleIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java index afb1e94a23f5a..7570225ce5f38 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java @@ -60,7 +60,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,16 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -155,9 +161,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SampleLongAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -181,12 +198,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleLongAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -202,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -224,13 +259,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SampleLongAggregator.GroupingState inState = ((SampleLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SampleLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java index 7cf0ab3e7b148..4dd4649472948 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -112,16 +112,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -145,7 +142,26 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -154,9 +170,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + StdDevDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -180,12 +207,45 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -201,11 +261,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,13 +288,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - StdDevStates.GroupingState inState = ((StdDevDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDevDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java index e3bbbb5d4d624..c78ea039edb63 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,26 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +172,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +209,45 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +290,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - StdDevStates.GroupingState inState = ((StdDevFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDevFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java index b0c780b232fe7..32839ee533cbf 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java @@ -60,7 +60,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -113,16 +113,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,26 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -155,9 +171,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + StdDevIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -181,12 +208,45 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -202,11 +262,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -234,13 +289,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - StdDevStates.GroupingState inState = ((StdDevIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDevIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java index 7e33a0c70c145..e06207363bbc6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -112,16 +112,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -145,7 +142,26 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -154,9 +170,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + StdDevLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -180,12 +207,45 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -201,11 +261,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,13 +288,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - StdDevStates.GroupingState inState = ((StdDevLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - StdDevLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java index 303bb3d0ff5dc..88d147c3fd451 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -112,16 +112,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -145,7 +142,26 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -154,9 +170,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -180,12 +207,45 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -201,11 +261,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,13 +288,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SumDoubleAggregator.GroupingSumState inState = ((SumDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SumDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java index 154057db5f462..d7f0f6185d318 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,26 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +172,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SumFloatAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +209,45 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SumFloatAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +290,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SumDoubleAggregator.GroupingSumState inState = ((SumFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SumFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java index 9b5cba8cd5a89..05b29459d1e02 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -112,16 +112,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -145,7 +142,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -154,9 +165,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -180,12 +204,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -201,11 +255,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -230,15 +279,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LongArrayState inState = ((SumIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index a2969a4dddaa8..31779335e80c0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -58,7 +58,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -111,16 +111,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -144,7 +141,21 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -153,9 +164,22 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -179,12 +203,42 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -200,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -229,15 +278,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - LongArrayState inState = ((SumLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); - } + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java index 1fa211364cfcc..f6238670a776a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java @@ -62,7 +62,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BooleanBlock valuesBlock = page.getBlock(channels.get(0)); BooleanVector valuesVector = valuesBlock.asVector(); @@ -115,16 +115,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -148,7 +145,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BooleanBlock top = (BooleanBlock) topUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -157,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopBooleanAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -183,12 +199,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BooleanBlock top = (BooleanBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBooleanAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -204,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +258,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopBooleanAggregator.GroupingState inState = ((TopBooleanGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopBooleanAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java index 4ab5bb9875107..12f1456327264 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java @@ -63,7 +63,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -117,16 +117,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,7 +148,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -161,9 +166,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopBytesRefAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -188,13 +205,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBytesRefAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -211,11 +251,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,13 +268,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopBytesRefAggregator.GroupingState inState = ((TopBytesRefGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopBytesRefAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java index 8a2f4aef9cf35..11ba0cbea0d6b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java @@ -62,7 +62,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -115,16 +115,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -148,7 +145,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -157,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopDoubleAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -183,12 +199,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopDoubleAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -204,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +258,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopDoubleAggregator.GroupingState inState = ((TopDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java index d09bf60c82aca..32dfcaaffcde9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java @@ -62,7 +62,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -115,16 +115,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -148,7 +145,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -157,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopFloatAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -183,12 +199,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopFloatAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -204,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +258,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopFloatAggregator.GroupingState inState = ((TopFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java index 786f0660ea06f..1a0dea4b8d0eb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -147,7 +144,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -156,9 +161,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopIntAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -182,12 +198,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopIntAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -203,11 +241,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -224,13 +257,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopIntAggregator.GroupingState inState = ((TopIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java index 3d1137486fb75..ad0e75b625e3d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java @@ -63,7 +63,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -117,16 +117,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,7 +148,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -161,9 +166,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopIpAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -188,13 +205,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopIpAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -211,11 +251,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,13 +268,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopIpAggregator.GroupingState inState = ((TopIpGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopIpAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java index 820aa3c6c63e1..71e17e29be5fb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java @@ -62,7 +62,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -115,16 +115,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -148,7 +145,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -157,9 +162,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + TopLongAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -183,12 +199,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopLongAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -204,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -225,13 +258,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - TopLongAggregator.GroupingState inState = ((TopLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - TopLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java index a928d0908eb8e..896d037cf68fb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java @@ -55,7 +55,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BooleanBlock valuesBlock = page.getBlock(channels.get(0)); BooleanVector valuesVector = valuesBlock.asVector(); @@ -108,16 +108,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -141,7 +138,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BooleanBlock values = (BooleanBlock) valuesUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -150,9 +155,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + ValuesBooleanAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -176,12 +192,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BooleanBlock values = (BooleanBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesBooleanAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -197,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,13 +251,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - ValuesBooleanAggregator.GroupingState inState = ((ValuesBooleanGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - ValuesBooleanAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index ebbc4cd5eb8a3..da8e93f9cf61a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -56,7 +56,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -112,16 +112,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -146,7 +143,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BytesRefBlock values = (BytesRefBlock) valuesUncast; BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -156,9 +161,21 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -183,13 +200,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BytesRefBlock values = (BytesRefBlock) valuesUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -206,11 +246,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,13 +263,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - ValuesBytesRefAggregator.GroupingState inState = ((ValuesBytesRefGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - ValuesBytesRefAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java index e61ffacb17274..3a35f48fee5f8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java @@ -55,7 +55,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { DoubleBlock valuesBlock = page.getBlock(channels.get(0)); DoubleVector valuesVector = valuesBlock.asVector(); @@ -108,16 +108,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -141,7 +138,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -150,9 +155,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + ValuesDoubleAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -176,12 +192,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesDoubleAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -197,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,13 +251,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - ValuesDoubleAggregator.GroupingState inState = ((ValuesDoubleGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - ValuesDoubleAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java index d7eb4bc97bacb..4917f61a23f8d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java @@ -55,7 +55,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { FloatBlock valuesBlock = page.getBlock(channels.get(0)); FloatVector valuesVector = valuesBlock.asVector(); @@ -108,16 +108,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -141,7 +138,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -150,9 +155,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + ValuesFloatAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -176,12 +192,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesFloatAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -197,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,13 +251,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - ValuesFloatAggregator.GroupingState inState = ((ValuesFloatGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - ValuesFloatAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java index bd34ac0d27098..d0e094099af4e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java @@ -54,7 +54,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -107,16 +107,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -140,7 +137,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -149,9 +154,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + ValuesIntAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -175,12 +191,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesIntAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -196,11 +234,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,13 +250,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - ValuesIntAggregator.GroupingState inState = ((ValuesIntGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - ValuesIntAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java index 39f485f3b174d..287013a1dc136 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java @@ -55,7 +55,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -108,16 +108,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -141,7 +138,15 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -150,9 +155,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + ValuesLongAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -176,12 +192,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesLongAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -197,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,13 +251,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - ValuesLongAggregator.GroupingState inState = ((ValuesLongGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - ValuesLongAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java index a959f808e438b..5116ea389510a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java @@ -65,7 +65,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -118,16 +118,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,7 +148,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -160,9 +186,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SpatialCentroidCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -186,12 +223,55 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -249,13 +324,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - CentroidPointAggregator.GroupingCentroidState inState = ((SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialCentroidCartesianPointDocValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java index a3593b8152dd7..e0508288abbe3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java @@ -68,7 +68,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -122,16 +122,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -156,19 +153,59 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -193,13 +230,56 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -216,11 +296,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -258,13 +333,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - CentroidPointAggregator.GroupingCentroidState inState = ((SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialCentroidCartesianPointSourceValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java index 77a959e654862..23936066d214b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java @@ -65,7 +65,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -118,16 +118,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,7 +148,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -160,9 +186,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SpatialCentroidGeoPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -186,12 +223,55 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -249,13 +324,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - CentroidPointAggregator.GroupingCentroidState inState = ((SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialCentroidGeoPointDocValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java index fc05c0932f50c..2707f5c78cf62 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java @@ -68,7 +68,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -122,16 +122,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -156,19 +153,59 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -193,13 +230,56 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -216,11 +296,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -258,13 +333,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - CentroidPointAggregator.GroupingCentroidState inState = ((SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialCentroidGeoPointSourceValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java index 76f66cf41d569..17c887a5e0035 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java @@ -63,7 +63,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -116,16 +116,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -149,7 +146,31 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -158,9 +179,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SpatialExtentCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -184,12 +216,50 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -205,11 +275,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -242,13 +307,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingState inState = ((SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentCartesianPointDocValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java index 3c1159eb0de11..1c4169263e9f0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java @@ -64,7 +64,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -118,16 +118,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -152,19 +149,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -189,13 +221,51 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -212,11 +282,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -249,13 +314,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingState inState = ((SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentCartesianPointSourceValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java index 7057281c2ec6f..d9d834d96c2c6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java @@ -61,7 +61,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -114,16 +114,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); int[] valuesArray = new int[valuesEnd - valuesStart]; @@ -139,7 +136,31 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val // This type does not support vectors because all values are multi-valued } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -148,9 +169,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SpatialExtentCartesianShapeDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); int[] valuesArray = new int[valuesEnd - valuesStart]; @@ -166,12 +198,50 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector // This type does not support vectors because all values are multi-valued } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); int[] valuesArray = new int[valuesEnd - valuesStart]; @@ -186,11 +256,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) // This type does not support vectors because all values are multi-valued } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -223,13 +288,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingState inState = ((SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentCartesianShapeDocValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java index 21241efbf3198..c568de2dbd6be 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java @@ -64,7 +64,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -118,16 +118,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -152,19 +149,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -189,13 +221,51 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -212,11 +282,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -249,13 +314,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingState inState = ((SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentCartesianShapeSourceValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java index 387ed0abc34bb..e80e6d4391dc3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java @@ -65,7 +65,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { LongBlock valuesBlock = page.getBlock(channels.get(0)); LongVector valuesVector = valuesBlock.asVector(); @@ -118,16 +118,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -151,7 +148,41 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -160,9 +191,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SpatialExtentGeoPointDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -186,12 +228,60 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -207,11 +297,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -254,13 +339,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingStateWrappedLongitudeState inState = ((SpatialExtentGeoPointDocValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentGeoPointDocValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java index 9d9c10902ada6..43a2662a229c4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java @@ -66,7 +66,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -120,16 +120,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -154,19 +151,64 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -191,13 +233,61 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -214,11 +304,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -261,13 +346,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingStateWrappedLongitudeState inState = ((SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentGeoPointSourceValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java index 82553910e1587..6ad4a92e83c7e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java @@ -63,7 +63,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { IntBlock valuesBlock = page.getBlock(channels.get(0)); IntVector valuesVector = valuesBlock.asVector(); @@ -116,16 +116,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); int[] valuesArray = new int[valuesEnd - valuesStart]; @@ -141,7 +138,41 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val // This type does not support vectors because all values are multi-valued } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -150,9 +181,20 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } + SpatialExtentGeoShapeDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); int[] valuesArray = new int[valuesEnd - valuesStart]; @@ -168,12 +210,60 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector // This type does not support vectors because all values are multi-valued } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); int[] valuesArray = new int[valuesEnd - valuesStart]; @@ -188,11 +278,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) // This type does not support vectors because all values are multi-valued } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,13 +320,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingStateWrappedLongitudeState inState = ((SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentGeoShapeDocValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java index ccab0870e206d..7d8f8fefc722b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java @@ -66,7 +66,7 @@ public int intermediateBlockCount() { } @Override - public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); BytesRefVector valuesVector = valuesBlock.asVector(); @@ -120,16 +120,13 @@ public void close() { private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -154,19 +151,64 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -191,13 +233,61 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); if (values.isNull(groupPosition + positionOffset)) { continue; } + int groupId = groups.getInt(groupPosition); int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { @@ -214,11 +304,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -261,13 +346,8 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); - } - SpatialExtentGroupingStateWrappedLongitudeState inState = ((SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction) input).state; - state.enableGroupIdTracking(new SeenGroupIds.Empty()); - SpatialExtentGeoShapeSourceValuesAggregator.combineStates(state, groupId, inState, position); + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBooleanAggregator.java index 218af8fcb705e..033ac62e45fd4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBooleanAggregator.java @@ -50,10 +50,6 @@ public static void combine(GroupingState current, int groupId, boolean v) { current.collect(groupId, v); } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - current.combineStates(currentGroupId, state); - } - public static void combineIntermediate(GroupingState current, int groupId, boolean fbit, boolean tbit) { if (fbit) current.bits.set(groupId * 2); if (tbit) current.bits.set(groupId * 2 + 1); @@ -120,11 +116,6 @@ void collect(int groupId, boolean v) { trackGroupId(groupId); } - void combineStates(int currentGroupId, GroupingState state) { - bits.or(state.bits); - trackGroupId(currentGroupId); - } - /** Extracts an intermediate view of the contents of this state. */ @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBytesRefAggregator.java index 13a9e00bb28ab..84b7f4558b9e9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBytesRefAggregator.java @@ -50,15 +50,6 @@ public static void combineIntermediate(HllStates.GroupingState current, int grou current.merge(groupId, inValue, 0); } - public static void combineStates( - HllStates.GroupingState current, - int currentGroupId, - HllStates.GroupingState state, - int statePosition - ) { - current.merge(currentGroupId, state.hll, statePosition); - } - public static Block evaluateFinal(HllStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctDoubleAggregator.java index 46a0d24cec8c4..ef001408bdf63 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctDoubleAggregator.java @@ -50,15 +50,6 @@ public static void combineIntermediate(HllStates.GroupingState current, int grou current.merge(groupId, inValue, 0); } - public static void combineStates( - HllStates.GroupingState current, - int currentGroupId, - HllStates.GroupingState state, - int statePosition - ) { - current.merge(currentGroupId, state.hll, statePosition); - } - public static Block evaluateFinal(HllStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctFloatAggregator.java index 2159f0864e1cf..e744bea7bf338 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctFloatAggregator.java @@ -50,15 +50,6 @@ public static void combineIntermediate(HllStates.GroupingState current, int grou current.merge(groupId, inValue, 0); } - public static void combineStates( - HllStates.GroupingState current, - int currentGroupId, - HllStates.GroupingState state, - int statePosition - ) { - current.merge(currentGroupId, state.hll, statePosition); - } - public static Block evaluateFinal(HllStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregator.java index 9c29eb98f2987..e87e9e5a593ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctIntAggregator.java @@ -50,15 +50,6 @@ public static void combineIntermediate(HllStates.GroupingState current, int grou current.merge(groupId, inValue, 0); } - public static void combineStates( - HllStates.GroupingState current, - int currentGroupId, - HllStates.GroupingState state, - int statePosition - ) { - current.merge(currentGroupId, state.hll, statePosition); - } - public static Block evaluateFinal(HllStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregator.java index 59570e2f5a7ef..ccdc1c289d5fe 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctLongAggregator.java @@ -50,15 +50,6 @@ public static void combineIntermediate(HllStates.GroupingState current, int grou current.merge(groupId, inValue, 0); } - public static void combineStates( - HllStates.GroupingState current, - int currentGroupId, - HllStates.GroupingState state, - int statePosition - ) { - current.merge(currentGroupId, state.hll, statePosition); - } - public static Block evaluateFinal(HllStates.GroupingState state, IntVector selected, DriverContext driverContext) { try (LongBlock.Builder builder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount())) { for (int i = 0; i < selected.getPositionCount(); i++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java index 611118d03872b..f5b7a73a54a1d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java @@ -59,7 +59,7 @@ public int intermediateBlockCount() { } @Override - public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { Block valuesBlock = page.getBlock(blockIndex()); if (countAll == false) { Vector valuesVector = valuesBlock.asVector(); @@ -112,10 +112,10 @@ public void close() {} private void addRawInput(int positionOffset, IntVector groups, Block values) { int position = positionOffset; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { - int groupId = Math.toIntExact(groups.getInt(groupPosition)); if (values.isNull(position)) { continue; } + int groupId = groups.getInt(groupPosition); state.increment(groupId, values.getValueCount(position)); } } @@ -123,16 +123,13 @@ private void addRawInput(int positionOffset, IntVector groups, Block values) { private void addRawInput(int positionOffset, IntArrayBlock groups, Block values) { int position = positionOffset; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(position)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { - int groupId = Math.toIntExact(groups.getInt(g)); - if (values.isNull(position)) { - continue; - } + int groupId = groups.getInt(g); state.increment(groupId, values.getValueCount(position)); } } @@ -141,16 +138,13 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, Block values) private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block values) { int position = positionOffset; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { - if (groups.isNull(groupPosition)) { + if (groups.isNull(groupPosition) || values.isNull(position)) { continue; } int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { - int groupId = Math.toIntExact(groups.getInt(g)); - if (values.isNull(position)) { - continue; - } + int groupId = groups.getInt(g); state.increment(groupId, values.getValueCount(position)); } } @@ -161,7 +155,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block valu */ private void addRawInput(IntVector groups) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = Math.toIntExact(groups.getInt(groupPosition)); + int groupId = groups.getInt(groupPosition); state.increment(groupId, 1); } } @@ -177,7 +171,7 @@ private void addRawInput(IntArrayBlock groups) { int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { - int groupId = Math.toIntExact(groups.getInt(g)); + int groupId = groups.getInt(g); state.increment(groupId, 1); } } @@ -194,7 +188,7 @@ private void addRawInput(IntBigArrayBlock groups) { int groupStart = groups.getFirstValueIndex(groupPosition); int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { - int groupId = Math.toIntExact(groups.getInt(g)); + int groupId = groups.getInt(g); state.increment(groupId, 1); } } @@ -206,7 +200,7 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { } @Override - public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { assert channels.size() == intermediateBlockCount(); assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -214,19 +208,49 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page BooleanVector seen = page.getBlock(channels.get(1)).asVector(); assert count.getPositionCount() == seen.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - state.increment(Math.toIntExact(groups.getInt(groupPosition)), count.getLong(groupPosition + positionOffset)); + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, count.getLong(groupPosition + positionOffset)); + } } } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input.getClass() != getClass()) { - throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + LongVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, count.getLong(groupPosition + positionOffset)); + } } - final LongArrayState inState = ((CountGroupingAggregatorFunction) input).state; + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); state.enableGroupIdTracking(new SeenGroupIds.Empty()); - if (inState.hasValue(position)) { - state.increment(groupId, inState.get(position)); + LongVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + state.increment(groups.getInt(groupPosition), count.getLong(groupPosition + positionOffset)); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java index 8b7734fe33ab7..121d8e213dcbd 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java @@ -39,13 +39,13 @@ record FilteredGroupingAggregatorFunction(GroupingAggregatorFunction next, EvalO } @Override - public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { try (BooleanBlock filterResult = ((BooleanBlock) filter.eval(page))) { ToMask mask = filterResult.toMask(); // TODO warn on mv fields AddInput nextAdd = null; try { - nextAdd = next.prepareProcessPage(seenGroupIds, page); + nextAdd = next.prepareProcessRawInputPage(seenGroupIds, page); AddInput result = new FilteredAddInput(mask.mask(), nextAdd, page.getPositionCount()); mask = null; nextAdd = null; @@ -101,13 +101,18 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { } @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { next.addIntermediateInput(positionOffset, groupIdVector, page); } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - next.addIntermediateRowInput(groupId, ((FilteredGroupingAggregatorFunction) input).next(), position); + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + next.addIntermediateInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + next.addIntermediateInput(positionOffset, groupIdVector, page); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java index 19012cabce5a1..d87ca338c6589 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java @@ -40,7 +40,7 @@ public FromPartialGroupingAggregatorFunction(GroupingAggregatorFunction delegate } @Override - public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { return new AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { @@ -76,17 +76,21 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { } @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { final CompositeBlock inputBlock = page.getBlock(inputChannel); delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input instanceof FromPartialGroupingAggregatorFunction toPartial) { - input = toPartial.delegate; - } - delegate.addIntermediateRowInput(groupId, input, position); + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(inputChannel); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(inputChannel); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java index e0d82b1f145b8..b0edca4ae25ba 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java @@ -11,7 +11,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -24,6 +23,10 @@ public class GroupingAggregator implements Releasable { private final AggregatorMode mode; + public AggregatorMode getMode() { + return mode; + } + public interface Factory extends Function, Describable {} public GroupingAggregator(GroupingAggregatorFunction aggregatorFunction, AggregatorMode mode) { @@ -42,19 +45,14 @@ public int evaluateBlockCount() { public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { if (mode.isInputPartial()) { return new GroupingAggregatorFunction.AddInput() { - @Override - public void add(int positionOffset, IntBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); - } - @Override public void add(int positionOffset, IntArrayBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); + aggregatorFunction.addIntermediateInput(positionOffset, groupIds, page); } @Override public void add(int positionOffset, IntBigArrayBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); + aggregatorFunction.addIntermediateInput(positionOffset, groupIds, page); } @Override @@ -66,17 +64,10 @@ public void add(int positionOffset, IntVector groupIds) { public void close() {} }; } else { - return aggregatorFunction.prepareProcessPage(seenGroupIds, page); + return aggregatorFunction.prepareProcessRawInputPage(seenGroupIds, page); } } - /** - * Add the position-th row from the intermediate output of the given aggregator to this aggregator at the groupId position - */ - public void addIntermediateRow(int groupId, GroupingAggregator input, int position) { - aggregatorFunction.addIntermediateRowInput(groupId, input.aggregatorFunction, position); - } - /** * Build the results for this aggregation. * @param selected the groupIds that have been selected to be included in diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index 556902174f213..a60bcb1523ffc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -105,13 +105,13 @@ default void add(int positionOffset, IntBlock groupIds) { } /** - * Prepare to process a single page of input. + * Prepare to process a single page of raw input. *

* This should load the input {@link Block}s and check their types and * select an optimal path and return that path as an {@link AddInput}. *

*/ - AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page); // TODO allow returning null to opt out of the callback loop + AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page); // TODO allow returning null to opt out of the callback loop /** * Call this to signal to the aggregation that the {@code selected} @@ -126,12 +126,17 @@ default void add(int positionOffset, IntBlock groupIds) { /** * Add data produced by {@link #evaluateIntermediate}. */ - void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page); + void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page); /** - * Add the position-th row from the intermediate output of the given aggregator function to the groupId + * Add data produced by {@link #evaluateIntermediate}. */ - void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position); + void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page); + + /** + * Add data produced by {@link #evaluateIntermediate}. + */ + void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page); /** * Build the intermediate results for this aggregation. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java index 64a970c2acc07..9ffea07121dd4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java @@ -168,10 +168,6 @@ long cardinality(int groupId) { return hll.cardinality(groupId); } - void merge(int groupId, AbstractHyperLogLogPlusPlus other, int otherGroup) { - hll.merge(groupId, other, otherGroup); - } - void merge(int groupId, BytesRef other, int otherGroup) { hll.merge(groupId, deserializeHLL(other), otherGroup); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java index 049642c350917..7d731c1e9ac0c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java @@ -62,10 +62,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState state, int groupId, GroupingState otherState, int otherGroupId) { - state.combine(groupId, otherState, otherGroupId); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(selected, driverContext); } @@ -83,12 +79,6 @@ public void add(int groupId, BytesRef value) { } } - public void combine(int groupId, GroupingState otherState, int otherGroupId) { - if (otherState.internalState.hasValue(otherGroupId)) { - add(groupId, otherState.internalState.get(otherGroupId)); - } - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java index 43b4a4a2fe0a1..d8c8146f09807 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java @@ -58,10 +58,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState state, int groupId, GroupingState otherState, int otherGroupId) { - state.combine(groupId, otherState, otherGroupId); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(selected, driverContext); } @@ -80,12 +76,6 @@ public void add(int groupId, BytesRef value) { } } - public void combine(int groupId, GroupingState otherState, int otherGroupId) { - if (otherState.internalState.hasValue(otherGroupId)) { - add(groupId, otherState.internalState.get(otherGroupId, otherState.scratch)); - } - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleAggregator.java index d1c21cef90f30..7a195613fdb08 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluateMedianAbsoluteDeviation(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatAggregator.java index 743ed5d4ca5e1..2b07899174ebd 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatAggregator.java @@ -48,15 +48,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluateMedianAbsoluteDeviation(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java index 09f521eb7c0a0..f1482cba2949c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluateMedianAbsoluteDeviation(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregator.java index 723959bb87827..7867a6ccebd21 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluateMedianAbsoluteDeviation(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java index 677b38a9af3a7..142deb4058f3e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java @@ -62,10 +62,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState state, int groupId, GroupingState otherState, int otherGroupId) { - state.combine(groupId, otherState, otherGroupId); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(selected, driverContext); } @@ -83,12 +79,6 @@ public void add(int groupId, BytesRef value) { } } - public void combine(int groupId, GroupingState otherState, int otherGroupId) { - if (otherState.internalState.hasValue(otherGroupId)) { - add(groupId, otherState.internalState.get(otherGroupId)); - } - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java index c4ee93db89cf8..c5cbdab2c68a5 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java @@ -58,10 +58,6 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe } } - public static void combineStates(GroupingState state, int groupId, GroupingState otherState, int otherGroupId) { - state.combine(groupId, otherState, otherGroupId); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(selected, driverContext); } @@ -80,12 +76,6 @@ public void add(int groupId, BytesRef value) { } } - public void combine(int groupId, GroupingState otherState, int otherGroupId) { - if (otherState.internalState.hasValue(otherGroupId)) { - add(groupId, otherState.internalState.get(otherGroupId, otherState.scratch)); - } - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileDoubleAggregator.java index dabdba38566ae..7b20903907806 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileDoubleAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluatePercentile(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileFloatAggregator.java index 6a0c7e40285a5..07feae7afa9bd 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileFloatAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluatePercentile(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileIntAggregator.java index cf6e66ca9ca6c..f814b6d43d148 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileIntAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selected, DriverContext driverContext) { return state.evaluatePercentile(selected, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileLongAggregator.java index 2138033a80437..4814f3d25ac5c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PercentileLongAggregator.java @@ -47,15 +47,6 @@ public static void combineIntermediate(QuantileStates.GroupingState state, int g state.add(groupId, inValue); } - public static void combineStates( - QuantileStates.GroupingState current, - int currentGroupId, - QuantileStates.GroupingState state, - int statePosition - ) { - current.add(currentGroupId, state.getOrNull(statePosition)); - } - public static Block evaluateFinal(QuantileStates.GroupingState state, IntVector selectedGroups, DriverContext driverContext) { return state.evaluatePercentile(selectedGroups, driverContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumDoubleAggregator.java index 5e46225a873f8..d3ad2c07ac7ef 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumDoubleAggregator.java @@ -69,12 +69,6 @@ public static void combine(GroupingSumState current, int groupId, double v) { current.add(v, groupId); } - public static void combineStates(GroupingSumState current, int groupId, GroupingSumState state, int statePosition) { - if (state.hasValue(statePosition)) { - current.add(state.values.get(statePosition), state.deltas.get(statePosition), groupId); - } - } - public static void combineIntermediate(GroupingSumState current, int groupId, double inValue, double inDelta, boolean seen) { if (seen) { current.add(inValue, inDelta, groupId); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java index e0087a0ad2340..5aa489f6e2fd9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java @@ -10,6 +10,8 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.CompositeBlock; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; @@ -55,8 +57,8 @@ public ToPartialGroupingAggregatorFunction(GroupingAggregatorFunction delegate, } @Override - public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { - return delegate.prepareProcessPage(seenGroupIds, page); + public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) { + return delegate.prepareProcessRawInputPage(seenGroupIds, page); } @Override @@ -65,17 +67,21 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { } @Override - public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { final CompositeBlock inputBlock = page.getBlock(channels.get(0)); delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); } @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - if (input instanceof ToPartialGroupingAggregatorFunction toPartial) { - input = toPartial.delegate; - } - delegate.addIntermediateRowInput(groupId, input, position); + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(channels.get(0)); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(channels.get(0)); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java index e19d3107172e3..c070d02de627e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java @@ -66,19 +66,6 @@ public static void combineIntermediate(GroupingState state, int groupId, Boolean } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - long stateOffset = ((long) statePosition) << 1; - boolean seenFalse = state.values.get(stateOffset); - boolean seenTrue = state.values.get(stateOffset | 1); - - if (seenFalse) { - combine(current, currentGroupId, false); - } - if (seenTrue) { - combine(current, currentGroupId, true); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st index 68accfd6ad3de..639c7d4ab1831 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st @@ -29,8 +29,6 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import java.util.Arrays; - /** * A rate grouping aggregation definition for $type$. * This class is generated. Edit `X-RateAggregator.java.st` instead. @@ -65,15 +63,6 @@ public class Rate$Type$Aggregator { current.combine(groupId, timestamps, values, sampleCount, reset, otherPosition); } - public static void combineStates( - $Type$RateGroupingState current, - int currentGroupId, // make the stylecheck happy - $Type$RateGroupingState otherState, - int otherGroupId - ) { - current.combineState(currentGroupId, otherState, otherGroupId); - } - public static Block evaluateFinal($Type$RateGroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } @@ -222,25 +211,6 @@ public class Rate$Type$Aggregator { } } - void combineState(int groupId, $Type$RateGroupingState otherState, int otherGroupId) { - var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null; - if (other == null) { - return; - } - ensureCapacity(groupId); - var curr = states.get(groupId); - if (curr == null) { - var len = other.entries(); - adjustBreaker($Type$RateState.bytesUsed(len)); - curr = new $Type$RateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len)); - curr.reset = other.reset; - curr.sampleCount = other.sampleCount; - states.set(groupId, curr); - } else { - states.set(groupId, mergeState(curr, other)); - } - } - $Type$RateState mergeState($Type$RateState s1, $Type$RateState s2) { var newLen = s1.entries() + s2.entries(); adjustBreaker($Type$RateState.bytesUsed(newLen)); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st index 90539cb9ecf68..46de952d5c7c5 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st @@ -84,10 +84,6 @@ class Sample$Type$Aggregator { } } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); } @@ -150,10 +146,6 @@ class Sample$Type$Aggregator { bytesRefBuilder.clear(); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -185,10 +177,6 @@ class Sample$Type$Aggregator { internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st index 510d770f90d62..4338b1e5aee9b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st @@ -52,10 +52,6 @@ public class StdDev$Type$Aggregator { current.add(groupId, value); } - public static void combineStates(StdDevStates.GroupingState current, int groupId, StdDevStates.GroupingState state, int statePosition) { - current.combine(groupId, state.getOrNull(statePosition)); - } - public static void combineIntermediate(StdDevStates.GroupingState state, int groupId, double mean, double m2, long count) { state.combine(groupId, mean, m2, count); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st index 761b70791e946..839fc69c6645f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st @@ -90,10 +90,6 @@ $else$ $endif$ } - public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { - current.merge(groupId, state, statePosition); - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -115,10 +111,6 @@ $endif$ sort.collect(value, groupId); } - public void merge(int groupId, GroupingState other, int otherGroupId) { - sort.merge(groupId, other.sort, otherGroupId); - } - @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); @@ -150,10 +142,6 @@ $endif$ internalState.add(0, value); } - public void merge(GroupingState other) { - internalState.merge(0, other, 0); - } - @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValueOverTimeAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValueOverTimeAggregator.java.st index 13ebd58aa1023..31633b151fa39 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValueOverTimeAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValueOverTimeAggregator.java.st @@ -60,14 +60,6 @@ public class $Occurrence$OverTime$Type$Aggregator { } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState otherState, int otherGroupId) { - if (otherGroupId < otherState.timestamps.size() && otherState.hasValue(otherGroupId)) { - var timestamp = otherState.timestamps.get(otherGroupId); - var value = otherState.values.get(otherGroupId); - current.collectValue(currentGroupId, timestamp, value); - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext evalContext) { return state.evaluateFinal(selected, evalContext); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index fa8ffecea052d..cc084644832ca 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -128,23 +128,6 @@ $endif$ } } - public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { - if (statePosition > state.maxGroupId) { - return; - } - var sorted = state.sortedForOrdinalMerging(current); - var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0; - var end = sorted.counts[statePosition]; - for (int i = start; i < end; i++) { - int id = sorted.ids[i]; -$if(BytesRef)$ - current.addValueOrdinal(currentGroupId, id); -$else$ - current.addValue(currentGroupId, state.getValue(id)); -$endif$ - } - } - public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(driverContext.blockFactory(), selected); } @@ -222,10 +205,6 @@ $endif$ * and then use it to iterate over the values in order. * * @param ids positions of the {@link GroupingState#values} to read. -$if(BytesRef)$ - * If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)}, - * these are ordinals referring to the {@link GroupingState#bytes} in the target state. -$endif$ */ private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { @Override @@ -255,8 +234,6 @@ $elseif(int||float)$ private final LongHash values; $endif$ - private Sorted sortedForOrdinalMerging = null; - private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); $if(long||double)$ @@ -436,36 +413,6 @@ $endif$ } } - private Sorted sortedForOrdinalMerging(GroupingState other) { - if (sortedForOrdinalMerging == null) { - try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) { - sortedForOrdinalMerging = buildSorted(selected); -$if(BytesRef)$ - // hash all the bytes to the destination to avoid hashing them multiple times - BytesRef scratch = new BytesRef(); - final int totalValue = Math.toIntExact(bytes.size()); - blockFactory.adjustBreaker((long) totalValue * Integer.BYTES); - try { - final int[] mappedIds = new int[totalValue]; - for (int i = 0; i < totalValue; i++) { - var v = bytes.get(i, scratch); - mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v))); - } - // no longer need the bytes - bytes.close(); - bytes = null; - for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) { - sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))]; - } - } finally { - blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES); - } -$endif$ - } - } - return sortedForOrdinalMerging; - } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -559,9 +506,9 @@ $endif$ @Override public void close() { $if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values, bytes); $else$ - Releasables.closeExpectNoException(values, sortedForOrdinalMerging); + Releasables.closeExpectNoException(values); $endif$ } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index 1cae296f09c02..5aec8a931a587 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -127,17 +127,50 @@ public abstract class BlockHash implements Releasable, SeenGroupIds { */ public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {} + public interface EmptyBucketGenerator { + int getEmptyBucketCount(); + + void generate(Block.Builder blockBuilder); + } + + public record CategorizeDef(String analyzer, OutputFormat outputFormat, int similarityThreshold) { + public enum OutputFormat { + REGEX, + TOKENS + } + } + /** - * @param isCategorize Whether this group is a CATEGORIZE() or not. - * May be changed in the future when more stateful grouping functions are added. + * Configuration for a BlockHash group spec that is doing text categorization. */ - public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) { + public record CategorizeDef(String analyzer, OutputFormat outputFormat, int similarityThreshold) { + public enum OutputFormat { + REGEX, + TOKENS + } + } + + public record GroupSpec( + int channel, + ElementType elementType, + @Nullable CategorizeDef categorizeDef, + @Nullable TopNDef topNDef, + @Nullable EmptyBucketGenerator emptyBucketGenerator + ) { public GroupSpec(int channel, ElementType elementType) { - this(channel, elementType, false, null); + this(channel, elementType, null, null, null); + } + + public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) { + this(channel, elementType, categorizeDef, null, null); + } + + public GroupSpec(int channel, ElementType elementType, EmptyBucketGenerator emptyBucketGenerator) { + this(channel, elementType, null, null, emptyBucketGenerator); } - public GroupSpec(int channel, ElementType elementType, boolean isCategorize) { - this(channel, elementType, isCategorize, null); + public boolean isCategorize() { + return categorizeDef != null; } } @@ -207,7 +240,13 @@ public static BlockHash buildCategorizeBlockHash( int emitBatchSize ) { if (groups.size() == 1) { - return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry); + return new CategorizeBlockHash( + blockFactory, + groups.get(0).channel, + aggregatorMode, + groups.get(0).categorizeDef, + analysisRegistry + ); } else { assert groups.get(0).isCategorize(); assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java index 5e716d8c9d5ff..fcc1a7f3d271e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java @@ -18,7 +18,6 @@ import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; @@ -47,12 +46,13 @@ */ public class CategorizeBlockHash extends BlockHash { - private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig + private static final CategorizationAnalyzerConfig DEFAULT_ANALYZER_CONFIG = CategorizationAnalyzerConfig .buildStandardEsqlCategorizationAnalyzer(); private static final int NULL_ORD = 0; private final int channel; private final AggregatorMode aggregatorMode; + private final CategorizeDef categorizeDef; private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; private final CategorizeEvaluator evaluator; @@ -64,28 +64,38 @@ public class CategorizeBlockHash extends BlockHash { */ private boolean seenNull = false; - CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) { + CategorizeBlockHash( + BlockFactory blockFactory, + int channel, + AggregatorMode aggregatorMode, + CategorizeDef categorizeDef, + AnalysisRegistry analysisRegistry + ) { super(blockFactory); this.channel = channel; this.aggregatorMode = aggregatorMode; + this.categorizeDef = categorizeDef; this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer( new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())), CategorizationPartOfSpeechDictionary.getInstance(), - 0.70f + categorizeDef.similarityThreshold() / 100.0f ); if (aggregatorMode.isInputPartial() == false) { - CategorizationAnalyzer analyzer; + CategorizationAnalyzer categorizationAnalyzer; try { Objects.requireNonNull(analysisRegistry); - analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG); - } catch (Exception e) { + CategorizationAnalyzerConfig config = categorizeDef.analyzer() == null + ? DEFAULT_ANALYZER_CONFIG + : new CategorizationAnalyzerConfig.Builder().setAnalyzer(categorizeDef.analyzer()).build(); + categorizationAnalyzer = new CategorizationAnalyzer(analysisRegistry, config); + } catch (IOException e) { categorizer.close(); throw new RuntimeException(e); } - this.evaluator = new CategorizeEvaluator(analyzer); + this.evaluator = new CategorizeEvaluator(categorizationAnalyzer); } else { this.evaluator = null; } @@ -114,7 +124,7 @@ public IntVector nonEmpty() { @Override public BitArray seenGroupIds(BigArrays bigArrays) { - return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays); + return new Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays); } @Override @@ -222,7 +232,7 @@ private Block buildFinalBlock() { try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) { result.appendNull(); for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - scratch.copyChars(category.getRegex()); + scratch.copyChars(getKeyString(category)); result.appendBytesRef(scratch.get()); scratch.clear(); } @@ -232,7 +242,7 @@ private Block buildFinalBlock() { try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - scratch.copyChars(category.getRegex()); + scratch.copyChars(getKeyString(category)); result.appendBytesRef(scratch.get()); scratch.clear(); } @@ -240,6 +250,13 @@ private Block buildFinalBlock() { } } + private String getKeyString(SerializableTokenListCategory category) { + return switch (categorizeDef.outputFormat()) { + case REGEX -> category.getRegex(); + case TOKENS -> category.getKeyTokensString(); + }; + } + /** * Similar implementation to an Evaluator. */ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java index 20874cb10ceb8..bb5f0dee8ca2d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java @@ -56,6 +56,8 @@ public class CategorizePackedValuesBlockHash extends BlockHash { int emitBatchSize ) { super(blockFactory); + assert specs.get(0).categorizeDef() != null; + this.specs = specs; this.aggregatorMode = aggregatorMode; blocks = new Block[specs.size()]; @@ -68,7 +70,13 @@ public class CategorizePackedValuesBlockHash extends BlockHash { boolean success = false; try { - categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry); + categorizeBlockHash = new CategorizeBlockHash( + blockFactory, + specs.get(0).channel(), + aggregatorMode, + specs.get(0).categorizeDef(), + analysisRegistry + ); packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize); success = true; } finally { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index 6eb3aef6dfc8b..6f08a97a2a01e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -49,6 +49,8 @@ * 2, 3, 4 * 2, 3, 5 * 3, 2, 4 + * 3, 2, 5 + * 3, 3, 4 * 3, 3, 5 * } *

diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java index c3b07d069cf11..b4df19cddae97 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java @@ -44,10 +44,6 @@ public static void combine(CentroidState current, double xVal, double xDel, doub current.add(xVal, xDel, yVal, yDel, count); } - public static void combineStates(CentroidState current, CentroidState state) { - current.add(state); - } - public static void combineIntermediate(CentroidState state, double xIn, double dx, double yIn, double dy, long count) { if (count > 0) { combine(state, xIn, dx, yIn, dy, count); @@ -68,19 +64,6 @@ public static Block evaluateFinal(CentroidState state, DriverContext driverConte return state.toBlock(driverContext.blockFactory()); } - public static void combineStates(GroupingCentroidState current, int groupId, GroupingCentroidState state, int statePosition) { - if (state.hasValue(statePosition)) { - current.add( - state.xValues.get(statePosition), - state.xDeltas.get(statePosition), - state.yValues.get(statePosition), - state.yDeltas.get(statePosition), - state.counts.get(statePosition), - groupId - ); - } - } - public static void combineIntermediate( GroupingCentroidState current, int groupId, @@ -170,12 +153,6 @@ public void count(long count) { this.count = count; } - public void add(CentroidState other) { - xSum.add(other.xSum.value(), other.xSum.delta()); - ySum.add(other.ySum.value(), other.ySum.delta()); - count += other.count; - } - public void add(double x, double y) { xSum.add(x); ySum.add(y); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentAggregator.java index 91e0f098d795e..ac32abac23898 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentAggregator.java @@ -29,8 +29,4 @@ public static Block evaluateFinal(SpatialExtentState state, DriverContext driver public static Block evaluateFinal(SpatialExtentGroupingState state, IntVector selected, DriverContext driverContext) { return state.toBlock(selected, driverContext); } - - public static void combineStates(SpatialExtentGroupingState current, int groupId, SpatialExtentGroupingState inState, int inPosition) { - current.add(groupId, inState, inPosition); - } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingState.java index 9fb548dceaad9..b2411d59ac298 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingState.java @@ -176,17 +176,4 @@ public Block toBlock(IntVector selected, DriverContext driverContext) { return builder.build(); } } - - public void add(int groupId, SpatialExtentGroupingState inState, int inPosition) { - ensureCapacity(groupId); - if (inState.hasValue(inPosition)) { - add( - groupId, - inState.minXs.get(inPosition), - inState.maxXs.get(inPosition), - inState.maxYs.get(inPosition), - inState.minYs.get(inPosition) - ); - } - } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingStateWrappedLongitudeState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingStateWrappedLongitudeState.java index 9f8fca5236d14..5eadbc83435b0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingStateWrappedLongitudeState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGroupingStateWrappedLongitudeState.java @@ -108,21 +108,6 @@ public void add(int groupId, Geometry geo) { } } - public void add(int groupId, SpatialExtentGroupingStateWrappedLongitudeState inState, int inPosition) { - ensureCapacity(groupId); - if (inState.hasValue(inPosition)) { - add( - groupId, - inState.tops.get(inPosition), - inState.bottoms.get(inPosition), - inState.negLefts.get(inPosition), - inState.negRights.get(inPosition), - inState.posLefts.get(inPosition), - inState.posRights.get(inPosition) - ); - } - } - /** * This method is used when the field is a geo_point or cartesian_point and is loaded from doc-values. * This optimization is enabled when the field has doc-values and is only used in a spatial aggregation. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentLongitudeWrappingAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentLongitudeWrappingAggregator.java index 2d89ba78d1025..c15761d451f5b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentLongitudeWrappingAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/SpatialExtentLongitudeWrappingAggregator.java @@ -50,13 +50,4 @@ public static Block evaluateFinal( ) { return state.toBlock(selected, driverContext); } - - public static void combineStates( - SpatialExtentGroupingStateWrappedLongitudeState current, - int groupId, - SpatialExtentGroupingStateWrappedLongitudeState inState, - int inPosition - ) { - current.add(groupId, inState, inPosition); - } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java index 20ca4ed70e3f8..ccd0f82343401 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java @@ -138,7 +138,6 @@ private boolean checkIfSingleSegmentNonDecreasing() { prev = v; } return true; - } /** diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java index 626f0b00f0e2c..cded3a3494738 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java @@ -58,7 +58,7 @@ public Factory( taskConcurrency, limit, false, - ScoreMode.COMPLETE_NO_SCORES + shardContext -> ScoreMode.COMPLETE_NO_SCORES ); this.shardRefCounters = contexts; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java index 82d766349ce9e..7e0003efaf669 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java @@ -129,7 +129,7 @@ public LuceneMaxFactory( taskConcurrency, limit, false, - ScoreMode.COMPLETE_NO_SCORES + shardContext -> ScoreMode.COMPLETE_NO_SCORES ); this.contexts = contexts; this.fieldName = fieldName; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java index 505e5cd3f0d75..000ade1b19562 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java @@ -130,7 +130,7 @@ public LuceneMinFactory( taskConcurrency, limit, false, - ScoreMode.COMPLETE_NO_SCORES + shardContext -> ScoreMode.COMPLETE_NO_SCORES ); this.shardRefCounters = contexts; this.fieldName = fieldName; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java index 366715530f665..f3eec4147f237 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java @@ -112,11 +112,18 @@ protected Factory( int taskConcurrency, int limit, boolean needsScore, - ScoreMode scoreMode + Function scoreModeFunction ) { this.limit = limit; this.dataPartitioning = dataPartitioning; - this.sliceQueue = LuceneSliceQueue.create(contexts, queryFunction, dataPartitioning, autoStrategy, taskConcurrency, scoreMode); + this.sliceQueue = LuceneSliceQueue.create( + contexts, + queryFunction, + dataPartitioning, + autoStrategy, + taskConcurrency, + scoreModeFunction + ); this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency); this.needsScore = needsScore; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java index ee9f217303195..1a0b349b45f3f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java @@ -112,12 +112,14 @@ public static LuceneSliceQueue create( DataPartitioning dataPartitioning, Function autoStrategy, int taskConcurrency, - ScoreMode scoreMode + Function scoreModeFunction ) { List slices = new ArrayList<>(); Map partitioningStrategies = new HashMap<>(contexts.size()); + for (ShardContext ctx : contexts) { for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) { + var scoreMode = scoreModeFunction.apply(ctx); Query query = queryAndExtra.query; query = scoreMode.needsScores() ? query : new ConstantScoreQuery(query); /* diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java index 9fedc595641b4..5201eede502df 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java @@ -81,7 +81,7 @@ public Factory( taskConcurrency, limit, needsScore, - needsScore ? COMPLETE : COMPLETE_NO_SCORES + shardContext -> needsScore ? COMPLETE : COMPLETE_NO_SCORES ); this.contexts = contexts; this.maxPageSize = maxPageSize; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java index d93a5493a3aba..553b4319f22e9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java @@ -12,14 +12,13 @@ import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.LeafCollector; -import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopFieldCollectorManager; import org.apache.lucene.search.TopScoreDocCollectorManager; -import org.apache.lucene.search.Weight; import org.elasticsearch.common.Strings; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocBlock; @@ -44,9 +43,6 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.apache.lucene.search.ScoreMode.TOP_DOCS; -import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES; - /** * Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN) */ @@ -75,7 +71,7 @@ public Factory( taskConcurrency, limit, needsScore, - needsScore ? TOP_DOCS_WITH_SCORES : TOP_DOCS + scoreModeFunction(sorts, needsScore) ); this.contexts = contexts; this.maxPageSize = maxPageSize; @@ -331,18 +327,11 @@ static final class ScoringPerShardCollector extends PerShardCollector { } } - private static Function weightFunction( - Function queryFunction, - List> sorts, - boolean needsScore - ) { + private static Function scoreModeFunction(List> sorts, boolean needsScore) { return ctx -> { - final var query = queryFunction.apply(ctx); - final var searcher = ctx.searcher(); try { // we create a collector with a limit of 1 to determine the appropriate score mode to use. - var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode(); - return searcher.createWeight(searcher.rewrite(query), scoreMode, 1); + return newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode(); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java index 089846f9939ae..ba6da814542e4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java @@ -183,16 +183,20 @@ void readDocsForNextPage() throws IOException { for (LeafIterator leaf : oneTsidQueue) { leaf.reinitializeIfNeeded(executingThread); } - do { - PriorityQueue sub = subQueueForNextTsid(); - if (sub.size() == 0) { - break; - } - tsHashesBuilder.appendNewTsid(sub.top().timeSeriesHash); - if (readValuesForOneTsid(sub)) { - break; - } - } while (mainQueue.size() > 0); + if (mainQueue.size() + oneTsidQueue.size() == 1) { + readValuesFromSingleRemainingLeaf(); + } else { + do { + PriorityQueue sub = subQueueForNextTsid(); + if (sub.size() == 0) { + break; + } + tsHashesBuilder.appendNewTsid(sub.top().timeSeriesHash); + if (readValuesForOneTsid(sub)) { + break; + } + } while (mainQueue.size() > 0); + } } private boolean readValuesForOneTsid(PriorityQueue sub) throws IOException { @@ -236,6 +240,38 @@ private PriorityQueue subQueueForNextTsid() { return oneTsidQueue; } + private void readValuesFromSingleRemainingLeaf() throws IOException { + if (oneTsidQueue.size() == 0) { + oneTsidQueue.add(getMainQueue().pop()); + tsidsLoaded++; + } + final LeafIterator sub = oneTsidQueue.top(); + int lastTsid = -1; + do { + currentPagePos++; + remainingDocs--; + docCollector.collect(sub.segmentOrd, sub.docID); + if (lastTsid != sub.lastTsidOrd) { + tsHashesBuilder.appendNewTsid(sub.timeSeriesHash); + lastTsid = sub.lastTsidOrd; + } + tsHashesBuilder.appendOrdinal(); + timestampsBuilder.appendLong(sub.timestamp); + if (sub.nextDoc() == false) { + if (sub.docID == DocIdSetIterator.NO_MORE_DOCS) { + oneTsidQueue.clear(); + return; + } else { + ++tsidsLoaded; + } + } + } while (remainingDocs > 0 && currentPagePos < maxPageSize); + } + + private PriorityQueue getMainQueue() { + return mainQueue; + } + boolean completed() { return mainQueue.size() == 0 && oneTsidQueue.size() == 0; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java index 97286761b7bcf..bb1d889db3f85 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperatorFactory.java @@ -45,7 +45,7 @@ private TimeSeriesSourceOperatorFactory( taskConcurrency, limit, false, - ScoreMode.COMPLETE_NO_SCORES + shardContext -> ScoreMode.COMPLETE_NO_SCORES ); this.contexts = contexts; this.maxPageSize = maxPageSize; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java index f7f5f541c747f..20e7ffc4ca2cb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java @@ -14,18 +14,16 @@ import org.elasticsearch.core.Releasable; class ComputeBlockLoaderFactory extends DelegatingBlockLoaderFactory implements Releasable { - private final int pageSize; private Block nullBlock; - ComputeBlockLoaderFactory(BlockFactory factory, int pageSize) { + ComputeBlockLoaderFactory(BlockFactory factory) { super(factory); - this.pageSize = pageSize; } @Override - public Block constantNulls() { + public Block constantNulls(int count) { if (nullBlock == null) { - nullBlock = factory.newConstantNullBlock(pageSize); + nullBlock = factory.newConstantNullBlock(count); } nullBlock.incRef(); return nullBlock; @@ -39,7 +37,7 @@ public void close() { } @Override - public BytesRefBlock constantBytes(BytesRef value) { - return factory.newConstantBytesRefBlockWith(value, pageSize); + public BytesRefBlock constantBytes(BytesRef value, int count) { + return factory.newConstantBytesRefBlockWith(value, count); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/DelegatingBlockLoaderFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/DelegatingBlockLoaderFactory.java index 8dc5b6cc43ecf..c5e3628b268d4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/DelegatingBlockLoaderFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/DelegatingBlockLoaderFactory.java @@ -8,10 +8,10 @@ package org.elasticsearch.compute.lucene.read; import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedSetDocValues; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.SingletonOrdinalsBuilder; import org.elasticsearch.index.mapper.BlockLoader; public abstract class DelegatingBlockLoaderFactory implements BlockLoader.BlockFactory { @@ -86,6 +86,11 @@ public BlockLoader.SingletonOrdinalsBuilder singletonOrdinalsBuilder(SortedDocVa return new SingletonOrdinalsBuilder(factory, ordinals, count); } + @Override + public BlockLoader.SortedSetOrdinalsBuilder sortedSetOrdinalsBuilder(SortedSetDocValues ordinals, int count) { + return new SortedSetOrdinalsBuilder(factory, ordinals, count); + } + @Override public BlockLoader.AggregateMetricDoubleBuilder aggregateMetricDoubleBuilder(int count) { return factory.newAggregateMetricDoubleBlockBuilder(count); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/SingletonOrdinalsBuilder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/SingletonOrdinalsBuilder.java similarity index 95% rename from x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/SingletonOrdinalsBuilder.java rename to x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/SingletonOrdinalsBuilder.java index cfcc75c7c396a..ef3c39e30a7fe 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/SingletonOrdinalsBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/SingletonOrdinalsBuilder.java @@ -5,11 +5,17 @@ * 2.0. */ -package org.elasticsearch.compute.data; +package org.elasticsearch.compute.lucene.read; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/SortedSetOrdinalsBuilder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/SortedSetOrdinalsBuilder.java new file mode 100644 index 0000000000000..7e7fb8a8abb82 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/SortedSetOrdinalsBuilder.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.lucene.read; + +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.mapper.BlockLoader; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; + +public final class SortedSetOrdinalsBuilder implements BlockLoader.SortedSetOrdinalsBuilder, Releasable, Block.Builder { + private final BlockFactory blockFactory; + private final SortedSetDocValues docValues; + private int minOrd = Integer.MAX_VALUE; + private int maxOrd = Integer.MIN_VALUE; + private int totalValueCount; + private final IntBlock.Builder ordsBuilder; + + public SortedSetOrdinalsBuilder(BlockFactory blockFactory, SortedSetDocValues docValues, int count) { + this.blockFactory = blockFactory; + this.docValues = docValues; + this.ordsBuilder = blockFactory.newIntBlockBuilder(count).mvOrdering(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING); + } + + @Override + public SortedSetOrdinalsBuilder appendNull() { + ordsBuilder.appendNull(); + return this; + } + + @Override + public SortedSetOrdinalsBuilder appendOrd(int ord) { + minOrd = Math.min(minOrd, ord); + maxOrd = Math.max(maxOrd, ord); + ordsBuilder.appendInt(ord); + totalValueCount++; + return this; + } + + @Override + public SortedSetOrdinalsBuilder beginPositionEntry() { + ordsBuilder.beginPositionEntry(); + return this; + } + + @Override + public SortedSetOrdinalsBuilder endPositionEntry() { + ordsBuilder.endPositionEntry(); + return this; + } + + private BytesRefBlock buildBlock(IntBlock ordinals) { + final int numOrds = maxOrd - minOrd + 1; + final long breakerSize = arraySize(numOrds); + blockFactory.adjustBreaker(breakerSize); + BytesRefVector dict = null; + IntBlock mappedOrds = null; + try { + final int[] newOrds = new int[numOrds]; + Arrays.fill(newOrds, -1); + for (int p = 0; p < ordinals.getPositionCount(); p++) { + int count = ordinals.getValueCount(p); + if (count > 0) { + int first = ordinals.getFirstValueIndex(p); + for (int i = 0; i < count; i++) { + int oldOrd = ordinals.getInt(first + i); + newOrds[oldOrd - minOrd] = 0; + } + } + } + int nextOrd = -1; + try (BytesRefVector.Builder dictBuilder = blockFactory.newBytesRefVectorBuilder(Math.min(newOrds.length, totalValueCount))) { + for (int i = 0; i < newOrds.length; i++) { + if (newOrds[i] != -1) { + newOrds[i] = ++nextOrd; + dictBuilder.appendBytesRef(docValues.lookupOrd(i + minOrd)); + } + } + dict = dictBuilder.build(); + } catch (IOException e) { + throw new UncheckedIOException("error resolving ordinals", e); + } + mappedOrds = remapOrdinals(ordinals, newOrds, minOrd); + final OrdinalBytesRefBlock result = new OrdinalBytesRefBlock(mappedOrds, dict); + dict = null; + mappedOrds = null; + return result; + } finally { + Releasables.close(() -> blockFactory.adjustBreaker(-breakerSize), mappedOrds, dict); + } + } + + private IntBlock remapOrdinals(IntBlock ordinals, int[] newOrds, int shiftOrd) { + try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(totalValueCount)) { + for (int p = 0; p < ordinals.getPositionCount(); p++) { + int valueCount = ordinals.getValueCount(p); + switch (valueCount) { + case 0 -> builder.appendNull(); + case 1 -> { + int ord = ordinals.getInt(ordinals.getFirstValueIndex(p)); + builder.appendInt(newOrds[ord - shiftOrd]); + } + default -> { + int first = ordinals.getFirstValueIndex(p); + builder.beginPositionEntry(); + int last = first + valueCount; + for (int i = first; i < last; i++) { + int ord = ordinals.getInt(i); + builder.appendInt(newOrds[ord - shiftOrd]); + } + builder.endPositionEntry(); + } + } + } + builder.mvOrdering(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING); + return builder.build(); + } + } + + @Override + public long estimatedBytes() { + /* + * This is a *terrible* estimate because we have no idea how big the + * values in the ordinals are. + */ + final int numOrds = minOrd <= maxOrd ? maxOrd - minOrd + 1 : 0; + return totalValueCount * 4L + Math.min(numOrds, totalValueCount) * 20L; + } + + @Override + public BytesRefBlock build() { + try (IntBlock ordinals = ordsBuilder.build()) { + if (ordinals.areAllValuesNull()) { + return (BytesRefBlock) blockFactory.newConstantNullBlock(ordinals.getPositionCount()); + } + return buildBlock(ordinals); + } + } + + @Override + public void close() { + ordsBuilder.close(); + } + + @Override + public Block.Builder copyFrom(Block block, int beginInclusive, int endExclusive) { + throw new UnsupportedOperationException(); + } + + @Override + public Block.Builder mvOrdering(Block.MvOrdering mvOrdering) { + throw new UnsupportedOperationException(); + } + + private static long arraySize(int ordsCount) { + return RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) ordsCount * Integer.BYTES; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java index 9ec5802b43f98..e197861e9b701 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java @@ -198,12 +198,12 @@ static class BlockLoaderFactory extends DelegatingBlockLoaderFactory { } @Override - public BlockLoader.Block constantNulls() { + public BlockLoader.Block constantNulls(int count) { throw new UnsupportedOperationException("must not be used by column readers"); } @Override - public BlockLoader.Block constantBytes(BytesRef value) { + public BlockLoader.Block constantBytes(BytesRef value, int count) { throw new UnsupportedOperationException("must not be used by column readers"); } @@ -254,7 +254,8 @@ static final class ShardLevelFieldsReader implements Releasable { this.storedFieldsSpec = storedFieldsSpec; this.dimensions = new boolean[fields.size()]; for (int i = 0; i < fields.size(); i++) { - dimensions[i] = shardContext.fieldType(fields.get(i).name()).isDimension(); + final var mappedFieldType = shardContext.fieldType(fields.get(i).name()); + dimensions[i] = mappedFieldType != null && mappedFieldType.isDimension(); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java index 7ff6e7211b7f2..6f00e97a1f9f2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java @@ -16,6 +16,8 @@ import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.BlockLoaderStoredFieldsFromLeafLoader; import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; @@ -24,6 +26,8 @@ * Loads values from a many leaves. Much less efficient than {@link ValuesFromSingleReader}. */ class ValuesFromManyReader extends ValuesReader { + private static final Logger log = LogManager.getLogger(ValuesFromManyReader.class); + private final int[] forwards; private final int[] backwards; private final BlockLoader.RowStrideReader[] rowStride; @@ -35,6 +39,7 @@ class ValuesFromManyReader extends ValuesReader { forwards = docs.shardSegmentDocMapForwards(); backwards = docs.shardSegmentDocMapBackwards(); rowStride = new BlockLoader.RowStrideReader[operator.fields.length]; + log.debug("initializing {} positions", docs.getPositionCount()); } @Override @@ -70,9 +75,7 @@ void run(int offset) throws IOException { builders[f] = new Block.Builder[operator.shardContexts.size()]; converters[f] = new BlockLoader[operator.shardContexts.size()]; } - try ( - ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory, docs.getPositionCount()) - ) { + try (ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory)) { int p = forwards[offset]; int shard = docs.shards().getInt(p); int segment = docs.segments().getInt(p); @@ -84,7 +87,9 @@ void run(int offset) throws IOException { read(firstDoc, shard); int i = offset + 1; - while (i < forwards.length) { + long estimated = estimatedRamBytesUsed(); + long dangerZoneBytes = Long.MAX_VALUE; // TODO danger_zone if ascending + while (i < forwards.length && estimated < dangerZoneBytes) { p = forwards[i]; shard = docs.shards().getInt(p); segment = docs.segments().getInt(p); @@ -96,8 +101,17 @@ void run(int offset) throws IOException { verifyBuilders(loaderBlockFactory, shard); read(docs.docs().getInt(p), shard); i++; + estimated = estimatedRamBytesUsed(); + log.trace("{}: bytes loaded {}/{}", p, estimated, dangerZoneBytes); } buildBlocks(); + if (log.isDebugEnabled()) { + long actual = 0; + for (Block b : target) { + actual += b.ramBytesUsed(); + } + log.debug("loaded {} positions total estimated/actual {}/{} bytes", p, estimated, actual); + } } } @@ -115,6 +129,9 @@ private void buildBlocks() { } operator.sanityCheckBlock(rowStride[f], backwards.length, target[f], f); } + if (target[0].getPositionCount() != docs.getPositionCount()) { + throw new IllegalStateException("partial pages not yet supported"); + } } private void verifyBuilders(ComputeBlockLoaderFactory loaderBlockFactory, int shard) { @@ -141,6 +158,18 @@ public void close() { Releasables.closeExpectNoException(builders[f]); } } + + private long estimatedRamBytesUsed() { + long estimated = 0; + for (Block.Builder[] builders : this.builders) { + for (Block.Builder builder : builders) { + if (builder != null) { + estimated += builder.estimatedBytes(); + } + } + } + return estimated; + } } private void fieldsMoved(LeafReaderContext ctx, int shard) throws IOException { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java index 1bee68160e024..d47a015c24578 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java @@ -16,6 +16,8 @@ import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.BlockLoaderStoredFieldsFromLeafLoader; import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; @@ -26,6 +28,8 @@ * Loads values from a single leaf. Much more efficient than {@link ValuesFromManyReader}. */ class ValuesFromSingleReader extends ValuesReader { + private static final Logger log = LogManager.getLogger(ValuesFromSingleReader.class); + /** * Minimum number of documents for which it is more efficient to use a * sequential stored field reader when reading stored fields. @@ -45,39 +49,27 @@ class ValuesFromSingleReader extends ValuesReader { super(operator, docs); this.shard = docs.shards().getInt(0); this.segment = docs.segments().getInt(0); + log.debug("initialized {} positions", docs.getPositionCount()); } @Override protected void load(Block[] target, int offset) throws IOException { - assert offset == 0; // TODO allow non-0 offset to support splitting pages if (docs.singleSegmentNonDecreasing()) { - loadFromSingleLeaf(target, new BlockLoader.Docs() { - @Override - public int count() { - return docs.getPositionCount(); - } - - @Override - public int get(int i) { - return docs.docs().getInt(i); - } - }); + loadFromSingleLeaf(operator.jumboBytes, target, new ValuesReaderDocs(docs), offset); return; } + if (offset != 0) { + throw new IllegalStateException("can only load partial pages with single-segment non-decreasing pages"); + } int[] forwards = docs.shardSegmentDocMapForwards(); Block[] unshuffled = new Block[target.length]; try { - loadFromSingleLeaf(unshuffled, new BlockLoader.Docs() { - @Override - public int count() { - return docs.getPositionCount(); - } - - @Override - public int get(int i) { - return docs.docs().getInt(forwards[i]); - } - }); + loadFromSingleLeaf( + Long.MAX_VALUE, // Effectively disable splitting pages when we're not loading in order + unshuffled, + new ValuesReaderDocs(docs).mapped(forwards), + 0 + ); final int[] backwards = docs.shardSegmentDocMapBackwards(); for (int i = 0; i < unshuffled.length; i++) { target[i] = unshuffled[i].filter(backwards); @@ -89,24 +81,25 @@ public int get(int i) { } } - private void loadFromSingleLeaf(Block[] target, BlockLoader.Docs docs) throws IOException { - int firstDoc = docs.get(0); + private void loadFromSingleLeaf(long jumboBytes, Block[] target, ValuesReaderDocs docs, int offset) throws IOException { + int firstDoc = docs.get(offset); operator.positionFieldWork(shard, segment, firstDoc); StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS; - List rowStrideReaders = new ArrayList<>(operator.fields.length); LeafReaderContext ctx = operator.ctx(shard, segment); - try (ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory, docs.count())) { + + List columnAtATimeReaders = new ArrayList<>(operator.fields.length); + List rowStrideReaders = new ArrayList<>(operator.fields.length); + try (ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory)) { for (int f = 0; f < operator.fields.length; f++) { ValuesSourceReaderOperator.FieldWork field = operator.fields[f]; BlockLoader.ColumnAtATimeReader columnAtATime = field.columnAtATime(ctx); if (columnAtATime != null) { - target[f] = (Block) columnAtATime.read(loaderBlockFactory, docs); - operator.sanityCheckBlock(columnAtATime, docs.count(), target[f], f); + columnAtATimeReaders.add(new ColumnAtATimeWork(columnAtATime, f)); } else { rowStrideReaders.add( new RowStrideReaderWork( field.rowStride(ctx), - (Block.Builder) field.loader.builder(loaderBlockFactory, docs.count()), + (Block.Builder) field.loader.builder(loaderBlockFactory, docs.count() - offset), field.loader, f ) @@ -116,7 +109,18 @@ private void loadFromSingleLeaf(Block[] target, BlockLoader.Docs docs) throws IO } if (rowStrideReaders.isEmpty() == false) { - loadFromRowStrideReaders(target, storedFieldsSpec, rowStrideReaders, ctx, docs); + loadFromRowStrideReaders(jumboBytes, target, storedFieldsSpec, rowStrideReaders, ctx, docs, offset); + } + for (ColumnAtATimeWork r : columnAtATimeReaders) { + target[r.idx] = (Block) r.reader.read(loaderBlockFactory, docs, offset); + operator.sanityCheckBlock(r.reader, docs.count() - offset, target[r.idx], r.idx); + } + if (log.isDebugEnabled()) { + long total = 0; + for (Block b : target) { + total += b.ramBytesUsed(); + } + log.debug("loaded {} positions total ({} bytes)", target[0].getPositionCount(), total); } } finally { Releasables.close(rowStrideReaders); @@ -124,11 +128,13 @@ private void loadFromSingleLeaf(Block[] target, BlockLoader.Docs docs) throws IO } private void loadFromRowStrideReaders( + long jumboBytes, Block[] target, StoredFieldsSpec storedFieldsSpec, List rowStrideReaders, LeafReaderContext ctx, - BlockLoader.Docs docs + ValuesReaderDocs docs, + int offset ) throws IOException { SourceLoader sourceLoader = null; ValuesSourceReaderOperator.ShardContext shardContext = operator.shardContexts.get(shard); @@ -153,18 +159,29 @@ private void loadFromRowStrideReaders( storedFieldLoader.getLoader(ctx, null), sourceLoader != null ? sourceLoader.leaf(ctx.reader(), null) : null ); - int p = 0; - while (p < docs.count()) { + int p = offset; + long estimated = 0; + while (p < docs.count() && estimated < jumboBytes) { int doc = docs.get(p++); storedFields.advanceTo(doc); for (RowStrideReaderWork work : rowStrideReaders) { work.read(doc, storedFields); } + estimated = estimatedRamBytesUsed(rowStrideReaders); + log.trace("{}: bytes loaded {}/{}", p, estimated, jumboBytes); } for (RowStrideReaderWork work : rowStrideReaders) { - target[work.offset] = work.build(); - operator.sanityCheckBlock(work.reader, p, target[work.offset], work.offset); + target[work.idx] = work.build(); + operator.sanityCheckBlock(work.reader, p - offset, target[work.idx], work.idx); } + if (log.isDebugEnabled()) { + long actual = 0; + for (RowStrideReaderWork work : rowStrideReaders) { + actual += target[work.idx].ramBytesUsed(); + } + log.debug("loaded {} positions row stride estimated/actual {}/{} bytes", p - offset, estimated, actual); + } + docs.setCount(p); } /** @@ -180,7 +197,21 @@ private boolean useSequentialStoredFieldsReader(BlockLoader.Docs docs, double st return range * storedFieldsSequentialProportion <= count; } - private record RowStrideReaderWork(BlockLoader.RowStrideReader reader, Block.Builder builder, BlockLoader loader, int offset) + /** + * Work for building a column-at-a-time. + * @param reader reads the values + * @param idx destination in array of {@linkplain Block}s we build + */ + private record ColumnAtATimeWork(BlockLoader.ColumnAtATimeReader reader, int idx) {} + + /** + * Work for + * @param reader + * @param builder + * @param loader + * @param idx + */ + private record RowStrideReaderWork(BlockLoader.RowStrideReader reader, Block.Builder builder, BlockLoader loader, int idx) implements Releasable { void read(int doc, BlockLoaderStoredFieldsFromLeafLoader storedFields) throws IOException { @@ -196,4 +227,12 @@ public void close() { builder.close(); } } + + private long estimatedRamBytesUsed(List rowStrideReaders) { + long estimated = 0; + for (RowStrideReaderWork r : rowStrideReaders) { + estimated += r.builder.estimatedBytes(); + } + return estimated; + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java index ebfac0cb24f7f..d3b8b0edcec3d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java @@ -36,9 +36,6 @@ public Block[] next() { boolean success = false; try { load(target, offset); - if (target[0].getPositionCount() != docs.getPositionCount()) { - throw new IllegalStateException("partial pages not yet supported"); - } success = true; for (Block b : target) { operator.valuesLoaded += b.getTotalValueCount(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReaderDocs.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReaderDocs.java new file mode 100644 index 0000000000000..2e138dc2d0446 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReaderDocs.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.lucene.read; + +import org.elasticsearch.compute.data.DocVector; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.BlockLoader; + +/** + * Implementation of {@link BlockLoader.Docs} for ESQL. It's important that + * only this implementation, and the implementation returned by {@link #mapped} + * exist. This allows the jvm to inline the {@code invokevirtual}s to call + * the interface in hot, hot code. + *

+ * We've investigated moving the {@code offset} parameter from the + * {@link BlockLoader.ColumnAtATimeReader#read} into this. That's more + * readable, but a clock cycle slower. + *

+ *

+ * When we tried having a {@link Nullable} map member instead of a subclass + * that was also slower. + *

+ */ +class ValuesReaderDocs implements BlockLoader.Docs { + private final DocVector docs; + private int count; + + ValuesReaderDocs(DocVector docs) { + this.docs = docs; + this.count = docs.getPositionCount(); + } + + final Mapped mapped(int[] forwards) { + return new Mapped(docs, forwards); + } + + public final void setCount(int count) { + this.count = count; + } + + @Override + public final int count() { + return count; + } + + @Override + public int get(int i) { + return docs.docs().getInt(i); + } + + private class Mapped extends ValuesReaderDocs { + private final int[] forwards; + + private Mapped(DocVector docs, int[] forwards) { + super(docs); + this.forwards = forwards; + } + + @Override + public int get(int i) { + return super.get(forwards[i]); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java index 2fd4784224087..6d0ebb9c312d0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java @@ -9,6 +9,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocBlock; @@ -42,7 +43,9 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingToIteratorOpe * @param shardContexts per-shard loading information * @param docChannel the channel containing the shard, leaf/segment and doc id */ - public record Factory(List fields, List shardContexts, int docChannel) implements OperatorFactory { + public record Factory(ByteSizeValue jumboSize, List fields, List shardContexts, int docChannel) + implements + OperatorFactory { public Factory { if (fields.isEmpty()) { throw new IllegalStateException("ValuesSourceReaderOperator doesn't support empty fields"); @@ -51,7 +54,7 @@ public record Factory(List fields, List shardContexts, @Override public Operator get(DriverContext driverContext) { - return new ValuesSourceReaderOperator(driverContext.blockFactory(), fields, shardContexts, docChannel); + return new ValuesSourceReaderOperator(driverContext.blockFactory(), jumboSize.getBytes(), fields, shardContexts, docChannel); } @Override @@ -85,10 +88,21 @@ public record FieldInfo(String name, ElementType type, IntFunction public record ShardContext(IndexReader reader, Supplier newSourceLoader, double storedFieldsSequentialProportion) {} + final BlockFactory blockFactory; + /** + * When the loaded fields {@link Block}s' estimated size grows larger than this, + * we finish loading the {@linkplain Page} and return it, even if + * the {@linkplain Page} is shorter than the incoming {@linkplain Page}. + *

+ * NOTE: This only applies when loading single segment non-descending + * row stride bytes. This is the most common way to get giant fields, + * but it isn't all the ways. + *

+ */ + final long jumboBytes; final FieldWork[] fields; final List shardContexts; private final int docChannel; - final BlockFactory blockFactory; private final Map readersBuilt = new TreeMap<>(); long valuesLoaded; @@ -101,14 +115,21 @@ public record ShardContext(IndexReader reader, Supplier newSourceL * @param fields fields to load * @param docChannel the channel containing the shard, leaf/segment and doc id */ - public ValuesSourceReaderOperator(BlockFactory blockFactory, List fields, List shardContexts, int docChannel) { + public ValuesSourceReaderOperator( + BlockFactory blockFactory, + long jumboBytes, + List fields, + List shardContexts, + int docChannel + ) { if (fields.isEmpty()) { throw new IllegalStateException("ValuesSourceReaderOperator doesn't support empty fields"); } + this.blockFactory = blockFactory; + this.jumboBytes = jumboBytes; this.fields = fields.stream().map(FieldWork::new).toArray(FieldWork[]::new); this.shardContexts = shardContexts; this.docChannel = docChannel; - this.blockFactory = blockFactory; } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnLoadOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnLoadOperator.java index 05f60c1b6834d..9eb00adb58146 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnLoadOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ColumnLoadOperator.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.ReleasableIterator; +import org.elasticsearch.core.Releasables; /** * {@link Block#lookup Looks up} values from a provided {@link Block} and @@ -44,8 +45,19 @@ public String describe() { private final int positionsOrd; public ColumnLoadOperator(Values values, int positionsOrd) { - this.values = values; this.positionsOrd = positionsOrd; + this.values = clone(values); + } + + // FIXME: Since we don't have a thread-safe RefCounted for blocks/vectors, we have to clone the values block to avoid + // data races of reference when sharing blocks/vectors across threads. Remove this when we have a thread-safe RefCounted + // for blocks/vectors. + static Values clone(Values values) { + final Block block = values.block; + try (var builder = block.elementType().newBlockBuilder(block.getPositionCount(), block.blockFactory())) { + builder.copyFrom(block, 0, block.getPositionCount()); + return new Values(values.name, builder.build()); + } } /** @@ -67,6 +79,11 @@ protected ReleasableIterator receive(Page page) { return appendBlocks(page, values.block.lookup(page.getBlock(positionsOrd), TARGET_BLOCK_SIZE)); } + @Override + public void close() { + Releasables.closeExpectNoException(values.block, super::close); + } + @Override public String toString() { return "ColumnLoad[values=" + values + ", positions=" + positionsOrd + "]"; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index cbce712ed9cdb..20734fe8c1a44 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -20,6 +20,7 @@ import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntVector; @@ -34,6 +35,7 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import static java.util.Objects.requireNonNull; @@ -52,6 +54,7 @@ public record HashAggregationOperatorFactory( public Operator get(DriverContext driverContext) { if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) { return new HashAggregationOperator( + groups, aggregators, () -> BlockHash.buildCategorizeBlockHash( groups, @@ -64,6 +67,7 @@ public Operator get(DriverContext driverContext) { ); } return new HashAggregationOperator( + groups, aggregators, () -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false), driverContext @@ -83,6 +87,7 @@ public String describe() { private boolean finished; private Page output; + private final List groups; private final BlockHash blockHash; protected final List aggregators; @@ -117,10 +122,12 @@ public String describe() { @SuppressWarnings("this-escape") public HashAggregationOperator( + List groups, List aggregators, Supplier blockHash, DriverContext driverContext ) { + this.groups = groups; this.aggregators = new ArrayList<>(aggregators.size()); this.driverContext = driverContext; boolean success = false; @@ -142,8 +149,22 @@ public boolean needsInput() { return finished == false; } + private final AtomicBoolean isInitialPage = new AtomicBoolean(true); + @Override public void addInput(Page page) { + if (isInitialPage.compareAndSet(true, false) + && (aggregators.size() == 0 || AggregatorMode.INITIAL.equals(aggregators.get(0).getMode()))) { + Page initialPage = createInitialPage(page); + if (initialPage != null) { + addInputInternal(initialPage); + return; + } + } + addInputInternal(page); + } + + private void addInputInternal(Page page) { try { GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()]; class AddInput implements GroupingAggregatorFunction.AddInput { @@ -289,6 +310,42 @@ protected Page wrapPage(Page page) { return page; } + private Page createInitialPage(Page page) { + // If no groups are generating bucket keys, move on + if (groups.stream().allMatch(g -> g.emptyBucketGenerator() == null)) { + return page; + } + Block.Builder[] blockBuilders = new Block.Builder[page.getBlockCount()]; + for (int channel = 0; channel < page.getBlockCount(); channel++) { + Block block = page.getBlock(channel); + blockBuilders[channel] = block.elementType().newBlockBuilder(block.getPositionCount(), driverContext.blockFactory()); + blockBuilders[channel].copyFrom(block, 0, block.getPositionCount()); + } + for (BlockHash.GroupSpec group : groups) { + BlockHash.EmptyBucketGenerator emptyBucketGenerator = group.emptyBucketGenerator(); + if (emptyBucketGenerator != null) { + for (int channel = 0; channel < page.getBlockCount(); channel++) { + if (group.channel() == channel) { + emptyBucketGenerator.generate(blockBuilders[channel]); + } else { + for (int i = 0; i < emptyBucketGenerator.getEmptyBucketCount(); i++) { + if (page.getBlock(channel) instanceof DocBlock) { + // TODO: DocBlock doesn't allow appending nulls + ((DocBlock.Builder) blockBuilders[channel]).appendShard(0).appendSegment(0).appendDoc(0); + } else { + blockBuilders[channel].appendNull(); + } + } + } + } + } + } + Block[] blocks = Arrays.stream(blockBuilders).map(Block.Builder::build).toArray(Block[]::new); + Releasables.closeExpectNoException(blockBuilders); + page.releaseBlocks(); + return new Page(blocks); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java deleted file mode 100644 index 58466cffee78e..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ /dev/null @@ -1,647 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.operator; - -import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.SortedSetDocValues; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; -import org.apache.lucene.util.PriorityQueue; -import org.elasticsearch.common.CheckedSupplier; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.compute.Describable; -import org.elasticsearch.compute.aggregation.GroupingAggregator; -import org.elasticsearch.compute.aggregation.GroupingAggregator.Factory; -import org.elasticsearch.compute.aggregation.GroupingAggregatorEvaluationContext; -import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.aggregation.SeenGroupIds; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash.GroupSpec; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DocBlock; -import org.elasticsearch.compute.data.DocVector; -import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator; -import org.elasticsearch.core.RefCounted; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.index.mapper.BlockLoader; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.function.IntFunction; -import java.util.function.Supplier; -import java.util.stream.Collectors; - -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.joining; - -/** - * Unlike {@link HashAggregationOperator}, this hash operator also extracts values or ordinals of the input documents. - */ -public class OrdinalsGroupingOperator implements Operator { - public record OrdinalsGroupingOperatorFactory( - IntFunction blockLoaders, - List shardContexts, - ElementType groupingElementType, - int docChannel, - String groupingField, - List aggregators, - int maxPageSize - ) implements OperatorFactory { - - @Override - public Operator get(DriverContext driverContext) { - return new OrdinalsGroupingOperator( - blockLoaders, - shardContexts, - groupingElementType, - docChannel, - groupingField, - aggregators, - maxPageSize, - driverContext - ); - } - - @Override - public String describe() { - return "OrdinalsGroupingOperator(aggs = " + aggregators.stream().map(Describable::describe).collect(joining(", ")) + ")"; - } - } - - private final IntFunction blockLoaders; - private final List shardContexts; - private final int docChannel; - private final String groupingField; - - private final List aggregatorFactories; - private final ElementType groupingElementType; - private final Map ordinalAggregators; - - private final DriverContext driverContext; - - private boolean finished = false; - - // used to extract and aggregate values - private final int maxPageSize; - private ValuesAggregator valuesAggregator; - - public OrdinalsGroupingOperator( - IntFunction blockLoaders, - List shardContexts, - ElementType groupingElementType, - int docChannel, - String groupingField, - List aggregatorFactories, - int maxPageSize, - DriverContext driverContext - ) { - Objects.requireNonNull(aggregatorFactories); - this.blockLoaders = blockLoaders; - this.shardContexts = shardContexts; - this.groupingElementType = groupingElementType; - this.docChannel = docChannel; - this.groupingField = groupingField; - this.aggregatorFactories = aggregatorFactories; - this.ordinalAggregators = new HashMap<>(); - this.maxPageSize = maxPageSize; - this.driverContext = driverContext; - } - - @Override - public boolean needsInput() { - return finished == false; - } - - @Override - public void addInput(Page page) { - checkState(needsInput(), "Operator is already finishing"); - requireNonNull(page, "page is null"); - DocVector docVector = page.getBlock(docChannel).asVector(); - final int shardIndex = docVector.shards().getInt(0); - RefCounted shardRefCounter = docVector.shardRefCounted().get(shardIndex); - final var blockLoader = blockLoaders.apply(shardIndex); - boolean pagePassed = false; - try { - if (docVector.singleSegmentNonDecreasing() && blockLoader.supportsOrdinals()) { - final IntVector segmentIndexVector = docVector.segments(); - assert segmentIndexVector.isConstant(); - final OrdinalSegmentAggregator ordinalAggregator = this.ordinalAggregators.computeIfAbsent( - new SegmentID(shardIndex, segmentIndexVector.getInt(0)), - k -> { - try { - return new OrdinalSegmentAggregator( - driverContext.blockFactory(), - this::createGroupingAggregators, - () -> blockLoader.ordinals(shardContexts.get(k.shardIndex).reader().leaves().get(k.segmentIndex)), - driverContext.bigArrays(), - shardRefCounter - ); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - ); - pagePassed = true; - ordinalAggregator.addInput(docVector.docs(), page); - } else { - if (valuesAggregator == null) { - int channelIndex = page.getBlockCount(); // extractor will append a new block at the end - valuesAggregator = new ValuesAggregator( - blockLoaders, - shardContexts, - groupingElementType, - docChannel, - groupingField, - channelIndex, - aggregatorFactories, - maxPageSize, - driverContext - ); - } - pagePassed = true; - valuesAggregator.addInput(page); - } - } finally { - if (pagePassed == false) { - Releasables.closeExpectNoException(page::releaseBlocks); - } - } - } - - private List createGroupingAggregators() { - boolean success = false; - List aggregators = new ArrayList<>(aggregatorFactories.size()); - try { - for (GroupingAggregator.Factory aggregatorFactory : aggregatorFactories) { - aggregators.add(aggregatorFactory.apply(driverContext)); - } - success = true; - return aggregators; - } finally { - if (success == false) { - Releasables.close(aggregators); - } - } - } - - @Override - public Page getOutput() { - if (finished == false) { - return null; - } - if (valuesAggregator != null) { - try { - return valuesAggregator.getOutput(); - } finally { - final ValuesAggregator aggregator = this.valuesAggregator; - this.valuesAggregator = null; - Releasables.close(aggregator); - } - } - if (ordinalAggregators.isEmpty() == false) { - try { - return mergeOrdinalsSegmentResults(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } finally { - Releasables.close(() -> Releasables.close(ordinalAggregators.values()), ordinalAggregators::clear); - } - } - return null; - } - - @Override - public void finish() { - finished = true; - if (valuesAggregator != null) { - valuesAggregator.finish(); - } - } - - private Page mergeOrdinalsSegmentResults() throws IOException { - // TODO: Should we also combine from the results from ValuesAggregator - final PriorityQueue pq = new PriorityQueue<>(ordinalAggregators.size()) { - @Override - protected boolean lessThan(AggregatedResultIterator a, AggregatedResultIterator b) { - return a.currentTerm.compareTo(b.currentTerm) < 0; - } - }; - final List aggregators = createGroupingAggregators(); - try { - boolean seenNulls = false; - for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) { - if (agg.seenNulls()) { - seenNulls = true; - for (int i = 0; i < aggregators.size(); i++) { - aggregators.get(i).addIntermediateRow(0, agg.aggregators.get(i), 0); - } - } - } - for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) { - final AggregatedResultIterator it = agg.getResultIterator(); - if (it.next()) { - pq.add(it); - } - } - final int startPosition = seenNulls ? 0 : -1; - int position = startPosition; - final BytesRefBuilder lastTerm = new BytesRefBuilder(); - final Block[] blocks; - final int[] aggBlockCounts; - try (var keysBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(1)) { - if (seenNulls) { - keysBuilder.appendNull(); - } - while (pq.size() > 0) { - final AggregatedResultIterator top = pq.top(); - if (position == startPosition || lastTerm.get().equals(top.currentTerm) == false) { - position++; - lastTerm.copyBytes(top.currentTerm); - keysBuilder.appendBytesRef(top.currentTerm); - } - for (int i = 0; i < top.aggregators.size(); i++) { - aggregators.get(i).addIntermediateRow(position, top.aggregators.get(i), top.currentPosition()); - } - if (top.next()) { - pq.updateTop(); - } else { - pq.pop(); - } - } - aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray(); - blocks = new Block[1 + Arrays.stream(aggBlockCounts).sum()]; - blocks[0] = keysBuilder.build(); - } - boolean success = false; - try { - try (IntVector selected = IntVector.range(0, blocks[0].getPositionCount(), driverContext.blockFactory())) { - int offset = 1; - for (int i = 0; i < aggregators.size(); i++) { - aggregators.get(i).evaluate(blocks, offset, selected, new GroupingAggregatorEvaluationContext(driverContext)); - offset += aggBlockCounts[i]; - } - } - success = true; - return new Page(blocks); - } finally { - if (success == false) { - Releasables.closeExpectNoException(blocks); - } - } - } finally { - Releasables.close(() -> Releasables.close(aggregators)); - } - } - - @Override - public boolean isFinished() { - return finished && valuesAggregator == null && ordinalAggregators.isEmpty(); - } - - @Override - public void close() { - Releasables.close(() -> Releasables.close(ordinalAggregators.values()), valuesAggregator); - } - - private static void checkState(boolean condition, String msg) { - if (condition == false) { - throw new IllegalArgumentException(msg); - } - } - - @Override - public String toString() { - String aggregatorDescriptions = aggregatorFactories.stream() - .map(factory -> "\"" + factory.describe() + "\"") - .collect(Collectors.joining(", ")); - - return this.getClass().getSimpleName() + "[" + "aggregators=[" + aggregatorDescriptions + "]]"; - } - - record SegmentID(int shardIndex, int segmentIndex) { - - } - - static final class OrdinalSegmentAggregator implements Releasable, SeenGroupIds { - private final BlockFactory blockFactory; - private final List aggregators; - private final CheckedSupplier docValuesSupplier; - private final BitArray visitedOrds; - private final RefCounted shardRefCounted; - private BlockOrdinalsReader currentReader; - - OrdinalSegmentAggregator( - BlockFactory blockFactory, - Supplier> aggregatorsSupplier, - CheckedSupplier docValuesSupplier, - BigArrays bigArrays, - RefCounted shardRefCounted - ) throws IOException { - boolean success = false; - this.shardRefCounted = shardRefCounted; - this.shardRefCounted.mustIncRef(); - List groupingAggregators = null; - BitArray bitArray = null; - try { - final SortedSetDocValues sortedSetDocValues = docValuesSupplier.get(); - bitArray = new BitArray(sortedSetDocValues.getValueCount(), bigArrays); - groupingAggregators = aggregatorsSupplier.get(); - this.currentReader = BlockOrdinalsReader.newReader(blockFactory, sortedSetDocValues); - this.blockFactory = blockFactory; - this.docValuesSupplier = docValuesSupplier; - this.aggregators = groupingAggregators; - this.visitedOrds = bitArray; - success = true; - } finally { - if (success == false) { - if (bitArray != null) Releasables.close(bitArray); - if (groupingAggregators != null) Releasables.close(groupingAggregators); - // There is no danger of double decRef here, since this decRef is called only if the constructor throws, so it would be - // impossible to call close on the instance. - shardRefCounted.decRef(); - } - } - } - - void addInput(IntVector docs, Page page) { - GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()]; - try { - for (int i = 0; i < prepared.length; i++) { - prepared[i] = aggregators.get(i).prepareProcessPage(this, page); - } - - if (BlockOrdinalsReader.canReuse(currentReader, docs.getInt(0)) == false) { - currentReader = BlockOrdinalsReader.newReader(blockFactory, docValuesSupplier.get()); - } - try (IntBlock ordinals = currentReader.readOrdinalsAdded1(docs)) { - final IntVector ordinalsVector = ordinals.asVector(); - if (ordinalsVector != null) { - addOrdinalsInput(ordinalsVector, prepared); - } else { - addOrdinalsInput(ordinals, prepared); - } - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } finally { - Releasables.close(page::releaseBlocks, Releasables.wrap(prepared)); - } - } - - void addOrdinalsInput(IntBlock ordinals, GroupingAggregatorFunction.AddInput[] prepared) { - for (int p = 0; p < ordinals.getPositionCount(); p++) { - int start = ordinals.getFirstValueIndex(p); - int end = start + ordinals.getValueCount(p); - for (int i = start; i < end; i++) { - long ord = ordinals.getInt(i); - visitedOrds.set(ord); - } - } - for (GroupingAggregatorFunction.AddInput addInput : prepared) { - addInput.add(0, ordinals); - } - } - - void addOrdinalsInput(IntVector ordinals, GroupingAggregatorFunction.AddInput[] prepared) { - for (int p = 0; p < ordinals.getPositionCount(); p++) { - long ord = ordinals.getInt(p); - visitedOrds.set(ord); - } - for (GroupingAggregatorFunction.AddInput addInput : prepared) { - addInput.add(0, ordinals); - } - } - - AggregatedResultIterator getResultIterator() throws IOException { - return new AggregatedResultIterator(aggregators, visitedOrds, docValuesSupplier.get()); - } - - boolean seenNulls() { - return visitedOrds.get(0); - } - - @Override - public BitArray seenGroupIds(BigArrays bigArrays) { - final BitArray seen = new BitArray(0, bigArrays); - boolean success = false; - try { - // the or method can grow the `seen` bits - seen.or(visitedOrds); - success = true; - return seen; - } finally { - if (success == false) { - Releasables.close(seen); - } - } - } - - @Override - public void close() { - Releasables.close(visitedOrds, () -> Releasables.close(aggregators), Releasables.fromRefCounted(shardRefCounted)); - } - } - - private static class AggregatedResultIterator { - private BytesRef currentTerm; - private long currentOrd = 0; - private final List aggregators; - private final BitArray ords; - private final SortedSetDocValues dv; - - AggregatedResultIterator(List aggregators, BitArray ords, SortedSetDocValues dv) { - this.aggregators = aggregators; - this.ords = ords; - this.dv = dv; - } - - int currentPosition() { - assert currentOrd != Long.MAX_VALUE : "Must not read position when iterator is exhausted"; - return Math.toIntExact(currentOrd); - } - - boolean next() throws IOException { - currentOrd = ords.nextSetBit(currentOrd + 1); - assert currentOrd > 0 : currentOrd; - if (currentOrd < Long.MAX_VALUE) { - currentTerm = dv.lookupOrd(currentOrd - 1); - return true; - } else { - currentTerm = null; - return false; - } - } - } - - private static class ValuesAggregator implements Releasable { - private final ValuesSourceReaderOperator extractor; - private final HashAggregationOperator aggregator; - - ValuesAggregator( - IntFunction blockLoaders, - List shardContexts, - ElementType groupingElementType, - int docChannel, - String groupingField, - int channelIndex, - List aggregatorFactories, - int maxPageSize, - DriverContext driverContext - ) { - this.extractor = new ValuesSourceReaderOperator( - driverContext.blockFactory(), - List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, groupingElementType, blockLoaders)), - shardContexts, - docChannel - ); - this.aggregator = new HashAggregationOperator( - aggregatorFactories, - () -> BlockHash.build( - List.of(new GroupSpec(channelIndex, groupingElementType)), - driverContext.blockFactory(), - maxPageSize, - false - ), - driverContext - ); - } - - void addInput(Page page) { - extractor.addInput(page); - Page out = extractor.getOutput(); - if (out != null) { - aggregator.addInput(out); - } - } - - void finish() { - aggregator.finish(); - } - - Page getOutput() { - return aggregator.getOutput(); - } - - @Override - public void close() { - Releasables.close(extractor, aggregator); - } - } - - abstract static class BlockOrdinalsReader { - protected final Thread creationThread; - protected final BlockFactory blockFactory; - - BlockOrdinalsReader(BlockFactory blockFactory) { - this.blockFactory = blockFactory; - this.creationThread = Thread.currentThread(); - } - - static BlockOrdinalsReader newReader(BlockFactory blockFactory, SortedSetDocValues sortedSetDocValues) { - SortedDocValues singleValues = DocValues.unwrapSingleton(sortedSetDocValues); - if (singleValues != null) { - return new SortedDocValuesBlockOrdinalsReader(blockFactory, singleValues); - } else { - return new SortedSetDocValuesBlockOrdinalsReader(blockFactory, sortedSetDocValues); - } - } - - abstract IntBlock readOrdinalsAdded1(IntVector docs) throws IOException; - - abstract int docID(); - - /** - * Checks if the reader can be used to read a range documents starting with the given docID by the current thread. - */ - static boolean canReuse(BlockOrdinalsReader reader, int startingDocID) { - return reader != null && reader.creationThread == Thread.currentThread() && reader.docID() <= startingDocID; - } - } - - private static class SortedSetDocValuesBlockOrdinalsReader extends BlockOrdinalsReader { - private final SortedSetDocValues sortedSetDocValues; - - SortedSetDocValuesBlockOrdinalsReader(BlockFactory blockFactory, SortedSetDocValues sortedSetDocValues) { - super(blockFactory); - this.sortedSetDocValues = sortedSetDocValues; - } - - @Override - IntBlock readOrdinalsAdded1(IntVector docs) throws IOException { - final int positionCount = docs.getPositionCount(); - try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - int doc = docs.getInt(p); - if (false == sortedSetDocValues.advanceExact(doc)) { - builder.appendInt(0); - continue; - } - int count = sortedSetDocValues.docValueCount(); - if (count == 1) { - builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1)); - continue; - } - builder.beginPositionEntry(); - for (int i = 0; i < count; i++) { - builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1)); - } - builder.endPositionEntry(); - } - return builder.build(); - } - } - - @Override - int docID() { - return sortedSetDocValues.docID(); - } - } - - private static class SortedDocValuesBlockOrdinalsReader extends BlockOrdinalsReader { - private final SortedDocValues sortedDocValues; - - SortedDocValuesBlockOrdinalsReader(BlockFactory blockFactory, SortedDocValues sortedDocValues) { - super(blockFactory); - this.sortedDocValues = sortedDocValues; - } - - @Override - IntBlock readOrdinalsAdded1(IntVector docs) throws IOException { - final int positionCount = docs.getPositionCount(); - try (IntVector.FixedBuilder builder = blockFactory.newIntVectorFixedBuilder(positionCount)) { - for (int p = 0; p < positionCount; p++) { - if (sortedDocValues.advanceExact(docs.getInt(p))) { - builder.appendInt(p, sortedDocValues.ordValue() + 1); - } else { - builder.appendInt(p, 0); - } - } - return builder.build().asBlock(); - } - } - - @Override - int docID() { - return sortedDocValues.docID(); - } - } -} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java index 6ab0291c718a7..9a5f78132b266 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java @@ -40,7 +40,7 @@ public record Factory( @Override public Operator get(DriverContext driverContext) { // TODO: use TimeSeriesBlockHash when possible - return new TimeSeriesAggregationOperator(timeBucket, aggregators, () -> { + return new TimeSeriesAggregationOperator(timeBucket, groups, aggregators, () -> { if (sortedInput && groups.size() == 2) { return new TimeSeriesBlockHash(groups.get(0).channel(), groups.get(1).channel(), driverContext.blockFactory()); } else { @@ -68,11 +68,12 @@ public String describe() { public TimeSeriesAggregationOperator( Rounding.Prepared timeBucket, + List groups, List aggregators, Supplier blockHash, DriverContext driverContext ) { - super(aggregators, blockHash, driverContext); + super(groups, aggregators, blockHash, driverContext); this.timeBucket = timeBucket; } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index e3ddbe0b58aed..8185b045029b3 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -11,8 +11,6 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.LongField; import org.apache.lucene.document.LongPoint; -import org.apache.lucene.document.SortedNumericDocValuesField; -import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; @@ -20,13 +18,11 @@ import org.apache.lucene.search.Collector; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; -import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; -import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.settings.Settings; @@ -35,12 +31,8 @@ import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; -import org.elasticsearch.compute.aggregation.CountAggregatorFunction; -import org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; @@ -55,15 +47,10 @@ import org.elasticsearch.compute.lucene.LuceneSourceOperatorTests; import org.elasticsearch.compute.lucene.ShardContext; import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator; -import org.elasticsearch.compute.operator.AbstractPageMappingOperator; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.HashAggregationOperator; -import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.RowInTableLookupOperator; -import org.elasticsearch.compute.operator.ShuffleDocsOperator; import org.elasticsearch.compute.test.BlockTestUtils; import org.elasticsearch.compute.test.OperatorTestCase; import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator; @@ -77,7 +64,6 @@ import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperServiceTestCase; -import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.Uid; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.search.lookup.SearchLookup; @@ -92,8 +78,6 @@ import java.util.Set; import java.util.TreeMap; -import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL; -import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -168,200 +152,6 @@ public void testQueryOperator() throws IOException { } } - public void testGroupingWithOrdinals() throws Exception { - DriverContext driverContext = driverContext(); - BlockFactory blockFactory = driverContext.blockFactory(); - - final String gField = "g"; - final int numDocs = 2856; // between(100, 10000); - final Map expectedCounts = new HashMap<>(); - int keyLength = randomIntBetween(1, 10); - try (BaseDirectoryWrapper dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) { - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - BytesRef key = new BytesRef(randomByteArrayOfLength(keyLength)); - SortedSetDocValuesField docValuesField = new SortedSetDocValuesField(gField, key); - doc.add(docValuesField); - writer.addDocument(doc); - expectedCounts.compute(key, (k, v) -> v == null ? 1 : v + 1); - } - writer.commit(); - Map actualCounts = new HashMap<>(); - - try (DirectoryReader reader = writer.getReader()) { - List operators = new ArrayList<>(); - if (randomBoolean()) { - operators.add(new ShuffleDocsOperator(blockFactory)); - } - operators.add(new AbstractPageMappingOperator() { - @Override - protected Page process(Page page) { - return page.appendBlock(driverContext.blockFactory().newConstantIntBlockWith(1, page.getPositionCount())); - } - - @Override - public String toString() { - return "Add(1)"; - } - }); - operators.add( - new OrdinalsGroupingOperator( - shardIdx -> new KeywordFieldMapper.KeywordFieldType("g").blockLoader(mockBlContext()), - List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE, 0.2)), - ElementType.BYTES_REF, - 0, - gField, - List.of(CountAggregatorFunction.supplier().groupingAggregatorFactory(INITIAL, List.of(1))), - randomPageSize(), - driverContext - ) - ); - operators.add( - new HashAggregationOperator( - List.of(CountAggregatorFunction.supplier().groupingAggregatorFactory(FINAL, List.of(1, 2))), - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF)), - driverContext.blockFactory(), - randomPageSize(), - false - ), - driverContext - ) - ); - Driver driver = TestDriverFactory.create( - driverContext, - luceneOperatorFactory( - reader, - List.of(new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of())), - LuceneOperator.NO_LIMIT - ).get(driverContext), - operators, - new PageConsumerOperator(page -> { - BytesRefBlock keys = page.getBlock(0); - LongBlock counts = page.getBlock(1); - for (int i = 0; i < keys.getPositionCount(); i++) { - BytesRef spare = new BytesRef(); - keys.getBytesRef(i, spare); - actualCounts.put(BytesRef.deepCopyOf(spare), counts.getLong(i)); - } - page.releaseBlocks(); - }) - ); - OperatorTestCase.runDriver(driver); - assertThat(actualCounts, equalTo(expectedCounts)); - assertDriverContext(driverContext); - org.elasticsearch.common.util.MockBigArrays.ensureAllArraysAreReleased(); - } - } - assertThat(blockFactory.breaker().getUsed(), equalTo(0L)); - } - - // TODO: Remove ordinals grouping operator or enable it GroupingAggregatorFunctionTestCase - public void testValuesWithOrdinalGrouping() throws Exception { - DriverContext driverContext = driverContext(); - BlockFactory blockFactory = driverContext.blockFactory(); - - final int numDocs = between(100, 1000); - Map> expectedValues = new HashMap<>(); - try (BaseDirectoryWrapper dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) { - String VAL_NAME = "val"; - String KEY_NAME = "key"; - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - BytesRef key = new BytesRef(Integer.toString(between(1, 100))); - SortedSetDocValuesField keyField = new SortedSetDocValuesField(KEY_NAME, key); - doc.add(keyField); - if (randomBoolean()) { - int numValues = between(0, 2); - for (int v = 0; v < numValues; v++) { - long val = between(1, 1000); - var valuesField = new SortedNumericDocValuesField(VAL_NAME, val); - doc.add(valuesField); - expectedValues.computeIfAbsent(key, k -> new HashSet<>()).add(val); - } - } - writer.addDocument(doc); - } - writer.commit(); - try (DirectoryReader reader = writer.getReader()) { - List operators = new ArrayList<>(); - if (randomBoolean()) { - operators.add(new ShuffleDocsOperator(blockFactory)); - } - operators.add( - new ValuesSourceReaderOperator( - blockFactory, - List.of( - new ValuesSourceReaderOperator.FieldInfo( - VAL_NAME, - ElementType.LONG, - unused -> new BlockDocValuesReader.LongsBlockLoader(VAL_NAME) - ) - ), - List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> { - throw new UnsupportedOperationException(); - }, 0.2)), - 0 - ) - ); - operators.add( - new OrdinalsGroupingOperator( - shardIdx -> new KeywordFieldMapper.KeywordFieldType(KEY_NAME).blockLoader(mockBlContext()), - List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE, 0.2)), - ElementType.BYTES_REF, - 0, - KEY_NAME, - List.of(new ValuesLongAggregatorFunctionSupplier().groupingAggregatorFactory(INITIAL, List.of(1))), - randomPageSize(), - driverContext - ) - ); - operators.add( - new HashAggregationOperator( - List.of(new ValuesLongAggregatorFunctionSupplier().groupingAggregatorFactory(FINAL, List.of(1))), - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF)), - driverContext.blockFactory(), - randomPageSize(), - false - ), - driverContext - ) - ); - Map> actualValues = new HashMap<>(); - Driver driver = TestDriverFactory.create( - driverContext, - luceneOperatorFactory( - reader, - List.of(new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of())), - LuceneOperator.NO_LIMIT - ).get(driverContext), - operators, - new PageConsumerOperator(page -> { - BytesRefBlock keyBlock = page.getBlock(0); - LongBlock valueBlock = page.getBlock(1); - BytesRef spare = new BytesRef(); - for (int p = 0; p < page.getPositionCount(); p++) { - var key = keyBlock.getBytesRef(p, spare); - int valueCount = valueBlock.getValueCount(p); - for (int i = 0; i < valueCount; i++) { - long val = valueBlock.getLong(valueBlock.getFirstValueIndex(p) + i); - boolean added = actualValues.computeIfAbsent(BytesRef.deepCopyOf(key), k -> new HashSet<>()).add(val); - assertTrue(actualValues.toString(), added); - } - } - page.releaseBlocks(); - }) - ); - OperatorTestCase.runDriver(driver); - assertDriverContext(driverContext); - assertThat(actualValues, equalTo(expectedValues)); - org.elasticsearch.common.util.MockBigArrays.ensureAllArraysAreReleased(); - } - } - assertThat(blockFactory.breaker().getUsed(), equalTo(0L)); - } - public void testPushRoundToToQuery() throws IOException { long firstGroupMax = randomLong(); long secondGroupMax = randomLong(); @@ -382,6 +172,7 @@ public void testPushRoundToToQuery() throws IOException { LuceneOperator.NO_LIMIT ); ValuesSourceReaderOperator.Factory load = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( new ValuesSourceReaderOperator.FieldInfo("v", ElementType.LONG, f -> new BlockDocValuesReader.LongsBlockLoader("v")) ), @@ -408,7 +199,6 @@ public void testPushRoundToToQuery() throws IOException { boolean sawSecondMax = false; boolean sawThirdMax = false; for (Page page : pages) { - logger.error("ADFA {}", page); LongVector group = page.getBlock(1).asVector(); LongVector value = page.getBlock(2).asVector(); for (int p = 0; p < page.getPositionCount(); p++) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunctionTests.java index dbd5d0cc167d1..948aca115f829 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunctionTests.java @@ -11,14 +11,12 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.LongIntBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Tuple; import org.junit.After; @@ -105,43 +103,6 @@ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { ); } - /** - * Tests {@link GroupingAggregator#addIntermediateRow} by building results using the traditional - * add mechanism and using {@link GroupingAggregator#addIntermediateRow} then asserting that they - * produce the same output. - */ - public void testAddIntermediateRowInput() { - DriverContext ctx = driverContext(); - AggregatorFunctionSupplier supplier = aggregatorFunction(); - List channels = channels(AggregatorMode.SINGLE); - Block[] results = new Block[2]; - try ( - GroupingAggregatorFunction main = supplier.groupingAggregator(ctx, channels); - GroupingAggregatorFunction leaf = supplier.groupingAggregator(ctx, channels); - SourceOperator source = simpleInput(ctx.blockFactory(), 10); - ) { - Page p; - while ((p = source.getOutput()) != null) { - try ( - IntVector group = ctx.blockFactory().newConstantIntVector(0, p.getPositionCount()); - GroupingAggregatorFunction.AddInput addInput = leaf.prepareProcessPage(null, p) - ) { - addInput.add(0, group); - } finally { - p.releaseBlocks(); - } - } - main.addIntermediateRowInput(0, leaf, 0); - try (IntVector selected = ctx.blockFactory().newConstantIntVector(0, 1)) { - main.evaluateFinal(results, 0, selected, new GroupingAggregatorEvaluationContext(ctx)); - leaf.evaluateFinal(results, 1, selected, new GroupingAggregatorEvaluationContext(ctx)); - } - assertThat(results[0], equalTo(results[1])); - } finally { - Releasables.close(results); - } - } - @After public void checkUnclosed() { for (Exception tracker : unclosed) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index 9e5039e8fd9b9..f35935e9b5e9b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -50,6 +50,7 @@ import java.util.SortedSet; import java.util.TreeSet; import java.util.function.Function; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.DoubleStream; import java.util.stream.IntStream; @@ -221,6 +222,18 @@ public final void testNullGroupsAndValues() { assertSimpleOutput(origInput, results); } + public final void testMixedMultivaluedNullGroupsAndValues() { + DriverContext driverContext = driverContext(); + BlockFactory blockFactory = driverContext.blockFactory(); + int end = between(50, 60); + List input = CannedSourceOperator.collectPages( + nullGroups(nullValues(mergeAll(simpleInput(blockFactory, end), blockFactory), blockFactory), blockFactory) + ); + List origInput = BlockTestUtils.deepCopyOf(input, TestBlockFactory.getNonBreakingInstance()); + List results = drive(simple().get(driverContext), input.iterator(), driverContext); + assertSimpleOutput(origInput, results); + } + public final void testNullGroups() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); @@ -550,11 +563,18 @@ protected void assertOutputFromNullOnly(Block b, int position) { } private SourceOperator mergeValues(SourceOperator orig, BlockFactory blockFactory) { + return merge(orig, blockFactory, blockIndex -> blockIndex != 0); + } + + private SourceOperator mergeAll(SourceOperator orig, BlockFactory blockFactory) { + return merge(orig, blockFactory, blockIndex -> true); + } + + private SourceOperator merge(SourceOperator orig, BlockFactory blockFactory, Predicate shouldMergeBlockIndex) { return new PositionMergingSourceOperator(orig, blockFactory) { @Override protected Block merge(int blockIndex, Block block) { - // Merge positions for all blocks but the first. For the first just take the first position. - if (blockIndex != 0) { + if (shouldMergeBlockIndex.test(blockIndex)) { return super.merge(blockIndex, block); } Block.Builder builder = block.elementType().newBlockBuilder(block.getPositionCount() / 2, blockFactory); @@ -661,9 +681,9 @@ public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext BitArray seenGroupIds = new BitArray(0, nonBreakingBigArrays()); @Override - public AddInput prepareProcessPage(SeenGroupIds ignoredSeenGroupIds, Page page) { + public AddInput prepareProcessRawInputPage(SeenGroupIds ignoredSeenGroupIds, Page page) { return new AddInput() { - final AddInput delegateAddInput = delegate.prepareProcessPage(bigArrays -> { + final AddInput delegateAddInput = delegate.prepareProcessRawInputPage(bigArrays -> { BitArray seen = new BitArray(0, bigArrays); seen.or(seenGroupIds); return seen; @@ -739,22 +759,53 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { delegate.selectedMayContainUnseenGroups(seenGroupIds); } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIds, Page page) { + addIntermediateInputInternal(positionOffset, groupIds, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIds, Page page) { + addIntermediateInputInternal(positionOffset, groupIds, page); + } + @Override public void addIntermediateInput(int positionOffset, IntVector groupIds, Page page) { + addIntermediateInputInternal(positionOffset, groupIds.asBlock(), page); + } + + public void addIntermediateInputInternal(int positionOffset, IntBlock groupIds, Page page) { + BlockFactory blockFactory = TestBlockFactory.getNonBreakingInstance(); int[] chunk = new int[emitChunkSize]; - for (int offset = 0; offset < groupIds.getPositionCount(); offset += emitChunkSize) { - int count = 0; - for (int i = offset; i < Math.min(groupIds.getPositionCount(), offset + emitChunkSize); i++) { - chunk[count++] = groupIds.getInt(i); + int chunkPosition = 0; + int offset = 0; + for (int position = 0; position < groupIds.getPositionCount(); position++) { + if (groupIds.isNull(position)) { + continue; + } + int firstValueIndex = groupIds.getFirstValueIndex(position); + int valueCount = groupIds.getValueCount(position); + assert valueCount == 1; // Multi-values make chunking more complex, and it's not a real case yet + + int groupId = groupIds.getInt(firstValueIndex); + chunk[chunkPosition++] = groupId; + if (chunkPosition == emitChunkSize) { + delegate.addIntermediateInput( + positionOffset + offset, + blockFactory.newIntArrayVector(chunk, chunkPosition), + page + ); + chunkPosition = 0; + offset = position + 1; } - BlockFactory blockFactory = TestBlockFactory.getNonBreakingInstance(); // TODO: just for compile - delegate.addIntermediateInput(positionOffset + offset, blockFactory.newIntArrayVector(chunk, count), page); } - } - - @Override - public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { - delegate.addIntermediateRowInput(groupId, input, position); + if (chunkPosition > 0) { + delegate.addIntermediateInput( + positionOffset + offset, + blockFactory.newIntArrayVector(chunk, chunkPosition), + page + ); + } } @Override @@ -831,9 +882,7 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { blockHash.add(page, new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { - IntBlock newGroupIds = aggregatorMode.isInputPartial() - ? groupIds - : BlockTypeRandomizer.randomizeBlockType(groupIds); + IntBlock newGroupIds = BlockTypeRandomizer.randomizeBlockType(groupIds); addInput.add(positionOffset, newGroupIds); } @@ -861,7 +910,7 @@ public void close() { }; }; - return new HashAggregationOperator(aggregators, blockHashSupplier, driverContext); + return new HashAggregationOperator(groups, aggregators, blockHashSupplier, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index 842952f9ef8bd..9ce086307acee 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -76,7 +76,13 @@ private void initAnalysisRegistry() throws IOException { ).getAnalysisRegistry(); } + private BlockHash.CategorizeDef getCategorizeDef() { + return new BlockHash.CategorizeDef(null, randomFrom(BlockHash.CategorizeDef.OutputFormat.values()), 70); + } + public void testCategorizeRaw() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + final Page page; boolean withNull = randomBoolean(); final int positions = 7 + (withNull ? 1 : 0); @@ -98,7 +104,7 @@ public void testCategorizeRaw() { page = new Page(builder.build()); } - try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { + try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) { for (int i = randomInt(2); i < 3; i++) { hash.add(page, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { @@ -137,7 +143,10 @@ public void close() { } }); - assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + switch (categorizeDef.outputFormat()) { + case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected"); + } } } finally { page.releaseBlocks(); @@ -145,6 +154,8 @@ public void close() { } public void testCategorizeRawMultivalue() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + final Page page; boolean withNull = randomBoolean(); final int positions = 3 + (withNull ? 1 : 0); @@ -170,7 +181,7 @@ public void testCategorizeRawMultivalue() { page = new Page(builder.build()); } - try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { + try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) { for (int i = randomInt(2); i < 3; i++) { hash.add(page, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { @@ -216,7 +227,10 @@ public void close() { } }); - assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + switch (categorizeDef.outputFormat()) { + case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?"); + case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected"); + } } } finally { page.releaseBlocks(); @@ -224,6 +238,8 @@ public void close() { } public void testCategorizeIntermediate() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + Page page1; boolean withNull = randomBoolean(); int positions1 = 7 + (withNull ? 1 : 0); @@ -259,8 +275,8 @@ public void testCategorizeIntermediate() { // Fill intermediatePages with the intermediate state from the raw hashes try ( - BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); - BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); + BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry); + BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry); ) { rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { @@ -335,7 +351,7 @@ public void close() { page2.releaseBlocks(); } - try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) { + try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, categorizeDef, null)) { intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() { private void addBlock(int positionOffset, IntBlock groupIds) { List values = IntStream.range(0, groupIds.getPositionCount()) @@ -403,14 +419,24 @@ public void close() { } }); - assertHashState( - intermediateHash, - withNull, - ".*?Connected.+?to.*?", - ".*?Connection.+?error.*?", - ".*?Disconnected.*?", - ".*?System.+?shutdown.*?" - ); + switch (categorizeDef.outputFormat()) { + case REGEX -> assertHashState( + intermediateHash, + withNull, + ".*?Connected.+?to.*?", + ".*?Connection.+?error.*?", + ".*?Disconnected.*?", + ".*?System.+?shutdown.*?" + ); + case TOKENS -> assertHashState( + intermediateHash, + withNull, + "Connected to", + "Connection error", + "Disconnected", + "System shutdown" + ); + } } } finally { intermediatePage1.releaseBlocks(); @@ -419,6 +445,9 @@ public void close() { } public void testCategorize_withDriver() { + BlockHash.CategorizeDef categorizeDef = getCategorizeDef(); + BlockHash.GroupSpec groupSpec = new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef); + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking(); CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); @@ -477,7 +506,7 @@ public void testCategorize_withDriver() { new LocalSourceOperator(input1), List.of( new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(makeGroupSpec()), + List.of(groupSpec), AggregatorMode.INITIAL, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)), @@ -496,7 +525,7 @@ public void testCategorize_withDriver() { new LocalSourceOperator(input2), List.of( new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(makeGroupSpec()), + List.of(groupSpec), AggregatorMode.INITIAL, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)), @@ -517,7 +546,7 @@ public void testCategorize_withDriver() { new CannedSourceOperator(intermediateOutput.iterator()), List.of( new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(makeGroupSpec()), + List.of(groupSpec), AggregatorMode.FINAL, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(1, 2)), @@ -544,23 +573,36 @@ public void testCategorize_withDriver() { sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i)); maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i)); } + List keys = switch (categorizeDef.outputFormat()) { + case REGEX -> List.of( + ".*?aaazz.*?", + ".*?bbbzz.*?", + ".*?ccczz.*?", + ".*?dddzz.*?", + ".*?eeezz.*?", + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?hello.*?" + ); + case TOKENS -> List.of("aaazz", "bbbzz", "ccczz", "dddzz", "eeezz", "words words words goodbye", "words words words hello"); + }; + assertThat( sums, equalTo( Map.of( - ".*?aaazz.*?", + keys.get(0), 1L, - ".*?bbbzz.*?", + keys.get(1), 2L, - ".*?ccczz.*?", + keys.get(2), 33L, - ".*?dddzz.*?", + keys.get(3), 44L, - ".*?eeezz.*?", + keys.get(4), 5L, - ".*?words.+?words.+?words.+?goodbye.*?", + keys.get(5), 8888L, - ".*?words.+?words.+?words.+?hello.*?", + keys.get(6), 999L ) ) @@ -569,19 +611,19 @@ public void testCategorize_withDriver() { maxs, equalTo( Map.of( - ".*?aaazz.*?", + keys.get(0), 1L, - ".*?bbbzz.*?", + keys.get(1), 2L, - ".*?ccczz.*?", + keys.get(2), 30L, - ".*?dddzz.*?", + keys.get(3), 40L, - ".*?eeezz.*?", + keys.get(4), 5L, - ".*?words.+?words.+?words.+?goodbye.*?", + keys.get(5), 8000L, - ".*?words.+?words.+?words.+?hello.*?", + keys.get(6), 900L ) ) @@ -589,10 +631,6 @@ public void testCategorize_withDriver() { Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); } - private BlockHash.GroupSpec makeGroupSpec() { - return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true); - } - private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) { // Check the keys Block[] blocks = null; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java index 734b0660d24a3..d0eb89eafd841 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java @@ -74,10 +74,15 @@ public void testCategorize_withDriver() { DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); boolean withNull = randomBoolean(); boolean withMultivalues = randomBoolean(); + BlockHash.CategorizeDef categorizeDef = new BlockHash.CategorizeDef( + null, + randomFrom(BlockHash.CategorizeDef.OutputFormat.values()), + 70 + ); List groupSpecs = List.of( - new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true), - new BlockHash.GroupSpec(1, ElementType.INT, false) + new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef), + new BlockHash.GroupSpec(1, ElementType.INT, null) ); LocalSourceOperator.BlockSupplier input1 = () -> { @@ -218,8 +223,12 @@ public void testCategorize_withDriver() { } Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); + List keys = switch (categorizeDef.outputFormat()) { + case REGEX -> List.of(".*?connected.+?to.*?", ".*?connection.+?error.*?", ".*?disconnected.*?"); + case TOKENS -> List.of("connected to", "connection error", "disconnected"); + }; Map>> expectedResult = Map.of( - ".*?connected.+?to.*?", + keys.get(0), Map.of( 7, Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"), @@ -228,9 +237,9 @@ public void testCategorize_withDriver() { 111, Set.of("connected to 2.1.1") ), - ".*?connection.+?error.*?", + keys.get(1), Map.of(7, Set.of("connection error"), 42, Set.of("connection error")), - ".*?disconnected.*?", + keys.get(2), Map.of(7, Set.of("disconnected")) ); if (withNull) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java index f96b9d26f075c..8d3b662cb0023 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java @@ -363,7 +363,7 @@ private void hashBatchesCallbackOnLast(Consumer callback, Block[].. private BlockHash buildBlockHash(int emitBatchSize, Block... values) { List specs = new ArrayList<>(values.length); for (int c = 0; c < values.length; c++) { - specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c))); + specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c), null)); } assert forcePackedHash == false : "Packed TopN hash not implemented yet"; /*return forcePackedHash diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java index 9640cbb2b44ea..3856d85c6178a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java @@ -77,7 +77,7 @@ public static IntBlock randomizeBlockType(IntBlock block) { } int[] values = new int[totalValues]; - for (int i = 0; i < values.length; i++) { + for (int i = 0; i < totalValues; i++) { values[i] = block.getInt(i); } @@ -93,7 +93,7 @@ public static IntBlock randomizeBlockType(IntBlock block) { } var intArray = blockFactory.bigArrays().newIntArray(totalValues); - for (int i = 0; i < block.getPositionCount(); i++) { + for (int i = 0; i < totalValues; i++) { intArray.set(i, block.getInt(i)); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java index 655f7b54c61c0..2ef64623daa74 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.OperatorTests; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; @@ -201,6 +202,7 @@ private List runQuery(Set values, Query query, boolean shuffleDocs operators.add( new ValuesSourceReaderOperator( blockFactory, + ByteSizeValue.ofGb(1).getBytes(), List.of( new ValuesSourceReaderOperator.FieldInfo( FIELD, diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/SingletonOrdinalsBuilderTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/SingletonOrdinalsBuilderTests.java similarity index 76% rename from x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/SingletonOrdinalsBuilderTests.java rename to x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/SingletonOrdinalsBuilderTests.java index 9a338506c00d1..9d0a3a3439680 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/SingletonOrdinalsBuilderTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/SingletonOrdinalsBuilderTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.compute.data; +package org.elasticsearch.compute.lucene.read; import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.index.DirectoryReader; @@ -17,20 +17,13 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.PageCacheRecycler; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.test.MockBlockFactory; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; import org.elasticsearch.indices.CrankyCircuitBreakerService; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -42,14 +35,14 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -public class SingletonOrdinalsBuilderTests extends ESTestCase { +public class SingletonOrdinalsBuilderTests extends ComputeTestCase { + public void testReader() throws IOException { - testRead(breakingDriverContext().blockFactory()); + testRead(blockFactory()); } public void testReadWithCranky() throws IOException { - BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new CrankyCircuitBreakerService()); - BlockFactory factory = new BlockFactory(bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST), bigArrays); + var factory = crankyBlockFactory(); try { testRead(factory); // If we made it this far cranky didn't fail us! @@ -112,23 +105,8 @@ private void assertCompactToUnique(int[] sortedOrds, List expected) { assertMap(Arrays.stream(sortedOrds).mapToObj(Integer::valueOf).limit(uniqueLength).toList(), matchesList(expected)); } - private final List breakers = new ArrayList<>(); - private final List blockFactories = new ArrayList<>(); - - /** - * A {@link DriverContext} with a breaking {@link BigArrays} and {@link BlockFactory}. - */ - protected DriverContext breakingDriverContext() { // TODO move this to driverContext once everyone supports breaking - BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofGb(1)).withCircuitBreaking(); - CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); - breakers.add(breaker); - BlockFactory factory = new MockBlockFactory(breaker, bigArrays); - blockFactories.add(factory); - return new DriverContext(bigArrays, factory); - } - public void testAllNull() throws IOException { - BlockFactory factory = breakingDriverContext().blockFactory(); + BlockFactory factory = blockFactory(); int count = 1000; try (Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { for (int i = 0; i < count; i++) { @@ -159,7 +137,7 @@ public void testAllNull() throws IOException { } public void testEmitOrdinalForHighCardinality() throws IOException { - BlockFactory factory = breakingDriverContext().blockFactory(); + BlockFactory factory = blockFactory(); int numOrds = between(50, 100); try (Directory directory = newDirectory(); IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig())) { for (int o = 0; o < numOrds; o++) { @@ -198,20 +176,4 @@ static BytesRefBlock buildOrdinalsBuilder(SingletonOrdinalsBuilder builder) { return builder.buildOrdinal(); } } - - @After - public void allBreakersEmpty() throws Exception { - // first check that all big arrays are released, which can affect breakers - MockBigArrays.ensureAllArraysAreReleased(); - - for (CircuitBreaker breaker : breakers) { - for (var factory : blockFactories) { - if (factory instanceof MockBlockFactory mockBlockFactory) { - mockBlockFactory.ensureAllBlocksAreReleased(); - } - } - assertThat("Unexpected used in breaker: " + breaker, breaker.getUsed(), equalTo(0L)); - } - } - } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/SortedSetOrdinalsBuilderTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/SortedSetOrdinalsBuilderTests.java new file mode 100644 index 0000000000000..745e949c6ae71 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/SortedSetOrdinalsBuilderTests.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.lucene.read; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.indices.CrankyCircuitBreakerService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + +public class SortedSetOrdinalsBuilderTests extends ComputeTestCase { + + public void testReader() throws IOException { + testRead(blockFactory()); + } + + public void testReadWithCranky() throws IOException { + BlockFactory factory = crankyBlockFactory(); + try { + testRead(factory); + // If we made it this far cranky didn't fail us! + } catch (CircuitBreakingException e) { + logger.info("cranky", e); + assertThat(e.getMessage(), equalTo(CrankyCircuitBreakerService.ERROR_MESSAGE)); + } + assertThat(factory.breaker().getUsed(), equalTo(0L)); + } + + private void testRead(BlockFactory factory) throws IOException { + int numDocs = between(1, 1000); + try (Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + Map> expectedValues = new HashMap<>(); + List allKeys = IntStream.range(0, between(1, 20)).mapToObj(n -> String.format(Locale.ROOT, "v%02d", n)).toList(); + for (int i = 0; i < numDocs; i++) { + List subs = randomSubsetOf(allKeys).stream().sorted().toList(); + expectedValues.put(i, subs); + Document doc = new Document(); + for (String v : subs) { + doc.add(new SortedSetDocValuesField("f", new BytesRef(v))); + } + doc.add(new NumericDocValuesField("k", i)); + indexWriter.addDocument(doc); + } + Map> actualValues = new HashMap<>(); + try (IndexReader reader = indexWriter.getReader()) { + for (LeafReaderContext ctx : reader.leaves()) { + var keysDV = ctx.reader().getNumericDocValues("k"); + var valuesDV = ctx.reader().getSortedSetDocValues("f"); + try ( + var valuesBuilder = new SortedSetOrdinalsBuilder(factory, valuesDV, ctx.reader().numDocs()); + var keysBuilder = factory.newIntVectorBuilder(ctx.reader().numDocs()) + ) { + for (int i = 0; i < ctx.reader().maxDoc(); i++) { + if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(i)) { + assertTrue(keysDV.advanceExact(i)); + keysBuilder.appendInt(Math.toIntExact(keysDV.longValue())); + if (valuesDV.advanceExact(i)) { + int valueCount = valuesDV.docValueCount(); + if (valueCount > 1) { + valuesBuilder.beginPositionEntry(); + } + for (int v = 0; v < valueCount; v++) { + valuesBuilder.appendOrd(Math.toIntExact(valuesDV.nextOrd())); + } + if (valueCount > 1) { + valuesBuilder.endPositionEntry(); + } + } else { + valuesBuilder.appendNull(); + } + } + } + BytesRef scratch = new BytesRef(); + try (BytesRefBlock valuesBlock = valuesBuilder.build(); IntVector counterVector = keysBuilder.build()) { + for (int p = 0; p < valuesBlock.getPositionCount(); p++) { + int key = counterVector.getInt(p); + ArrayList subs = new ArrayList<>(); + assertNull(actualValues.put(key, subs)); + int count = valuesBlock.getValueCount(p); + int first = valuesBlock.getFirstValueIndex(p); + int last = first + count; + for (int i = first; i < last; i++) { + String val = valuesBlock.getBytesRef(i, scratch).utf8ToString(); + subs.add(val); + } + } + } + } + } + } + assertThat(actualValues, equalTo(expectedValues)); + } + } + + public void testAllNull() throws IOException { + BlockFactory factory = blockFactory(); + int numDocs = between(1, 100); + try (Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new SortedSetDocValuesField("f", new BytesRef("empty"))); + indexWriter.addDocument(doc); + } + try (IndexReader reader = indexWriter.getReader()) { + for (LeafReaderContext ctx : reader.leaves()) { + var docValues = ctx.reader().getSortedSetDocValues("f"); + try (var builder = new SortedSetOrdinalsBuilder(factory, docValues, ctx.reader().numDocs())) { + for (int i = 0; i < ctx.reader().maxDoc(); i++) { + if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(i)) { + assertThat(docValues.advanceExact(i), equalTo(true)); + builder.appendNull(); + } + } + try (BytesRefBlock built = builder.build()) { + for (int p = 0; p < built.getPositionCount(); p++) { + assertThat(built.isNull(p), equalTo(true)); + } + assertThat(built.areAllValuesNull(), equalTo(true)); + } + } + } + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java index 2bd5cc95dd804..5a1f2ee7cc949 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.Block; @@ -241,12 +242,17 @@ private static Operator.OperatorFactory factory( ElementType elementType, BlockLoader loader ) { - return new ValuesSourceReaderOperator.Factory(List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { - if (shardIdx < 0 || shardIdx >= INDICES.size()) { - fail("unexpected shardIdx [" + shardIdx + "]"); - } - return loader; - })), shardContexts, 0); + return new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), + List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { + if (shardIdx < 0 || shardIdx >= INDICES.size()) { + fail("unexpected shardIdx [" + shardIdx + "]"); + } + return loader; + })), + shardContexts, + 0 + ); } protected SourceOperator simpleInput(DriverContext context, int size) { @@ -493,6 +499,7 @@ public void testManySingleDocPages() { // TODO: Add index2 operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(testCase.info, fieldInfo(mapperService(indexKey).fieldType("key"), ElementType.INT)), shardContexts, 0 @@ -600,6 +607,7 @@ private void loadSimpleAndAssert( List operators = new ArrayList<>(); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( fieldInfo(mapperService("index1").fieldType("key"), ElementType.INT), fieldInfo(mapperService("index1").fieldType("indexKey"), ElementType.BYTES_REF) @@ -614,7 +622,9 @@ private void loadSimpleAndAssert( cases.removeAll(b); tests.addAll(b); operators.add( - new ValuesSourceReaderOperator.Factory(b.stream().map(i -> i.info).toList(), shardContexts, 0).get(driverContext) + new ValuesSourceReaderOperator.Factory(ByteSizeValue.ofGb(1), b.stream().map(i -> i.info).toList(), shardContexts, 0).get( + driverContext + ) ); } List results = drive(operators, input.iterator(), driverContext); @@ -718,7 +728,7 @@ private void testLoadAllStatus(boolean allInOnePage) { Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING ); List operators = cases.stream() - .map(i -> new ValuesSourceReaderOperator.Factory(List.of(i.info), shardContexts, 0).get(driverContext)) + .map(i -> new ValuesSourceReaderOperator.Factory(ByteSizeValue.ofGb(1), List.of(i.info), shardContexts, 0).get(driverContext)) .toList(); if (allInOnePage) { input = List.of(CannedSourceOperator.mergePages(input)); @@ -1390,6 +1400,7 @@ public void testNullsShared() { simpleInput(driverContext, 10), List.of( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( new ValuesSourceReaderOperator.FieldInfo("null1", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS), new ValuesSourceReaderOperator.FieldInfo("null2", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS) @@ -1424,6 +1435,7 @@ public void testDescriptionOfMany() throws IOException { List cases = infoAndChecksForEachType(ordering, ordering); ValuesSourceReaderOperator.Factory factory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), cases.stream().map(c -> c.info).toList(), List.of(new ValuesSourceReaderOperator.ShardContext(reader(indexKey), () -> SourceLoader.FROM_STORED_SOURCE, 0.2)), 0 @@ -1469,6 +1481,7 @@ public void testManyShards() throws IOException { // TODO add index2 MappedFieldType ft = mapperService(indexKey).fieldType("key"); var readerFactory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(new ValuesSourceReaderOperator.FieldInfo("key", ElementType.INT, shardIdx -> { seenShards.add(shardIdx); return ft.blockLoader(blContext()); @@ -1676,8 +1689,8 @@ public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) throws } return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - Block block = reader.read(factory, docs); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + Block block = reader.read(factory, docs, offset); Page page = new Page((org.elasticsearch.compute.data.Block) block); return convertEvaluator.eval(page); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java index c9b46eb764580..19a645c146242 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java @@ -18,6 +18,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.TieredMergePolicy; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; @@ -29,6 +30,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; @@ -36,6 +38,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.DocBlock; +import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; @@ -98,6 +101,7 @@ import static org.elasticsearch.test.MapMatcher.matchesMap; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -149,12 +153,14 @@ public static Operator.OperatorFactory factory(IndexReader reader, MappedFieldTy } static Operator.OperatorFactory factory(IndexReader reader, String name, ElementType elementType, BlockLoader loader) { - return new ValuesSourceReaderOperator.Factory(List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { - if (shardIdx != 0) { - fail("unexpected shardIdx [" + shardIdx + "]"); - } - return loader; - })), + return new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), + List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { + if (shardIdx != 0) { + fail("unexpected shardIdx [" + shardIdx + "]"); + } + return loader; + })), List.of( new ValuesSourceReaderOperator.ShardContext( reader, @@ -390,17 +396,17 @@ private IndexReader initIndex(Directory directory, int size, int commitEvery) th return DirectoryReader.open(directory); } - private IndexReader initIndexLongField(Directory directory, int size, int commitEvery) throws IOException { + private IndexReader initIndexLongField(Directory directory, int size, int commitEvery, boolean forceMerge) throws IOException { try ( IndexWriter writer = new IndexWriter( directory, - newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE).setMaxBufferedDocs(IndexWriterConfig.DISABLE_AUTO_FLUSH) + newIndexWriterConfig().setMergePolicy(new TieredMergePolicy()).setMaxBufferedDocs(IndexWriterConfig.DISABLE_AUTO_FLUSH) ) ) { for (int d = 0; d < size; d++) { XContentBuilder source = JsonXContent.contentBuilder(); source.startObject(); - source.field("long_source_text", Integer.toString(d).repeat(100 * 1024)); + source.field("long_source_text", d + "#" + "a".repeat(100 * 1024)); source.endObject(); ParsedDocument doc = mapperService.documentParser() .parseDocument( @@ -413,6 +419,10 @@ private IndexReader initIndexLongField(Directory directory, int size, int commit writer.commit(); } } + + if (forceMerge) { + writer.forceMerge(1); + } } return DirectoryReader.open(directory); } @@ -484,6 +494,7 @@ public void testManySingleDocPages() { ); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(testCase.info, fieldInfo(mapperService.fieldType("key"), ElementType.INT)), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -603,6 +614,7 @@ private void loadSimpleAndAssert( List operators = new ArrayList<>(); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(fieldInfo(mapperService.fieldType("key"), ElementType.INT)), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -621,6 +633,7 @@ private void loadSimpleAndAssert( tests.addAll(b); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), b.stream().map(i -> i.info).toList(), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -719,6 +732,7 @@ private void testLoadAllStatus(boolean allInOnePage) { List operators = cases.stream() .map( i -> new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(i.info), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -923,8 +937,7 @@ public void testLoadLongShuffledManySegments() throws IOException { private void testLoadLong(boolean shuffle, boolean manySegments) throws IOException { int numDocs = between(10, 500); initMapping(); - keyToTags.clear(); - reader = initIndexLongField(directory, numDocs, manySegments ? commitEvery(numDocs) : numDocs); + reader = initIndexLongField(directory, numDocs, manySegments ? commitEvery(numDocs) : numDocs, manySegments == false); DriverContext driverContext = driverContext(); List input = CannedSourceOperator.collectPages(sourceOperator(driverContext, numDocs)); @@ -936,6 +949,7 @@ private void testLoadLong(boolean shuffle, boolean manySegments) throws IOExcept if (shuffle) { input = input.stream().map(this::shuffle).toList(); } + boolean willSplit = loadLongWillSplit(input); Checks checks = new Checks(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING, Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING); @@ -951,6 +965,7 @@ private void testLoadLong(boolean shuffle, boolean manySegments) throws IOExcept List operators = cases.stream() .map( i -> new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(i.info), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -963,12 +978,55 @@ private void testLoadLong(boolean shuffle, boolean manySegments) throws IOExcept ).get(driverContext) ) .toList(); - drive(operators, input.iterator(), driverContext); + List result = drive(operators, input.iterator(), driverContext); + + boolean[] found = new boolean[numDocs]; + for (Page page : result) { + BytesRefVector bytes = page.getBlock(1).asVector(); + BytesRef scratch = new BytesRef(); + for (int p = 0; p < bytes.getPositionCount(); p++) { + BytesRef v = bytes.getBytesRef(p, scratch); + int d = Integer.valueOf(v.utf8ToString().split("#")[0]); + assertFalse("found a duplicate " + d, found[d]); + found[d] = true; + } + } + List missing = new ArrayList<>(); + for (int d = 0; d < numDocs; d++) { + if (found[d] == false) { + missing.add(d); + } + } + assertThat(missing, hasSize(0)); + assertThat(result, hasSize(willSplit ? greaterThanOrEqualTo(input.size()) : equalTo(input.size()))); + for (int i = 0; i < cases.size(); i++) { ValuesSourceReaderOperatorStatus status = (ValuesSourceReaderOperatorStatus) operators.get(i).status(); assertThat(status.pagesReceived(), equalTo(input.size())); - assertThat(status.pagesEmitted(), equalTo(input.size())); + assertThat(status.pagesEmitted(), willSplit ? greaterThanOrEqualTo(input.size()) : equalTo(input.size())); + } + } + + private boolean loadLongWillSplit(List input) { + int nextDoc = -1; + for (Page page : input) { + DocVector doc = page.getBlock(0).asVector(); + for (int p = 0; p < doc.getPositionCount(); p++) { + if (doc.shards().getInt(p) != 0) { + return false; + } + if (doc.segments().getInt(p) != 0) { + return false; + } + if (nextDoc == -1) { + nextDoc = doc.docs().getInt(p); + } else if (doc.docs().getInt(p) != nextDoc) { + return false; + } + nextDoc++; + } } + return true; } record Checks(Block.MvOrdering booleanAndNumericalDocValuesMvOrdering, Block.MvOrdering bytesRefDocValuesMvOrdering) { @@ -1560,6 +1618,7 @@ public void testNullsShared() { simpleInput(driverContext.blockFactory(), 10), List.of( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( new ValuesSourceReaderOperator.FieldInfo("null1", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS), new ValuesSourceReaderOperator.FieldInfo("null2", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS) @@ -1611,6 +1670,7 @@ private void testSequentialStoredFields(boolean sequential, int docCount) throws assertThat(source, hasSize(1)); // We want one page for simpler assertions, and we want them all in one segment assertTrue(source.get(0).getBlock(0).asVector().singleSegmentNonDecreasing()); Operator op = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( fieldInfo(mapperService.fieldType("key"), ElementType.INT), fieldInfo(storedTextField("stored_text"), ElementType.BYTES_REF) @@ -1648,6 +1708,7 @@ public void testDescriptionOfMany() throws IOException { List cases = infoAndChecksForEachType(ordering, ordering); ValuesSourceReaderOperator.Factory factory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), cases.stream().map(c -> c.info).toList(), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -1701,6 +1762,7 @@ public void testManyShards() throws IOException { ); MappedFieldType ft = mapperService.fieldType("key"); var readerFactory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(new ValuesSourceReaderOperator.FieldInfo("key", ElementType.INT, shardIdx -> { seenShards.add(shardIdx); return ft.blockLoader(blContext()); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index 106b9613d7bb2..c8c6138eaece0 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -28,6 +28,7 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.function.Function; import java.util.stream.LongStream; import static java.util.stream.IntStream.range; @@ -113,7 +114,7 @@ public void testTopNNullsLast() { try ( var operator = new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, false, 3))), + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3), null)), mode, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels), @@ -190,7 +191,7 @@ public void testTopNNullsFirst() { try ( var operator = new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, false, new BlockHash.TopNDef(0, ascOrder, true, 3))), + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, 3), null)), mode, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels), @@ -254,4 +255,96 @@ public void testTopNNullsFirst() { outputPage.releaseBlocks(); } } + + /** + * When in intermediate/final mode, it will receive intermediate outputs that may have to be discarded + * (TopN in the datanode but not acceptable in the coordinator). + *

+ * This test ensures that such discarding works correctly. + *

+ */ + public void testTopNNullsIntermediateDiscards() { + boolean ascOrder = randomBoolean(); + var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L }; + if (ascOrder) { + Arrays.sort(groups, Comparator.reverseOrder()); + } + var groupChannel = 0; + + // Supplier of operators to ensure that they're identical, simulating a datanode/coordinator connection + Function makeAggWithMode = (mode) -> { + var sumAggregatorChannels = mode.isInputPartial() ? List.of(1, 2) : List.of(1); + var maxAggregatorChannels = mode.isInputPartial() ? List.of(3, 4) : List.of(1); + + return new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))), + mode, + List.of( + new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, sumAggregatorChannels), + new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, maxAggregatorChannels) + ), + randomPageSize(), + null + ).get(driverContext()); + }; + + // The operator that will collect all the results + try (var collectingOperator = makeAggWithMode.apply(AggregatorMode.FINAL)) { + // First datanode, sending a suitable TopN set of data + try (var datanodeOperator = makeAggWithMode.apply(AggregatorMode.INITIAL)) { + var page = new Page( + BlockUtils.fromList(blockFactory(), List.of(List.of(groups[4], 1L), List.of(groups[3], 2L), List.of(groups[2], 4L))) + ); + datanodeOperator.addInput(page); + datanodeOperator.finish(); + + var outputPage = datanodeOperator.getOutput(); + collectingOperator.addInput(outputPage); + } + + // Second datanode, sending an outdated TopN, as the coordinator has better top values already + try (var datanodeOperator = makeAggWithMode.apply(AggregatorMode.INITIAL)) { + var page = new Page( + BlockUtils.fromList( + blockFactory(), + List.of( + List.of(groups[5], 8L), + List.of(groups[3], 16L), + List.of(groups[1], 32L) // This group is worse than the worst group in the coordinator + ) + ) + ); + datanodeOperator.addInput(page); + datanodeOperator.finish(); + + var outputPage = datanodeOperator.getOutput(); + collectingOperator.addInput(outputPage); + } + + collectingOperator.finish(); + + var outputPage = collectingOperator.getOutput(); + + var groupsBlock = (LongBlock) outputPage.getBlock(0); + var sumBlock = (LongBlock) outputPage.getBlock(1); + var maxBlock = (LongBlock) outputPage.getBlock(2); + + assertThat(groupsBlock.getPositionCount(), equalTo(3)); + assertThat(sumBlock.getPositionCount(), equalTo(3)); + assertThat(maxBlock.getPositionCount(), equalTo(3)); + + assertThat(groupsBlock.getTotalValueCount(), equalTo(3)); + assertThat(sumBlock.getTotalValueCount(), equalTo(3)); + assertThat(maxBlock.getTotalValueCount(), equalTo(3)); + + assertThat( + BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3), + equalTo(Arrays.asList(List.of(groups[4]), List.of(groups[3]), List.of(groups[5]))) + ); + assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(1L), List.of(18L), List.of(8L)))); + assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(1L), List.of(16L), List.of(8L)))); + + outputPage.releaseBlocks(); + } + } } diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java index 705bfca2e903e..a943e917e0335 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java @@ -126,7 +126,10 @@ public MultiClusterSpecIT( "NullifiedJoinKeyToPurgeTheJoin", "SortBeforeAndAfterJoin", "SortEvalBeforeLookup", - "SortBeforeAndAfterMultipleJoinAndMvExpand" + "SortBeforeAndAfterMultipleJoinAndMvExpand", + "LookupJoinAfterTopNAndRemoteEnrich", + // Lookup join after LIMIT is not supported in CCS yet + "LookupJoinAfterLimitAndRemoteEnrich" ); @Override diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java index 7fa6d789bc8fc..2fb107ddfe73c 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClustersIT.java @@ -205,25 +205,24 @@ private Map runEsql(RestEsqlTestCase.RequestObjectBuilder reques } } - private void assertResultMapForLike( + private void assertResultMapWithCapabilities( boolean includeCCSMetadata, Map result, C columns, V values, boolean remoteOnly, - boolean requireLikeListCapability + List fullResultCapabilities ) throws IOException { - List requiredCapabilities = new ArrayList<>(List.of("like_on_index_fields")); - if (requireLikeListCapability) { - requiredCapabilities.add("like_list_on_index_fields"); - } // the feature is completely supported if both local and remote clusters support it - boolean isSupported = capabilitiesSupportedNewAndOld(requiredCapabilities); - + // otherwise we expect a partial result, and will not check the data + boolean isSupported = capabilitiesSupportedNewAndOld(fullResultCapabilities); if (isSupported) { assertResultMap(includeCCSMetadata, result, columns, values, remoteOnly); } else { - logger.info("--> skipping data check for like index test, cluster does not support like index feature"); + logger.info( + "--> skipping data check for a test, cluster does not support all of [{}] capabilities", + String.join(",", fullResultCapabilities) + ); // just verify that we did not get a partial result var clusters = result.get("_clusters"); var reason = "unexpected partial results" + (clusters != null ? ": _clusters=" + clusters : ""); @@ -526,7 +525,7 @@ public void testLikeIndex() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); var values = List.of(List.of(remoteDocs.size(), REMOTE_CLUSTER_NAME + ":" + remoteIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, false); + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); } public void testLikeIndexLegacySettingNoResults() throws Exception { @@ -548,7 +547,7 @@ public void testLikeIndexLegacySettingNoResults() throws Exception { var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); // we expect empty result, since the setting is false var values = List.of(); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, false); + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); } } @@ -572,7 +571,7 @@ public void testLikeIndexLegacySettingResults() throws Exception { var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); // we expect results, since the setting is false, but there is : in the LIKE query var values = List.of(List.of(remoteDocs.size(), REMOTE_CLUSTER_NAME + ":" + remoteIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, false); + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); } } @@ -586,7 +585,7 @@ public void testNotLikeIndex() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); var values = List.of(List.of(localDocs.size(), localIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, false); + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); } public void testLikeListIndex() throws Exception { @@ -601,7 +600,14 @@ public void testLikeListIndex() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); var values = List.of(List.of(remoteDocs.size(), REMOTE_CLUSTER_NAME + ":" + remoteIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, true); + assertResultMapWithCapabilities( + includeCCSMetadata, + result, + columns, + values, + false, + List.of("like_on_index_fields", "like_list_on_index_fields") + ); } public void testNotLikeListIndex() throws Exception { @@ -615,7 +621,14 @@ public void testNotLikeListIndex() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); var values = List.of(List.of(localDocs.size(), localIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, true); + assertResultMapWithCapabilities( + includeCCSMetadata, + result, + columns, + values, + false, + List.of("like_on_index_fields", "like_list_on_index_fields") + ); } public void testNotLikeListKeyword() throws Exception { @@ -629,11 +642,24 @@ public void testNotLikeListKeyword() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); Predicate filter = d -> false == (d.color.contains("blue") || d.color.contains("red")); - var values = List.of( - List.of((int) remoteDocs.stream().filter(filter).count(), REMOTE_CLUSTER_NAME + ":" + remoteIndex), - List.of((int) localDocs.stream().filter(filter).count(), localIndex) + + var values = new ArrayList<>(); + int remoteCount = (int) remoteDocs.stream().filter(filter).count(); + int localCount = (int) localDocs.stream().filter(filter).count(); + if (remoteCount > 0) { + values.add(List.of(remoteCount, REMOTE_CLUSTER_NAME + ":" + remoteIndex)); + } + if (localCount > 0) { + values.add(List.of(localCount, localIndex)); + } + assertResultMapWithCapabilities( + includeCCSMetadata, + result, + columns, + values, + false, + List.of("like_on_index_fields", "like_list_on_index_fields") ); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, true); } public void testRLikeIndex() throws Exception { @@ -646,7 +672,7 @@ public void testRLikeIndex() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); var values = List.of(List.of(remoteDocs.size(), REMOTE_CLUSTER_NAME + ":" + remoteIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, false); + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); } public void testNotRLikeIndex() throws Exception { @@ -659,7 +685,37 @@ public void testNotRLikeIndex() throws Exception { """, includeCCSMetadata); var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); var values = List.of(List.of(localDocs.size(), localIndex)); - assertResultMapForLike(includeCCSMetadata, result, columns, values, false, false); + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); + } + + public void testRLikeListIndex() throws Exception { + assumeTrue("not supported", capabilitiesSupportedNewAndOld(List.of("rlike_with_list_of_patterns"))); + boolean includeCCSMetadata = includeCCSMetadata(); + Map result = run(""" + FROM test-local-index,*:test-remote-index METADATA _index + | WHERE _index RLIKE (".*remote.*", ".*not-exist.*") + | STATS c = COUNT(*) BY _index + | SORT _index ASC + """, includeCCSMetadata); + var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); + var values = List.of(List.of(remoteDocs.size(), REMOTE_CLUSTER_NAME + ":" + remoteIndex)); + // we depend on the code in like_on_index_fields to serialize an ExpressionQueryBuilder + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); + } + + public void testNotRLikeListIndex() throws Exception { + assumeTrue("not supported", capabilitiesSupportedNewAndOld(List.of("rlike_with_list_of_patterns"))); + boolean includeCCSMetadata = includeCCSMetadata(); + Map result = run(""" + FROM test-local-index,*:test-remote-index METADATA _index + | WHERE _index NOT RLIKE (".*remote.*", ".*not-exist.*") + | STATS c = COUNT(*) BY _index + | SORT _index ASC + """, includeCCSMetadata); + var columns = List.of(Map.of("name", "c", "type", "long"), Map.of("name", "_index", "type", "keyword")); + var values = List.of(List.of(localDocs.size(), localIndex)); + // we depend on the code in like_on_index_fields to serialize an ExpressionQueryBuilder + assertResultMapWithCapabilities(includeCCSMetadata, result, columns, values, false, List.of("like_on_index_fields")); } private RestClient remoteClusterClient() throws IOException { diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java index d01e1c9fb7f56..3484f19afa451 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java @@ -9,13 +9,21 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.TestClustersThreadFilter; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.esql.CsvSpecReader.CsvTestCase; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.plugin.ComputeService; import org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase; +import org.junit.Before; import org.junit.ClassRule; +import java.io.IOException; + @ThreadLeakFilters(filters = TestClustersThreadFilter.class) public class EsqlSpecIT extends EsqlSpecTestCase { @ClassRule @@ -50,4 +58,14 @@ protected boolean enableRoundingDoubleValuesOnAsserting() { protected boolean supportsSourceFieldMapping() { return cluster.getNumNodes() == 1; } + + @Before + public void configureChunks() throws IOException { + boolean smallChunks = randomBoolean(); + Request request = new Request("PUT", "/_cluster/settings"); + XContentBuilder builder = JsonXContent.contentBuilder().startObject().startObject("persistent"); + builder.field(PhysicalSettings.VALUES_LOADING_JUMBO_SIZE.getKey(), smallChunks ? "1kb" : null); + request.setJsonEntity(Strings.toString(builder.endObject().endObject())); + assertOK(client().performRequest(request)); + } } diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushQueriesIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushQueriesIT.java index 27dd245121cd9..70452976ca14a 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushQueriesIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushQueriesIT.java @@ -275,6 +275,42 @@ public void testLikeList() throws IOException { testPushQuery(value, esqlQuery, List.of(luceneQuery), dataNodeSignature, true); } + public void testRLike() throws IOException { + String value = "v".repeat(between(1, 256)); + String esqlQuery = """ + FROM test + | WHERE test rlike "%value.*" + """; + String luceneQuery = switch (type) { + case KEYWORD -> "test:/%value.*/"; + case CONSTANT_KEYWORD, MATCH_ONLY_TEXT_WITH_KEYWORD, AUTO, TEXT_WITH_KEYWORD -> "*:*"; + case SEMANTIC_TEXT_WITH_KEYWORD -> "FieldExistsQuery [field=_primary_term]"; + }; + ComputeSignature dataNodeSignature = switch (type) { + case CONSTANT_KEYWORD, KEYWORD -> ComputeSignature.FILTER_IN_QUERY; + case AUTO, TEXT_WITH_KEYWORD, MATCH_ONLY_TEXT_WITH_KEYWORD, SEMANTIC_TEXT_WITH_KEYWORD -> ComputeSignature.FILTER_IN_COMPUTE; + }; + testPushQuery(value, esqlQuery, List.of(luceneQuery), dataNodeSignature, true); + } + + public void testRLikeList() throws IOException { + String value = "v".repeat(between(1, 256)); + String esqlQuery = """ + FROM test + | WHERE test rlike ("%value.*", "abc.*") + """; + String luceneQuery = switch (type) { + case CONSTANT_KEYWORD, MATCH_ONLY_TEXT_WITH_KEYWORD, AUTO, TEXT_WITH_KEYWORD -> "*:*"; + case SEMANTIC_TEXT_WITH_KEYWORD -> "FieldExistsQuery [field=_primary_term]"; + case KEYWORD -> "test:RLIKE(\"%value.*\", \"abc.*\"), caseInsensitive=false"; + }; + ComputeSignature dataNodeSignature = switch (type) { + case CONSTANT_KEYWORD, KEYWORD -> ComputeSignature.FILTER_IN_QUERY; + case AUTO, TEXT_WITH_KEYWORD, MATCH_ONLY_TEXT_WITH_KEYWORD, SEMANTIC_TEXT_WITH_KEYWORD -> ComputeSignature.FILTER_IN_COMPUTE; + }; + testPushQuery(value, esqlQuery, List.of(luceneQuery), dataNodeSignature, true); + } + enum ComputeSignature { FILTER_IN_COMPUTE( matchesList().item("LuceneSourceOperator") diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java index 68c606f2e3fa2..e8153673adb89 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java @@ -71,7 +71,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.oneOf; @@ -466,34 +465,6 @@ public void assertDriverData(Map driverMetadata, Map) driverSliceArgs.get("operators")), not(empty())); } - public void testProfileOrdinalsGroupingOperator() throws IOException { - assumeTrue("requires pragmas", Build.current().isSnapshot()); - indexTimestampData(1); - - RequestObjectBuilder builder = requestObjectBuilder().query(fromIndex() + " | STATS AVG(value) BY test.keyword"); - builder.profile(true); - // Lock to shard level partitioning, so we get consistent profile output - builder.pragmas(Settings.builder().put("data_partitioning", "shard").build()); - Map result = runEsql(builder); - - List> signatures = new ArrayList<>(); - @SuppressWarnings("unchecked") - List> profiles = (List>) ((Map) result.get("profile")).get("drivers"); - for (Map p : profiles) { - fixTypesOnProfile(p); - assertThat(p, commonProfile()); - List sig = new ArrayList<>(); - @SuppressWarnings("unchecked") - List> operators = (List>) p.get("operators"); - for (Map o : operators) { - sig.add((String) o.get("operator")); - } - signatures.add(sig); - } - - assertThat(signatures, hasItem(hasItem("OrdinalsGroupingOperator[aggregators=[\"sum of longs\", \"count\"]]"))); - } - @AwaitsFix(bugUrl = "disabled until JOIN infrastructrure properly lands") public void testInlineStatsProfile() throws IOException { assumeTrue("INLINESTATS only available on snapshots", Build.current().isSnapshot()); diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/StoredFieldsSequentialIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/StoredFieldsSequentialIT.java index df4444f5a1e47..027bf3313e661 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/StoredFieldsSequentialIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/StoredFieldsSequentialIT.java @@ -195,6 +195,15 @@ public void buildIndex() throws IOException { bulk.setJsonEntity(b.toString()); Response bulkResponse = client().performRequest(bulk); assertThat(entityToMap(bulkResponse.getEntity(), XContentType.JSON), matchesMap().entry("errors", false).extraOk()); + + // Forcemerge to one segment to get more consistent results. + Request forcemerge = new Request("POST", "/_forcemerge"); + forcemerge.addParameter("max_num_segments", "1"); + Response forcemergeResponse = client().performRequest(forcemerge); + assertThat( + entityToMap(forcemergeResponse.getEntity(), XContentType.JSON), + matchesMap().entry("_shards", matchesMap().entry("failed", 0).entry("successful", greaterThanOrEqualTo(1)).extraOk()).extraOk() + ); } @Override diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java index 5687504da487e..ef02d4a1f8c98 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java @@ -107,10 +107,6 @@ public static List policiesOnKeyword(List Set.of("languages_policy").contains(x.policyName())).toList(); } - public static String randomNonVector(List previousOutput) { - return randomName(previousOutput.stream().filter(x -> x.type().contains("vector") == false).toList()); - } - public static String randomName(List previousOutput) { String result = randomRawName(previousOutput); if (result == null) { @@ -292,7 +288,9 @@ public static boolean fieldCanBeUsed(Column field) { // https://github.com/elastic/elasticsearch/issues/121741 field.name().equals("") // this is a known pathological case, no need to test it for now - || field.name().equals("")) == false; + || field.name().equals("") + // no dense vectors for now, they are not supported in most commands + || field.type().contains("vector")) == false; } public static String unquote(String colName) { diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java index 56fc925ed8421..8fe9477e1714a 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java @@ -40,7 +40,7 @@ public abstract class GenerativeRestTest extends ESRestTestCase { "Cannot use field \\[.*\\] due to ambiguities", "cannot sort on .*", "argument of \\[count.*\\] must", - "Cannot use field \\[.*\\] with unsupported type \\[.*_range\\]", + "Cannot use field \\[.*\\] with unsupported type \\[.*\\]", "Unbounded sort not supported yet", "The field names are too complex to process", // field_caps problem "must be \\[any type except counter types\\]", // TODO refine the generation of count() @@ -48,12 +48,11 @@ public abstract class GenerativeRestTest extends ESRestTestCase { // Awaiting fixes for query failure "Unknown column \\[\\]", // https://github.com/elastic/elasticsearch/issues/121741, "Plan \\[ProjectExec\\[\\[.* optimized incorrectly due to missing references", // https://github.com/elastic/elasticsearch/issues/125866 - "optimized incorrectly due to missing references", // https://github.com/elastic/elasticsearch/issues/116781 "The incoming YAML document exceeds the limit:", // still to investigate, but it seems to be specific to the test framework "Data too large", // Circuit breaker exceptions eg. https://github.com/elastic/elasticsearch/issues/130072 + "optimized incorrectly due to missing references", // https://github.com/elastic/elasticsearch/issues/131509 // Awaiting fixes for correctness - "Expecting the following columns \\[.*\\], got", // https://github.com/elastic/elasticsearch/issues/129000 "Expecting at most \\[.*\\] columns, got \\[.*\\]" // https://github.com/elastic/elasticsearch/issues/129561 ); diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/MvExpandGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/MvExpandGenerator.java index 53cec0eabd652..317a2e459094e 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/MvExpandGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/MvExpandGenerator.java @@ -25,7 +25,7 @@ public CommandDescription generate( List previousOutput, QuerySchema schema ) { - String toExpand = EsqlQueryGenerator.randomNonVector(previousOutput); + String toExpand = EsqlQueryGenerator.randomName(previousOutput); if (toExpand == null) { return EMPTY_DESCRIPTION; // no columns to expand } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index e8acabe71ab41..ed15caa17ad3d 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -290,12 +290,12 @@ public long count(FieldName field, BytesRef value) { } @Override - public byte[] min(FieldName field, DataType dataType) { + public Object min(FieldName field) { return null; } @Override - public byte[] max(FieldName field, DataType dataType) { + public Object max(FieldName field) { return null; } @@ -381,6 +381,27 @@ public String toString() { } } + public static class TestSearchStatsWithMinMax extends TestSearchStats { + + private final Map minValues; + private final Map maxValues; + + public TestSearchStatsWithMinMax(Map minValues, Map maxValues) { + this.minValues = minValues; + this.maxValues = maxValues; + } + + @Override + public Object min(FieldName field) { + return minValues.get(field.string()); + } + + @Override + public Object max(FieldName field) { + return maxValues.get(field.string()); + } + } + public static final TestSearchStats TEST_SEARCH_STATS = new TestSearchStats(); private static final Map> TABLES = tables(); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec index 49b16baf30f58..334214016ae67 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec @@ -145,6 +145,302 @@ AVG(salary):double | bucket:date // end::bucket_in_agg-result[] ; +bucketMonthWithEmpty#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1986-01-01T00:00:00Z" +| STATS BY bucket = BUCKET(hire_date, 20, "1985-01-01T00:00:00Z", "1986-01-01T00:00:00Z", true) +| SORT bucket +; + +bucket:datetime +1985-01-01T00:00:00.000Z +1985-02-01T00:00:00.000Z +1985-03-01T00:00:00.000Z +1985-04-01T00:00:00.000Z +1985-05-01T00:00:00.000Z +1985-06-01T00:00:00.000Z +1985-07-01T00:00:00.000Z +1985-08-01T00:00:00.000Z +1985-09-01T00:00:00.000Z +1985-10-01T00:00:00.000Z +1985-11-01T00:00:00.000Z +1985-12-01T00:00:00.000Z +; + +bucketHeightWithEmpty#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1986-01-01T00:00:00Z" +| STATS BY heightBucket = ROUND(BUCKET(height, 10, 1.0, 2.0, true), 1) +| SORT heightBucket +; + +heightBucket:double +1.0 +1.1 +1.2 +1.3 +1.4 +1.5 +1.6 +1.7 +1.8 +1.9 +2.0 +; + +bucketsWithEmptyYear#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1980-01-01T00:00:00Z" AND hire_date < "1990-01-01T00:00:00Z" +| STATS BY + yearBucket = BUCKET(hire_date, 1 year, "1980-01-01T00:00:00Z", "1990-01-01T00:00:00Z", true), + heightBucket = ROUND(BUCKET(height, 10, 1.0, 2.0), 1) +| SORT yearBucket, heightBucket +; + +yearBucket:datetime | heightBucket:double +1980-01-01T00:00:00.000Z | null +1981-01-01T00:00:00.000Z | null +1982-01-01T00:00:00.000Z | null +1983-01-01T00:00:00.000Z | null +1984-01-01T00:00:00.000Z | null +1985-01-01T00:00:00.000Z | 1.4 +1985-01-01T00:00:00.000Z | 1.7 +1985-01-01T00:00:00.000Z | 1.8 +1985-01-01T00:00:00.000Z | 1.9 +1985-01-01T00:00:00.000Z | 2.0 +1985-01-01T00:00:00.000Z | null +1986-01-01T00:00:00.000Z | 1.4 +1986-01-01T00:00:00.000Z | 1.5 +1986-01-01T00:00:00.000Z | 1.7 +1986-01-01T00:00:00.000Z | 1.8 +1986-01-01T00:00:00.000Z | 2.0 +1986-01-01T00:00:00.000Z | 2.1 +1986-01-01T00:00:00.000Z | null +1987-01-01T00:00:00.000Z | 1.4 +1987-01-01T00:00:00.000Z | 1.5 +1987-01-01T00:00:00.000Z | 1.6 +1987-01-01T00:00:00.000Z | 1.7 +1987-01-01T00:00:00.000Z | 1.8 +1987-01-01T00:00:00.000Z | 1.9 +1987-01-01T00:00:00.000Z | 2.0 +1987-01-01T00:00:00.000Z | 2.1 +1987-01-01T00:00:00.000Z | null +1988-01-01T00:00:00.000Z | 1.4 +1988-01-01T00:00:00.000Z | 1.5 +1988-01-01T00:00:00.000Z | 1.7 +1988-01-01T00:00:00.000Z | 1.8 +1988-01-01T00:00:00.000Z | 1.9 +1988-01-01T00:00:00.000Z | null +1989-01-01T00:00:00.000Z | 1.5 +1989-01-01T00:00:00.000Z | 1.7 +1989-01-01T00:00:00.000Z | 2.0 +1989-01-01T00:00:00.000Z | null +; + +bucketsWithEmptyHeight#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1990-01-01T00:00:00Z" +| STATS BY + yearBucket = BUCKET(hire_date, 1 year, "1985-01-01T00:00:00Z", "1990-01-01T00:00:00Z"), + heightBucket = ROUND(BUCKET(height, 10, 1.0, 2.0, true), 1) +| SORT yearBucket, heightBucket +; + +yearBucket:datetime | heightBucket:double +1985-01-01T00:00:00.000Z | 1.4 +1985-01-01T00:00:00.000Z | 1.7 +1985-01-01T00:00:00.000Z | 1.8 +1985-01-01T00:00:00.000Z | 1.9 +1985-01-01T00:00:00.000Z | 2.0 +1986-01-01T00:00:00.000Z | 1.4 +1986-01-01T00:00:00.000Z | 1.5 +1986-01-01T00:00:00.000Z | 1.7 +1986-01-01T00:00:00.000Z | 1.8 +1986-01-01T00:00:00.000Z | 2.0 +1986-01-01T00:00:00.000Z | 2.1 +1987-01-01T00:00:00.000Z | 1.4 +1987-01-01T00:00:00.000Z | 1.5 +1987-01-01T00:00:00.000Z | 1.6 +1987-01-01T00:00:00.000Z | 1.7 +1987-01-01T00:00:00.000Z | 1.8 +1987-01-01T00:00:00.000Z | 1.9 +1987-01-01T00:00:00.000Z | 2.0 +1987-01-01T00:00:00.000Z | 2.1 +1988-01-01T00:00:00.000Z | 1.4 +1988-01-01T00:00:00.000Z | 1.5 +1988-01-01T00:00:00.000Z | 1.7 +1988-01-01T00:00:00.000Z | 1.8 +1988-01-01T00:00:00.000Z | 1.9 +1989-01-01T00:00:00.000Z | 1.5 +1989-01-01T00:00:00.000Z | 1.7 +1989-01-01T00:00:00.000Z | 2.0 +null | 1.0 +null | 1.1 +null | 1.2 +null | 1.3 +null | 1.4 +null | 1.5 +null | 1.6 +null | 1.7 +null | 1.8 +null | 1.9 +; + + + + +bucketMonthInAggWithEmpty#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1986-01-01T00:00:00Z" +| STATS MAX(salary) BY bucket = BUCKET(hire_date, 20, "1985-01-01T00:00:00Z", "1986-01-01T00:00:00Z", true) +| SORT bucket +; + +MAX(salary):integer | bucket:datetime +null | 1985-01-01T00:00:00.000Z +66174 | 1985-02-01T00:00:00.000Z +null | 1985-03-01T00:00:00.000Z +null | 1985-04-01T00:00:00.000Z +44817 | 1985-05-01T00:00:00.000Z +null | 1985-06-01T00:00:00.000Z +62405 | 1985-07-01T00:00:00.000Z +null | 1985-08-01T00:00:00.000Z +49095 | 1985-09-01T00:00:00.000Z +54329 | 1985-10-01T00:00:00.000Z +74999 | 1985-11-01T00:00:00.000Z +null | 1985-12-01T00:00:00.000Z +; + +bucketMonthInAggWithEmpty2#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1986-01-01T00:00:00Z" +| STATS MAX(salary) BY bucket = BUCKET(hire_date, 1 month, "1985-01-01T00:00:00Z", "1986-01-01T00:00:00Z", true) +| SORT bucket +; + +MAX(salary):integer | bucket:datetime +null | 1985-01-01T00:00:00.000Z +66174 | 1985-02-01T00:00:00.000Z +null | 1985-03-01T00:00:00.000Z +null | 1985-04-01T00:00:00.000Z +44817 | 1985-05-01T00:00:00.000Z +null | 1985-06-01T00:00:00.000Z +62405 | 1985-07-01T00:00:00.000Z +null | 1985-08-01T00:00:00.000Z +49095 | 1985-09-01T00:00:00.000Z +54329 | 1985-10-01T00:00:00.000Z +74999 | 1985-11-01T00:00:00.000Z +null | 1985-12-01T00:00:00.000Z +; + +bucketMonthInAggWithEmpty3#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1986-01-01T00:00:00Z" +| STATS MAX(salary) BY bucket = BUCKET(hire_date, 1 month, "1985-01-01T00:00:00Z", "1987-01-01T00:00:00Z", true) +| SORT bucket +; + +MAX(salary):integer | bucket:datetime +null | 1985-01-01T00:00:00.000Z +66174 | 1985-02-01T00:00:00.000Z +null | 1985-03-01T00:00:00.000Z +null | 1985-04-01T00:00:00.000Z +44817 | 1985-05-01T00:00:00.000Z +null | 1985-06-01T00:00:00.000Z +62405 | 1985-07-01T00:00:00.000Z +null | 1985-08-01T00:00:00.000Z +49095 | 1985-09-01T00:00:00.000Z +54329 | 1985-10-01T00:00:00.000Z +74999 | 1985-11-01T00:00:00.000Z +null | 1985-12-01T00:00:00.000Z +null | 1986-01-01T00:00:00.000Z +null | 1986-02-01T00:00:00.000Z +null | 1986-03-01T00:00:00.000Z +null | 1986-04-01T00:00:00.000Z +null | 1986-05-01T00:00:00.000Z +null | 1986-06-01T00:00:00.000Z +null | 1986-07-01T00:00:00.000Z +null | 1986-08-01T00:00:00.000Z +null | 1986-09-01T00:00:00.000Z +null | 1986-10-01T00:00:00.000Z +null | 1986-11-01T00:00:00.000Z +null | 1986-12-01T00:00:00.000Z +; + +bucketMonthInAggWithEmpty4#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1989-01-01T00:00:00Z" +| STATS MAX(salary) BY bucket = BUCKET(hire_date, 1 month, "1985-01-01T00:00:00Z", "1987-01-01T00:00:00Z", true) +| SORT bucket +; + +MAX(salary):integer | bucket:datetime +null | 1985-01-01T00:00:00.000Z +66174 | 1985-02-01T00:00:00.000Z +null | 1985-03-01T00:00:00.000Z +null | 1985-04-01T00:00:00.000Z +44817 | 1985-05-01T00:00:00.000Z +null | 1985-06-01T00:00:00.000Z +62405 | 1985-07-01T00:00:00.000Z +null | 1985-08-01T00:00:00.000Z +49095 | 1985-09-01T00:00:00.000Z +54329 | 1985-10-01T00:00:00.000Z +74999 | 1985-11-01T00:00:00.000Z +null | 1985-12-01T00:00:00.000Z +null | 1986-01-01T00:00:00.000Z +54462 | 1986-02-01T00:00:00.000Z +44956 | 1986-03-01T00:00:00.000Z +null | 1986-04-01T00:00:00.000Z +null | 1986-05-01T00:00:00.000Z +57305 | 1986-06-01T00:00:00.000Z +37702 | 1986-07-01T00:00:00.000Z +61805 | 1986-08-01T00:00:00.000Z +32272 | 1986-09-01T00:00:00.000Z +50128 | 1986-10-01T00:00:00.000Z +null | 1986-11-01T00:00:00.000Z +36174 | 1986-12-01T00:00:00.000Z +70011 | 1987-03-01T00:00:00.000Z +66817 | 1987-04-01T00:00:00.000Z +69904 | 1987-05-01T00:00:00.000Z +25324 | 1987-07-01T00:00:00.000Z +47411 | 1987-08-01T00:00:00.000Z +68431 | 1987-09-01T00:00:00.000Z +40612 | 1987-10-01T00:00:00.000Z +29175 | 1987-11-01T00:00:00.000Z +36051 | 1988-01-01T00:00:00.000Z +60408 | 1988-02-01T00:00:00.000Z +55360 | 1988-05-01T00:00:00.000Z +54518 | 1988-07-01T00:00:00.000Z +39878 | 1988-09-01T00:00:00.000Z +73578 | 1988-10-01T00:00:00.000Z +; + +bucketMonthInAggsWithEmpty#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "1985-01-01T00:00:00Z" AND hire_date < "1986-01-01T00:00:00Z" +| STATS MIN(salary), MAX(salary), AVG(salary) BY bucket = BUCKET(hire_date, 20, "1985-01-01T00:00:00Z", "1986-01-01T00:00:00Z", true) +| SORT bucket +; +warningRegex:evaluation of \[AVG\(salary\)\] failed, treating result as null. Only first 20 failures recorded +warningRegex:java.lang.ArithmeticException: / by zero + +// tag::bucket_in_aggs_with_empty-result[] +MIN(salary):integer | MAX(salary):integer | AVG(salary):double | bucket:datetime +null | null | null | 1985-01-01T00:00:00.000Z +26436 | 66174 | 46305.0 | 1985-02-01T00:00:00.000Z +null | null | null | 1985-03-01T00:00:00.000Z +null | null | null | 1985-04-01T00:00:00.000Z +44817 | 44817 | 44817.0 | 1985-05-01T00:00:00.000Z +null | null | null | 1985-06-01T00:00:00.000Z +62405 | 62405 | 62405.0 | 1985-07-01T00:00:00.000Z +null | null | null | 1985-08-01T00:00:00.000Z +49095 | 49095 | 49095.0 | 1985-09-01T00:00:00.000Z +48735 | 54329 | 51532.0 | 1985-10-01T00:00:00.000Z +33956 | 74999 | 54539.75 | 1985-11-01T00:00:00.000Z +null | null | null | 1985-12-01T00:00:00.000Z +// end::bucket_in_aggs_with_empty-result[] +; + bucketWithOffset#[skip:-8.13.99, reason:BUCKET renamed in 8.14] // tag::bucketWithOffset[] FROM employees diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index 7168ca3dc398f..be46e68a8b08a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -397,7 +397,7 @@ FROM sample_data ; COUNT():long | SUM(event_duration):long | category:keyword - 7 | 23231327 | null + 7 | 23231327 | null ; on null row @@ -800,3 +800,82 @@ COUNT():long | VALUES(str):keyword | category:keyword | str:keyword 1 | [a, b, c] | null | b 1 | [a, b, c] | null | c ; + +with option output_format regex +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"output_format": "regex"}) + | SORT count DESC, category +; + +count:long | category:keyword + 3 | .*?Connected.+?to.*? + 3 | .*?Connection.+?error.*? + 1 | .*?Disconnected.*? +; + +with option output_format tokens +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"output_format": "tokens"}) + | SORT count DESC, category +; + +count:long | category:keyword + 3 | Connected to + 3 | Connection error + 1 | Disconnected +; + +with option similarity_threshold +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"similarity_threshold": 99}) + | SORT count DESC, category +; + +count:long | category:keyword +3 | .*?Connection.+?error.*? +1 | .*?Connected.+?to.+?10\.1\.0\.1.*? +1 | .*?Connected.+?to.+?10\.1\.0\.2.*? +1 | .*?Connected.+?to.+?10\.1\.0\.3.*? +1 | .*?Disconnected.*? +; + +with option analyzer +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"analyzer": "stop"}) + | SORT count DESC, category +; + +count:long | category:keyword +3 | .*?connected.*? +3 | .*?connection.+?error.*? +1 | .*?disconnected.*? +; + +with all options +required_capability: categorize_options + +FROM sample_data + | STATS count=COUNT() + BY category=CATEGORIZE(message, {"analyzer": "whitespace", "similarity_threshold": 100, "output_format": "tokens"}) + | SORT count DESC, category +; + +count:long | category:keyword +3 | Connection error +1 | Connected to 10.1.0.1 +1 | Connected to 10.1.0.2 +1 | Connected to 10.1.0.3 +1 | Disconnected +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec index 9bfb08eb82b45..2aa6189a957ec 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec @@ -661,3 +661,104 @@ from * author.keyword:keyword|book_no:keyword|scalerank:integer|street:keyword|bytes_in:ul|@timestamp:unsupported|abbrev:keyword|city_location:geo_point|distance:double|description:unsupported|birth_date:date|language_code:integer|intersects:boolean|client_ip:unsupported|event_duration:long|version:version|language_name:keyword Fyodor Dostoevsky |1211 |null |null |null |null |null |null |null |null |null |null |null |null |null |null |null ; + + +statsAfterRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1", "Connected to 10.1.0.2") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| STATS messages = count_distinct(message) BY language_name +; + +messages:long | language_name:keyword +2 | English +; + + +enrichAfterRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +coordinatorEnrichAfterRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH _coordinator:languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +doubleRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH _remote:languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +enrichAfterCoordinatorEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _coordinator:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +doubleCoordinatorEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _coordinator:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH _coordinator:languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index c6105b82f2300..ce8061534ddbb 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,7 +3,7 @@ # top-n query at the shard level knnSearch -required_capability: knn_function_v2 +required_capability: knn_function_v3 // tag::knn-function[] from colors metadata _score @@ -29,9 +29,8 @@ chartreuse | [127.0, 255.0, 0.0] // end::knn-function-result[] ; -# https://github.com/elastic/elasticsearch/issues/129550 - Add as an example to knn function documentation -knnSearchWithSimilarityOption-Ignore -required_capability: knn_function_v2 +knnSearchWithSimilarityOption +required_capability: knn_function_v3 from colors metadata _score | where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) @@ -47,14 +46,13 @@ wheat | [245.0, 222.0, 179.0] ; knnHybridSearch -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score -| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140) +| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10) | where primary == true | sort _score desc, color asc | keep color, rgb_vector -| limit 10 ; color:text | rgb_vector:dense_vector @@ -69,21 +67,45 @@ red | [255.0, 0.0, 0.0] yellow | [255.0, 255.0, 0.0] ; -knnWithMultipleFunctions -required_capability: knn_function_v2 +knnWithPrefilter +required_capability: knn_function_v3 from colors metadata _score -| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive") +| where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green")) | sort _score desc, color asc | keep color, rgb_vector ; color:text | rgb_vector:dense_vector olive | [128.0, 128.0, 0.0] +green | [0.0, 128.0, 0.0] +; + +knnWithNegatedPrefilter +required_capability: knn_function_v3 + +from colors metadata _score +| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) +| sort _score desc, color asc +| keep color, rgb_vector +| LIMIT 10 +; + +color:text | rgb_vector:dense_vector +sienna | [160.0, 82.0, 45.0] +peru | [205.0, 133.0, 63.0] +golden rod | [218.0, 165.0, 32.0] +brown | [165.0, 42.0, 42.0] +firebrick | [178.0, 34.0, 34.0] +chartreuse | [127.0, 255.0, 0.0] +gray | [128.0, 128.0, 128.0] +green | [0.0, 128.0, 0.0] +maroon | [128.0, 0.0, 0.0] +orange | [255.0, 165.0, 0.0] ; knnAfterKeep -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | keep rgb_vector, color, _score @@ -102,7 +124,7 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | drop primary @@ -121,7 +143,7 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -140,14 +162,12 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function_v2 +required_capability: knn_function_v3 -# TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*" +| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" | sort _score desc, color asc | keep color, hex_code, rgb_vector -| limit 10 ; color:text | hex_code:keyword | rgb_vector:dense_vector @@ -161,11 +181,10 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function_v2 +required_capability: knn_function_v3 -# TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true +| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true | keep color, rgb_vector, _score | sort _score desc, color asc | drop _score @@ -184,8 +203,31 @@ red | [255.0, 0.0, 0.0] yellow | [255.0, 255.0, 0.0] ; +knnWithNegationsAndFiltersConjunction +required_capability: knn_function_v3 + +from colors metadata _score +| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) +| sort _score desc, color asc +| keep color, rgb_vector +| limit 10 +; + +color:text | rgb_vector:dense_vector +cyan | [0.0, 255.0, 255.0] +turquoise | [64.0, 224.0, 208.0] +aqua marine | [127.0, 255.0, 212.0] +teal | [0.0, 128.0, 128.0] +silver | [192.0, 192.0, 192.0] +gray | [128.0, 128.0, 128.0] +gainsboro | [220.0, 220.0, 220.0] +thistle | [216.0, 191.0, 216.0] +lavender | [230.0, 230.0, 250.0] +azure | [240.0, 255.0, 255.0] +; + knnWithNonPushableConjunction -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -208,9 +250,8 @@ green | false maroon | false ; -# https://github.com/elastic/elasticsearch/issues/129550 -testKnnWithNonPushableDisjunctions-Ignore -required_capability: knn_function_v2 +testKnnWithNonPushableDisjunctions +required_capability: knn_function_v3 from colors metadata _score | where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 @@ -225,9 +266,8 @@ lemon chiffon papaya whip ; -# https://github.com/elastic/elasticsearch/issues/129550 -testKnnWithNonPushableDisjunctionsOnComplexExpressions-Ignore -required_capability: knn_function_v2 +testKnnWithNonPushableDisjunctionsOnComplexExpressions +required_capability: knn_function_v3 from colors metadata _score | where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) @@ -242,7 +282,7 @@ indigo | false ; testKnnInStatsNonPushable -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors | where length(color) < 10 @@ -254,7 +294,7 @@ c: long ; testKnnInStatsWithGrouping -required_capability: knn_function_v2 +required_capability: knn_function_v3 required_capability: full_text_functions_in_stats_where from colors diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec index bdf0413a03d02..c71bf34cafd1a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec @@ -4773,3 +4773,101 @@ FROM sample_data_ts_nanos 2023-10-23T12:27:28.948123456Z | 172.21.2.113 | 2764889 | Connected to 10.1.0.2 2023-10-23T12:15:03.360123456Z | 172.21.2.162 | 3450233 | Connected to 10.1.0.3 ; + +############################################### +# LOOKUP JOIN and ENRICH +############################################### + +enrichAfterLookupJoin +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| LOOKUP JOIN message_types_lookup ON message +| ENRICH languages_policy ON language_code +; + +message:keyword | language_code:keyword | type:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | Success | English +; + + +lookupJoinAfterEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| ENRICH languages_policy ON language_code +| LOOKUP JOIN message_types_lookup ON message +; + +message:keyword | language_code:keyword | language_name:keyword | type:keyword +Connected to 10.1.0.1 | 1 | English | Success +; + + +lookupJoinAfterRemoteEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| LOOKUP JOIN message_types_lookup ON message +; + +message:keyword | language_code:keyword | language_name:keyword | type:keyword +Connected to 10.1.0.1 | 1 | English | Success +; + + +lookupJoinAfterLimitAndRemoteEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| LIMIT 1 +| ENRICH _remote:languages_policy ON language_code +| EVAL enrich_language_name = language_name, language_code = language_code::integer +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| KEEP message, enrich_language_name, language_name, country.keyword +| SORT language_name, country.keyword +; + +message:keyword | enrich_language_name:keyword | language_name:keyword | country.keyword:keyword +Connected to 10.1.0.1 | English | English | Canada +Connected to 10.1.0.1 | English | English | United States of America +Connected to 10.1.0.1 | English | English | null +Connected to 10.1.0.1 | English | null | United Kingdom +; + + +lookupJoinAfterTopNAndRemoteEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| SORT message +| LIMIT 1 +| ENRICH _remote:languages_policy ON language_code +| EVAL enrich_language_name = language_name, language_code = language_code::integer +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| KEEP message, enrich_language_name, language_name, country.keyword +| SORT language_name, country.keyword +; + +message:keyword | enrich_language_name:keyword | language_name:keyword | country.keyword:keyword +Connected to 10.1.0.1 | English | English | Canada +Connected to 10.1.0.1 | English | English | United States of America +Connected to 10.1.0.1 | English | English | null +Connected to 10.1.0.1 | English | null | United Kingdom +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_expand.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_expand.csv-spec index 20ce3ecc5a396..c7dbe01ef6f09 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_expand.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_expand.csv-spec @@ -419,3 +419,65 @@ emp_no:integer | job_positions:keyword 10001 | Accountant 10001 | Senior Python Developer ; + +testMvExpandInconsistentColumnOrder1 +required_capability: fix_mv_expand_inconsistent_column_order +from message_types +| eval foo_1 = 1, foo_2 = 2 +| sort message +| mv_expand foo_1 +; + +message:keyword | type:keyword | foo_1:integer | foo_2:integer +Connected to 10.1.0.1 | Success | 1 | 2 +Connected to 10.1.0.2 | Success | 1 | 2 +Connected to 10.1.0.3 | Success | 1 | 2 +Connection error | Error | 1 | 2 +Development environment | Development | 1 | 2 +Disconnected | Disconnected | 1 | 2 +Production environment | Production | 1 | 2 +; + +testMvExpandInconsistentColumnOrder2 +required_capability: fix_mv_expand_inconsistent_column_order +from message_types +| eval foo_1 = [1, 3], foo_2 = 2 +| sort message +| mv_expand foo_1 +; + +message:keyword | type:keyword | foo_1:integer | foo_2:integer +Connected to 10.1.0.1 | Success | 1 | 2 +Connected to 10.1.0.1 | Success | 3 | 2 +Connected to 10.1.0.2 | Success | 1 | 2 +Connected to 10.1.0.2 | Success | 3 | 2 +Connected to 10.1.0.3 | Success | 1 | 2 +Connected to 10.1.0.3 | Success | 3 | 2 +Connection error | Error | 1 | 2 +Connection error | Error | 3 | 2 +Development environment | Development | 1 | 2 +Development environment | Development | 3 | 2 +Disconnected | Disconnected | 1 | 2 +Disconnected | Disconnected | 3 | 2 +Production environment | Production | 1 | 2 +Production environment | Production | 3 | 2 +; + +testMvExpandInconsistentColumnOrder3 +required_capability: fix_mv_expand_inconsistent_column_order +from message_types +| sort type +| eval language_code = 1, `language_name` = false, message = true, foo_3 = 1, foo_2 = null +| eval foo_3 = "1", `foo_3` = -1, foo_1 = 1, `language_code` = null, `foo_2` = "1" +| mv_expand foo_1 +| limit 5 +; + +type:keyword | language_name:boolean | message:boolean | foo_3:integer | foo_1:integer | language_code:null | foo_2:keyword +Development | false | true | -1 | 1 | null | 1 +Disconnected | false | true | -1 | 1 | null | 1 +Error | false | true | -1 | 1 | null | 1 +Production | false | true | -1 | 1 | null | 1 +Success | false | true | -1 | 1 | null | 1 +; + diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/score-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/score-function.csv-spec new file mode 100644 index 0000000000000..1a39418e9a28d --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/score-function.csv-spec @@ -0,0 +1,127 @@ +############################################### +# Tests for Score function +# + +scoreSingle +required_capability: metadata_score +required_capability: score_function +required_capability: match_function + +// tag::score-function[] +FROM books METADATA _score +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL first_score = score(match(title, "Return")) +// end::score-function[] +| KEEP book_no, title, _score, first_score +| SORT book_no +| LIMIT 5 +; + +// tag::score-single-result[] +book_no:keyword | title:text | _score:double | first_score:double +2714 | Return of the King Being the Third Part of The Lord of the Rings | 3.1309072971343994 | 1.9245924949645996 +7350 | Return of the Shadow | 4.8434343338012695 | 3.5432329177856445 +// end::score-single-result[] +; + +scoreSingleNoMetadata +required_capability: score_function +required_capability: match_function + +FROM books +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL first_score = score(match(title, "Return")) +| KEEP book_no, title, first_score +| SORT book_no +| LIMIT 5 +; + +book_no:keyword | title:text | first_score:double +2714 | Return of the King Being the Third Part of The Lord of the Rings | 1.9245924949645996 +7350 | Return of the Shadow | 3.5432329177856445 +; + +scoreAfterEval +required_capability: score_function +required_capability: metadata_score +required_capability: match_function + +FROM books METADATA _score +| EVAL stars = to_long(ratings / 2.0) +| EVAL s1 = score(match(author, "William")) +| WHERE match(author, "Faulkner") +| SORT book_no +| KEEP book_no, author, stars, s1 +| limit 5; + +book_no:keyword | author:text | stars:long | s1:double +2378 | [Carol Faulkner, Holly Byers Ochoa, Lucretia Mott] | 3 | 0.0 +2713 | William Faulkner | 2 | 1.9043500423431396 +2847 | Colleen Faulkner | 3 | 0.0 +2883 | William Faulkner | 2 | 1.9043500423431396 +3293 | Danny Faulkner | 2 | 0.0 +; + +scoreMatchWithFilterConjunction +required_capability: score_function +required_capability: match_function + +FROM books +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL s1 = score(match(title, "Rings") and ratings > 4.6) +| KEEP book_no, title, s1 +| SORT book_no +| LIMIT 5; + +book_no:keyword | title:text | s1:double +2714 | Return of the King Being the Third Part of The Lord of the Rings | 1.9245924949645996 +7350 | Return of the Shadow | 0.0 +; + +scoreMatchWithDisjunction +required_capability: score_function +required_capability: match_function + +FROM books +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL s1 = score(match(title, "Rings") or match(title, "Shadow")) +| KEEP book_no, title, s1 +| SORT book_no +| LIMIT 5; + +book_no:keyword | title:text | s1:double +2714 | Return of the King Being the Third Part of The Lord of the Rings | 1.9245924949645996 +7350 | Return of the Shadow | 3.5432329177856445 +; + +scoreMatchWithDisjunctionAndFilter +required_capability: score_function +required_capability: match_function + +FROM books +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL s1 = score(match(title, "Rings") or match(title, "Shadow") and ratings > 4.6) +| KEEP book_no, title, s1 +| SORT book_no +| LIMIT 5; + +book_no:keyword | title:text | s1:double +2714 | Return of the King Being the Third Part of The Lord of the Rings | 1.9245924949645996 +7350 | Return of the Shadow | 3.5432329177856445 +; + +scoreMatchDisjunctionNonPushable +required_capability: score_function +required_capability: match_function + +FROM books +| WHERE match(title, "Return") AND match(author, "Tolkien") +| EVAL s1 = score(match(title, "Rings") or ratings > 4.6) +| KEEP book_no, title, s1 +| SORT book_no +| LIMIT 5; + +book_no:keyword | title:text | s1:double +2714 | Return of the King Being the Third Part of The Lord of the Rings | 1.9245924949645996 +7350 | Return of the Shadow | 0.0 +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unmapped_fields.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unmapped_fields.csv-spec index a0828ff628a6d..c2d02d8d60d51 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unmapped_fields.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unmapped_fields.csv-spec @@ -495,8 +495,8 @@ required_capability: source_field_mapping required_capability: unmapped_fields FROM partial_mapping_sample_data,partial_mapping_excluded_source_sample_data METADATA _index | INSIST_🐔 message -| SORT message, @timestamp | STATS max(@timestamp), count(*) BY message +| SORT message NULLS FIRST ; max(@timestamp):date | count(*):long | message:keyword diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec new file mode 100644 index 0000000000000..d9e1ff408c739 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec @@ -0,0 +1,93 @@ + # Tests for cosine similarity function + + similarityWithVectorField + required_capability: cosine_vector_similarity_function + +// tag::vector-cosine-similarity[] + from colors + | where color != "black" + | eval similarity = v_cosine(rgb_vector, [0, 255, 255]) + | sort similarity desc, color asc +// end::vector-cosine-similarity[] + | limit 10 + | keep color, similarity + ; + +// tag::vector-cosine-similarity-result[] +color:text | similarity:double +cyan | 1.0 +teal | 1.0 +turquoise | 0.9890533685684204 +aqua marine | 0.964962363243103 +azure | 0.916246771812439 +lavender | 0.9136701822280884 +mint cream | 0.9122757911682129 +honeydew | 0.9122424125671387 +gainsboro | 0.9082483053207397 +gray | 0.9082483053207397 +// end::vector-cosine-similarity-result[] +; + + similarityAsPartOfExpression + required_capability: cosine_vector_similarity_function + + from colors + | where color != "black" + | eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3) + | sort score desc, color asc + | limit 10 + | keep color, score + ; + +color:text | score:double +cyan | 1.5 +teal | 1.5 +turquoise | 1.495 +aqua marine | 1.482 +azure | 1.458 +lavender | 1.457 +honeydew | 1.456 +mint cream | 1.456 +gainsboro | 1.454 +gray | 1.454 +; + +similarityWithLiteralVectors +required_capability: cosine_vector_similarity_function + +row a = 1 +| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3) +| keep similarity +; + +similarity:double +0.978 +; + + similarityWithStats + required_capability: cosine_vector_similarity_function + + from colors + | where color != "black" + | eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3) + | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) + ; + +avg:double | min:double | max:double +0.832 | 0.5 | 1.0 +; + +# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector +similarityWithRow-Ignore +required_capability: cosine_vector_similarity_function + +row vector = [1, 2, 3] +| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3) +| sort similarity desc, color asc +| limit 10 +| keep color, similarity +; + +similarity:double +0.978 +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec index 6c5f13603e72b..85d69ff60d3d6 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec @@ -534,6 +534,458 @@ emp_no:integer | first_name:keyword 10055 | Georgy ; +rlikeListEmptyArgWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike ("") +| KEEP emp_no, first_name; + +emp_no:integer | first_name:keyword +; + +rlikeListSingleArgWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name RLIKE ("Eberhardt.*") +| KEEP emp_no, first_name; + +emp_no:integer | first_name:keyword +10013 | Eberhardt +; + +rlikeListTwoArgWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike ("Eberhardt.*", "testString.*") +| KEEP emp_no, first_name; + +emp_no:integer | first_name:keyword +10013 | Eberhardt +; + +rlikeListDocExample +required_capability: rlike_with_list_of_patterns +// tag::rlikeListDocExample[] +ROW message = "foobar" +| WHERE message RLIKE ("foo.*", "bar.") +// end::rlikeListDocExample[] +; + +message:string +foobar +; + +rlikeListThreeArgWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike ("Eberhardt.*", "Ot.*", "Part.") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10003 | Parto +10013 | Eberhardt +10029 | Otmar +; + +rlikeListMultipleWhere +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name RLIKE ("Eberhardt.*", "Ot.*", "Part.") +| WHERE first_name RLIKE ("Eberhard.", "Otm.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10013 | Eberhardt +10029 | Otmar +; + +rlikeListAllWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike (".*") +| KEEP emp_no, first_name +| SORT emp_no +| LIMIT 2; + +emp_no:integer | first_name:keyword +10001 | Georgi +10002 | Bezalel +; + +rlikeListOverlappingPatterns +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike ("Eber.*", "Eberhardt") +| KEEP emp_no, first_name; + +emp_no:integer | first_name:keyword +10013 | Eberhardt +; + +rlikeListCaseSensitive +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name RLIKE ("eberhardt", "EBERHARDT") +| KEEP emp_no, first_name; + +emp_no:integer | first_name:keyword +; + +rlikeListSpecialCharacters +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike (".*ar.*", ".*eor.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10001 | Georgi +10003 | Parto +10011 | Mary +10013 | Eberhardt +10029 | Otmar +10055 | Georgy +10058 | Berhard +10068 | Charlene +10069 | Margareta +10074 | Mokhtar +10082 | Parviz +10089 | Sudharsan +10095 | Hilari +; + +rlikeListEscapedWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike ("Eberhar\\*") +| KEEP emp_no, first_name; + +emp_no:integer | first_name:keyword +; + +rlikeListNineOrMoreLetters +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike (".{9,}.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10004 | Chirstian +10010 | Duangkaew +10013 | Eberhardt +10017 | Cristinel +10025 | Prasadram +10059 | Alejandro +10069 | Margareta +10089 | Sudharsan +10092 | Valdiodio +10098 | Sreekrishna +; + +notRlikeListThreeArgWildcardNotOtherFilter +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name not rlike ("Eberhardt.*", "Ot.*", "Part.") and emp_no < 10010 +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10001 | Georgi +10002 | Bezalel +10004 | Chirstian +10005 | Kyoichi +10006 | Anneke +10007 | Tzvetan +10008 | Saniya +10009 | Sumant +; + +rlikeListBeginningWithWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name rlike ("A.*", "B.*", "C.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10002 | Bezalel +10004 | Chirstian +10006 | Anneke +10014 | Berni +10017 | Cristinel +10023 | Bojan +10049 | Basil +10056 | Brendon +10058 | Berhard +10059 | Alejandro +10060 | Breannda +10062 | Anoosh +10067 | Claudi +10068 | Charlene +10091 | Amabile +10094 | Arumugam +; + +notRlikeListThreeArgWildcardOtherFirst +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE emp_no < 10010 and first_name not rlike ("Eberhardt.*", "Ot.*", "Part.") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10001 | Georgi +10002 | Bezalel +10004 | Chirstian +10005 | Kyoichi +10006 | Anneke +10007 | Tzvetan +10008 | Saniya +10009 | Sumant +; + +notRlikeFiveOrLessLetters +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name not rlike (".{6,}.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10003 | Parto +10011 | Mary +10014 | Berni +10021 | Ramzi +10023 | Bojan +10029 | Otmar +10040 | Weiyi +10041 | Uri +10042 | Magy +10045 | Moss +10049 | Basil +10057 | Ebbe +10061 | Tse +10063 | Gino +10064 | Udi +10066 | Kwee +10071 | Hisao +10073 | Shir +10075 | Gao +10076 | Erez +10077 | Mona +10078 | Danel +10083 | Vishv +10084 | Tuval +10097 | Remzi +; + +notRlikeListMultipleWhere +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name NOT RLIKE ("Eberhardt.*", "Ot.*", "Part.", "A.*", "B.*", "C.*", "D.*") +| WHERE first_name NOT RLIKE ("Eberhard.", "Otm.*", "F.*", "G.*", "H.*", "I.*", "J.*", "K.*", "L.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10007 | Tzvetan +10008 | Saniya +10009 | Sumant +10011 | Mary +10012 | Patricio +10020 | Mayuko +10021 | Ramzi +10022 | Shahaf +10024 | Suzette +10025 | Prasadram +10026 | Yongqiao +10040 | Weiyi +10041 | Uri +10042 | Magy +10043 | Yishay +10044 | Mingsen +10045 | Moss +10047 | Zvonko +10050 | Yinghua +10053 | Sanjiv +10054 | Mayumi +10057 | Ebbe +10061 | Tse +10064 | Udi +10065 | Satosi +10069 | Margareta +10070 | Reuven +10073 | Shir +10074 | Mokhtar +10076 | Erez +10077 | Mona +10080 | Premal +10081 | Zhongwei +10082 | Parviz +10083 | Vishv +10084 | Tuval +10086 | Somnath +10087 | Xinglin +10089 | Sudharsan +10092 | Valdiodio +10093 | Sailaja +10097 | Remzi +10098 | Sreekrishna +10099 | Valter +; + +notRlikeListNotField +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE NOT first_name RLIKE ("Eberhardt.*", "Ot.*", "Part.", "A.*", "B.*", "C.*", "D.*") +| WHERE first_name NOT RLIKE ("Eberhard.", "Otm.*", "F.*", "G.*", "H.*", "I.*", "J.*", "K.*", "L.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10007 | Tzvetan +10008 | Saniya +10009 | Sumant +10011 | Mary +10012 | Patricio +10020 | Mayuko +10021 | Ramzi +10022 | Shahaf +10024 | Suzette +10025 | Prasadram +10026 | Yongqiao +10040 | Weiyi +10041 | Uri +10042 | Magy +10043 | Yishay +10044 | Mingsen +10045 | Moss +10047 | Zvonko +10050 | Yinghua +10053 | Sanjiv +10054 | Mayumi +10057 | Ebbe +10061 | Tse +10064 | Udi +10065 | Satosi +10069 | Margareta +10070 | Reuven +10073 | Shir +10074 | Mokhtar +10076 | Erez +10077 | Mona +10080 | Premal +10081 | Zhongwei +10082 | Parviz +10083 | Vishv +10084 | Tuval +10086 | Somnath +10087 | Xinglin +10089 | Sudharsan +10092 | Valdiodio +10093 | Sailaja +10097 | Remzi +10098 | Sreekrishna +10099 | Valter +; + +notRlikeListAllWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name NOT RLIKE (".*") +| KEEP emp_no, first_name +| SORT emp_no +| LIMIT 2; + +emp_no:integer | first_name:keyword +10030 | null +10031 | null +; + +notRlikeListWildcard +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE first_name NOT RLIKE ("A.*","B.*", "C.*", "D.*","E.*", "F.*", "G.*", "H.*", "I.*", "J.*", "K.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10003 | Parto +10007 | Tzvetan +10008 | Saniya +10009 | Sumant +10011 | Mary +10012 | Patricio +10019 | Lillian +10020 | Mayuko +10021 | Ramzi +10022 | Shahaf +10024 | Suzette +10025 | Prasadram +10026 | Yongqiao +10029 | Otmar +10040 | Weiyi +10041 | Uri +10042 | Magy +10043 | Yishay +10044 | Mingsen +10045 | Moss +10046 | Lucien +10047 | Zvonko +10050 | Yinghua +10053 | Sanjiv +10054 | Mayumi +10061 | Tse +10064 | Udi +10065 | Satosi +10069 | Margareta +10070 | Reuven +10073 | Shir +10074 | Mokhtar +10077 | Mona +10080 | Premal +10081 | Zhongwei +10082 | Parviz +10083 | Vishv +10084 | Tuval +10086 | Somnath +10087 | Xinglin +10089 | Sudharsan +10092 | Valdiodio +10093 | Sailaja +10097 | Remzi +10098 | Sreekrishna +10099 | Valter +; + +rlikeListWithUpperTurnedInsensitive +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE TO_UPPER(first_name) RLIKE ("GEOR.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10001 | Georgi +10055 | Georgy +; + +rlikeListWithUpperTurnedInsensitiveMult +required_capability: rlike_with_list_of_patterns +FROM employees +| WHERE TO_UPPER(first_name) RLIKE ("GEOR.*", "WE.*") +| KEEP emp_no, first_name +| SORT emp_no; + +emp_no:integer | first_name:keyword +10001 | Georgi +10040 | Weiyi +10055 | Georgy +; + likeAll from employees | where first_name like "*" and emp_no > 10028 | sort emp_no | keep emp_no, first_name | limit 2; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryStopIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryStopIT.java index 222ffb5c05b0d..a866b6047f0dc 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryStopIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryStopIT.java @@ -9,11 +9,15 @@ import org.elasticsearch.Build; import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.core.Tuple; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.async.AsyncStopRequest; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import java.util.Iterator; import java.util.List; @@ -125,21 +129,32 @@ public void testStopQuery() throws Exception { } public void testStopQueryLocal() throws Exception { + assumeTrue("Pragma does not work in release builds", Build.current().isSnapshot()); Map testClusterInfo = setupClusters(3); int remote1NumShards = (Integer) testClusterInfo.get("remote1.num_shards"); int remote2NumShards = (Integer) testClusterInfo.get("remote2.num_shards"); populateRuntimeIndex(LOCAL_CLUSTER, "pause", INDEX_WITH_BLOCKING_MAPPING); + // Gets random node client but ensure it's the same node for all operations + Client client = cluster(LOCAL_CLUSTER).client(); + Tuple includeCCSMetadata = randomIncludeCCSMetadata(); boolean responseExpectMeta = includeCCSMetadata.v2(); - + // By default, ES|QL uses all workers in the esql_worker threadpool to execute drivers on data nodes. + // If a node is both data and coordinator, and all drivers are blocked by the allowEmitting latch, + // there are no workers left to execute the final driver or fetch pages from remote clusters. + // This can prevent remote clusters from being marked as successful on the coordinator, even if they + // have completed. To avoid this, we reserve at least one worker for the final driver and page fetching. + // A single worker is enough, as these two tasks can be paused and yielded. + var threadpool = cluster(LOCAL_CLUSTER).getInstance(TransportService.class).getThreadPool(); + int maxEsqlWorkers = threadpool.info(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME).getMax(); LOGGER.info("--> Launching async query"); - final String asyncExecutionId = startAsyncQuery( - client(), + final String asyncExecutionId = startAsyncQueryWithPragmas( + client, "FROM blocking,*:logs-* | STATS total=sum(coalesce(const,v)) | LIMIT 1", - includeCCSMetadata.v1() + includeCCSMetadata.v1(), + Map.of(QueryPragmas.TASK_CONCURRENCY.getKey(), between(1, maxEsqlWorkers - 1)) ); - try { // wait until we know that the local query against 'blocking' has started LOGGER.info("--> Waiting for {} to start", asyncExecutionId); @@ -147,9 +162,9 @@ public void testStopQueryLocal() throws Exception { // wait until the remotes are done LOGGER.info("--> Waiting for remotes", asyncExecutionId); - waitForCluster(client(), REMOTE_CLUSTER_1, asyncExecutionId); + waitForCluster(client, REMOTE_CLUSTER_1, asyncExecutionId); LOGGER.info("--> Remote 1 done", asyncExecutionId); - waitForCluster(client(), REMOTE_CLUSTER_2, asyncExecutionId); + waitForCluster(client, REMOTE_CLUSTER_2, asyncExecutionId); LOGGER.info("--> Remote 2 done", asyncExecutionId); /* at this point: @@ -159,10 +174,10 @@ public void testStopQueryLocal() throws Exception { // run the stop query AsyncStopRequest stopRequest = new AsyncStopRequest(asyncExecutionId); LOGGER.info("Launching stop for {}", asyncExecutionId); - ActionFuture stopAction = client().execute(EsqlAsyncStopAction.INSTANCE, stopRequest); + ActionFuture stopAction = client.execute(EsqlAsyncStopAction.INSTANCE, stopRequest); // ensure stop operation is running assertBusy(() -> { - try (EsqlQueryResponse asyncResponse = getAsyncResponse(client(), asyncExecutionId)) { + try (EsqlQueryResponse asyncResponse = getAsyncResponse(client, asyncExecutionId)) { EsqlExecutionInfo executionInfo = asyncResponse.getExecutionInfo(); LOGGER.info("--> Waiting for stop operation to start, current status: {}", executionInfo); assertNotNull(executionInfo); @@ -206,7 +221,7 @@ public void testStopQueryLocal() throws Exception { } } finally { SimplePauseFieldPlugin.allowEmitting.countDown(); - assertAcked(deleteAsyncId(client(), asyncExecutionId)); + assertAcked(deleteAsyncId(client, asyncExecutionId)); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryWithPartialResultsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryWithPartialResultsIT.java index 2c6b92655ba75..f16e9f448fa40 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryWithPartialResultsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryWithPartialResultsIT.java @@ -9,6 +9,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.set.Sets; @@ -37,6 +38,7 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -58,12 +60,15 @@ private static class ClusterSetup { void populateIndices() throws Exception { local.okIds = populateIndex(LOCAL_CLUSTER, "ok-local", local.okShards, between(1, 100)); populateIndexWithFailingFields(LOCAL_CLUSTER, "fail-local", local.failingShards); + createUnavailableIndex(LOCAL_CLUSTER, "unavailable-local"); remote1.okIds = populateIndex(REMOTE_CLUSTER_1, "ok-cluster1", remote1.okShards, between(1, 100)); populateIndexWithFailingFields(REMOTE_CLUSTER_1, "fail-cluster1", remote1.failingShards); + createUnavailableIndex(REMOTE_CLUSTER_1, "unavailable-cluster1"); remote2.okIds = populateIndex(REMOTE_CLUSTER_2, "ok-cluster2", remote2.okShards, between(1, 100)); populateIndexWithFailingFields(REMOTE_CLUSTER_2, "fail-cluster2", remote2.failingShards); + createUnavailableIndex(REMOTE_CLUSTER_2, "unavailable-cluster2"); } private void assertClusterPartial(EsqlQueryResponse resp, String clusterAlias, ClusterSetup cluster) { @@ -342,6 +347,42 @@ public void testFailSearchShardsOnLocalCluster() throws Exception { } } + public void testResolutionFailures() throws Exception { + populateIndices(); + EsqlQueryRequest request = new EsqlQueryRequest(); + request.allowPartialResults(true); + request.query("FROM ok*,unavailable* | LIMIT 1000"); + try (var resp = runQuery(request)) { + assertThat(EsqlTestUtils.getValuesList(resp), hasSize(local.okIds.size())); + assertTrue(resp.isPartial()); + EsqlExecutionInfo executionInfo = resp.getExecutionInfo(); + var localCluster = executionInfo.getCluster(LOCAL_CLUSTER); + assertThat(localCluster.getFailures(), not(empty())); + assertThat(localCluster.getFailures().get(0).reason(), containsString("index [unavailable-local] has no active shard copy")); + } + request.query("FROM *:ok*,unavailable* | LIMIT 1000"); + try (var resp = runQuery(request)) { + assertThat(EsqlTestUtils.getValuesList(resp), hasSize(remote1.okIds.size() + remote2.okIds.size())); + assertTrue(resp.isPartial()); + var executionInfo = resp.getExecutionInfo(); + var localCluster = executionInfo.getCluster(LOCAL_CLUSTER); + assertThat(localCluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SKIPPED)); + assertThat(localCluster.getFailures(), not(empty())); + assertThat(localCluster.getFailures().get(0).reason(), containsString("index [unavailable-local] has no active shard copy")); + assertThat(executionInfo.getCluster(REMOTE_CLUSTER_1).getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); + assertThat(executionInfo.getCluster(REMOTE_CLUSTER_2).getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); + } + request.query("FROM ok*,cluster-a:unavailable* | LIMIT 1000"); + try (var resp = runQuery(request)) { + assertThat(EsqlTestUtils.getValuesList(resp), hasSize(local.okIds.size())); + assertTrue(resp.isPartial()); + var remote1 = resp.getExecutionInfo().getCluster(REMOTE_CLUSTER_1); + assertThat(remote1.getFailures(), not(empty())); + assertThat(remote1.getFailures().get(0).reason(), containsString("index [unavailable-cluster1] has no active shard copy")); + assertThat(resp.getExecutionInfo().getCluster(LOCAL_CLUSTER).getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL)); + } + } + private Set populateIndexWithFailingFields(String clusterAlias, String indexName, int numShards) throws IOException { Client client = client(clusterAlias); XContentBuilder mapping = JsonXContent.contentBuilder().startObject(); @@ -384,4 +425,15 @@ private Set populateIndexWithFailingFields(String clusterAlias, String i } return ids; } + + private void createUnavailableIndex(String clusterAlias, String indexName) throws IOException { + Client client = client(clusterAlias); + assertAcked( + client.admin() + .indices() + .prepareCreate(indexName) + .setSettings(Settings.builder().put("index.routing.allocation.include._name", "no_such_node")) + .setWaitForActiveShards(ActiveShardCount.NONE) + ); + } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java index efefde8871546..5d3586b689832 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java @@ -597,8 +597,8 @@ public void testTaskContentsForGroupingStatsQuery() throws Exception { equalTo( """ \\_LuceneSourceOperator[sourceStatus] - \\_ValuesSourceReaderOperator[fields = [foo]] - \\_OrdinalsGroupingOperator(aggs = max of longs) + \\_ValuesSourceReaderOperator[fields = [pause_me, foo]] + \\_HashAggregationOperator[mode = , aggs = max of longs] \\_ExchangeSinkOperator""".replace("sourceStatus", sourceStatus) ) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlNodeFailureIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlNodeFailureIT.java index 30b05f741ec82..7da333e12f7e6 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlNodeFailureIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlNodeFailureIT.java @@ -19,6 +19,7 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import java.util.ArrayList; import java.util.Collection; @@ -30,6 +31,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -98,38 +100,30 @@ public void testFailureLoadingFields() throws Exception { public void testPartialResults() throws Exception { Set okIds = populateIndices(); - { - EsqlQueryRequest request = new EsqlQueryRequest(); - request.query("FROM fail,ok | LIMIT 100"); - request.allowPartialResults(true); - request.pragmas(randomPragmas()); - try (EsqlQueryResponse resp = run(request)) { - assertTrue(resp.isPartial()); - List> rows = EsqlTestUtils.getValuesList(resp); - assertThat(rows.size(), lessThanOrEqualTo(okIds.size())); - } - } - { - EsqlQueryRequest request = new EsqlQueryRequest(); - request.query("FROM fail,ok METADATA _id | KEEP _id, fail_me | LIMIT 100"); - request.allowPartialResults(true); - request.pragmas(randomPragmas()); - try (EsqlQueryResponse resp = run(request)) { - assertTrue(resp.isPartial()); - List> rows = EsqlTestUtils.getValuesList(resp); - assertThat(rows.size(), lessThanOrEqualTo(okIds.size())); - Set actualIds = new HashSet<>(); - for (List row : rows) { - assertThat(row.size(), equalTo(2)); - String id = (String) row.getFirst(); - assertThat(id, in(okIds)); - assertTrue(actualIds.add(id)); - } - EsqlExecutionInfo.Cluster localInfo = resp.getExecutionInfo().getCluster(RemoteClusterService.LOCAL_CLUSTER_GROUP_KEY); - assertThat(localInfo.getFailures(), not(empty())); - assertThat(localInfo.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.PARTIAL)); - assertThat(localInfo.getFailures().get(0).reason(), containsString("Accessing failing field")); + EsqlQueryRequest request = new EsqlQueryRequest(); + request.query("FROM fail,ok METADATA _id | KEEP _id, fail_me | LIMIT 100"); + request.allowPartialResults(true); + // have to run one shard at a time to avoid failing all shards + QueryPragmas pragma = new QueryPragmas( + Settings.builder().put(randomPragmas().getSettings()).put(QueryPragmas.MAX_CONCURRENT_SHARDS_PER_NODE.getKey(), 1).build() + ); + request.pragmas(pragma); + request.acceptedPragmaRisks(true); + try (EsqlQueryResponse resp = run(request)) { + assertTrue(resp.isPartial()); + List> rows = EsqlTestUtils.getValuesList(resp); + assertThat(rows.size(), equalTo(okIds.size())); + Set actualIds = new HashSet<>(); + for (List row : rows) { + assertThat(row.size(), equalTo(2)); + String id = (String) row.getFirst(); + assertThat(id, in(okIds)); + assertTrue(actualIds.add(id)); } + EsqlExecutionInfo.Cluster localInfo = resp.getExecutionInfo().getCluster(RemoteClusterService.LOCAL_CLUSTER_GROUP_KEY); + assertThat(localInfo.getFailures(), not(empty())); + assertThat(localInfo.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.PARTIAL)); + assertThat(localInfo.getFailures().get(0).reason(), containsString("Accessing failing field")); } } @@ -147,6 +141,15 @@ public void testDefaultPartialResults() throws Exception { EsqlQueryRequest request = new EsqlQueryRequest(); request.query("FROM fail,ok | LIMIT 100"); request.pragmas(randomPragmas()); + // have to run one shard at a time to avoid failing all shards + QueryPragmas pragma = new QueryPragmas( + Settings.builder() + .put(randomPragmas().getSettings()) + .put(QueryPragmas.MAX_CONCURRENT_SHARDS_PER_NODE.getKey(), 1) + .build() + ); + request.pragmas(pragma); + request.acceptedPragmaRisks(true); if (randomBoolean()) { request.allowPartialResults(true); } @@ -154,6 +157,7 @@ public void testDefaultPartialResults() throws Exception { assertTrue(resp.isPartial()); List> rows = EsqlTestUtils.getValuesList(resp); assertThat(rows.size(), lessThanOrEqualTo(okIds.size())); + assertThat(rows.size(), greaterThan(0)); } } // allow_partial_results = false diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRemoteErrorWrapIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRemoteErrorWrapIT.java index bc4d5d35ea71c..f6a9836929f35 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRemoteErrorWrapIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRemoteErrorWrapIT.java @@ -34,10 +34,9 @@ public void testThatRemoteErrorsAreWrapped() throws Exception { ); } - RemoteException wrappedError = expectThrows( - RemoteException.class, - () -> runQuery("FROM " + REMOTE_CLUSTER_1 + ":*," + REMOTE_CLUSTER_2 + ":* | LIMIT 100", false) - ); + RemoteException wrappedError = expectThrows(RemoteException.class, () -> { + try (EsqlQueryResponse ignored = runQuery("FROM " + REMOTE_CLUSTER_1 + ":*," + REMOTE_CLUSTER_2 + ":* | LIMIT 100", false)) {} + }); assertThat(wrappedError.getMessage(), is("Remote [cluster-a] encountered an error")); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java index 1d63a2bcf5373..e25cb82f29851 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java @@ -60,6 +60,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexOperator; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; @@ -198,6 +199,7 @@ private void runLookup(DataType keyType, PopulateIndices populateIndices) throws false // no scoring ); ValuesSourceReaderOperator.Factory reader = new ValuesSourceReaderOperator.Factory( + PhysicalSettings.VALUES_LOADING_JUMBO_SIZE.getDefault(Settings.EMPTY), List.of( new ValuesSourceReaderOperator.FieldInfo( "key", diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java index cfc8979360578..f7833b917b746 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java @@ -9,17 +9,15 @@ import org.elasticsearch.Build; import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Rounding; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.lucene.TimeSeriesSourceOperator; import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.TimeSeriesAggregationOperator; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.type.DataType; import org.junit.Before; -import java.time.ZoneOffset; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -59,7 +57,7 @@ public void testEmpty() { run("TS empty_index | LIMIT 1").close(); } - record Doc(String host, String cluster, long timestamp, int requestCount, double cpu) {} + record Doc(String host, String cluster, long timestamp, int requestCount, double cpu, ByteSizeValue memory) {} final List docs = new ArrayList<>(); @@ -87,7 +85,6 @@ static Double computeRate(List values) { @Before public void populateIndex() { - // this can be expensive, do one Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); client().admin() .indices() @@ -102,6 +99,8 @@ public void populateIndex() { "type=keyword,time_series_dimension=true", "cpu", "type=double,time_series_metric=gauge", + "memory", + "type=long,time_series_metric=gauge", "request_count", "type=integer,time_series_metric=counter" ) @@ -126,7 +125,8 @@ public void populateIndex() { } }); int cpu = randomIntBetween(0, 100); - docs.add(new Doc(host, hostToClusters.get(host), timestamp, requestCount, cpu)); + ByteSizeValue memory = ByteSizeValue.ofBytes(randomIntBetween(1024, 1024 * 1024)); + docs.add(new Doc(host, hostToClusters.get(host), timestamp, requestCount, cpu, memory)); } } Randomness.shuffle(docs); @@ -141,6 +141,8 @@ public void populateIndex() { doc.cluster, "cpu", doc.cpu, + "memory", + doc.memory.getBytes(), "request_count", doc.requestCount ) @@ -321,319 +323,6 @@ record RateKey(String cluster, String host) { } } - @AwaitsFix(bugUrl = "removed?") - public void testRateWithTimeBucket() { - var rounding = new Rounding.Builder(TimeValue.timeValueSeconds(60)).timeZone(ZoneOffset.UTC).build().prepareForUnknown(); - record RateKey(String host, String cluster, long interval) {} - Map> groups = new HashMap<>(); - for (Doc doc : docs) { - RateKey key = new RateKey(doc.host, doc.cluster, rounding.round(doc.timestamp)); - groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); - } - Map> bucketToRates = new HashMap<>(); - for (Map.Entry> e : groups.entrySet()) { - List values = bucketToRates.computeIfAbsent(e.getKey().interval, k -> new ArrayList<>()); - Double rate = computeRate(e.getValue()); - if (rate != null) { - values.add(rate); - } - } - List sortedKeys = bucketToRates.keySet().stream().sorted().limit(5).toList(); - try (var resp = run("TS hosts | STATS sum(rate(request_count)) BY ts=bucket(@timestamp, 1 minute) | SORT ts | LIMIT 5")) { - assertThat( - resp.columns(), - equalTo(List.of(new ColumnInfoImpl("sum(rate(request_count))", "double", null), new ColumnInfoImpl("ts", "date", null))) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(2)); - long key = sortedKeys.get(i); - assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key))); - List bucketValues = bucketToRates.get(key); - if (bucketValues.isEmpty()) { - assertNull(row.get(0)); - } else { - assertThat((double) row.get(0), closeTo(bucketValues.stream().mapToDouble(d -> d).sum(), 0.1)); - } - } - } - try (var resp = run("TS hosts | STATS avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute) | SORT ts | LIMIT 5")) { - assertThat( - resp.columns(), - equalTo(List.of(new ColumnInfoImpl("avg(rate(request_count))", "double", null), new ColumnInfoImpl("ts", "date", null))) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(2)); - long key = sortedKeys.get(i); - assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key))); - List bucketValues = bucketToRates.get(key); - if (bucketValues.isEmpty()) { - assertNull(row.get(0)); - } else { - double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); - assertThat((double) row.get(0), closeTo(avg, 0.1)); - } - } - } - try (var resp = run(""" - TS hosts - | STATS avg(rate(request_count)), avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute) - | SORT ts - | LIMIT 5 - """)) { - assertThat( - resp.columns(), - equalTo( - List.of( - new ColumnInfoImpl("avg(rate(request_count))", "double", null), - new ColumnInfoImpl("avg(rate(request_count))", "double", null), - new ColumnInfoImpl("ts", "date", null) - ) - ) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(3)); - long key = sortedKeys.get(i); - assertThat(row.get(2), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key))); - List bucketValues = bucketToRates.get(key); - if (bucketValues.isEmpty()) { - assertNull(row.get(0)); - assertNull(row.get(1)); - } else { - double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); - assertThat((double) row.get(0), closeTo(avg, 0.1)); - assertThat((double) row.get(1), closeTo(avg, 0.1)); - } - } - } - } - - @AwaitsFix(bugUrl = "removed?") - public void testRateWithTimeBucketAndCluster() { - var rounding = new Rounding.Builder(TimeValue.timeValueSeconds(60)).timeZone(ZoneOffset.UTC).build().prepareForUnknown(); - record RateKey(String host, String cluster, long interval) {} - Map> groups = new HashMap<>(); - for (Doc doc : docs) { - RateKey key = new RateKey(doc.host, doc.cluster, rounding.round(doc.timestamp)); - groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); - } - record GroupKey(String cluster, long interval) {} - Map> rateBuckets = new HashMap<>(); - for (Map.Entry> e : groups.entrySet()) { - RateKey key = e.getKey(); - List values = rateBuckets.computeIfAbsent(new GroupKey(key.cluster, key.interval), k -> new ArrayList<>()); - Double rate = computeRate(e.getValue()); - if (rate != null) { - values.add(rate); - } - } - Map> cpuBuckets = new HashMap<>(); - for (Doc doc : docs) { - GroupKey key = new GroupKey(doc.cluster, rounding.round(doc.timestamp)); - cpuBuckets.computeIfAbsent(key, k -> new ArrayList<>()).add(doc.cpu); - } - List sortedKeys = rateBuckets.keySet() - .stream() - .sorted(Comparator.comparing(GroupKey::interval).thenComparing(GroupKey::cluster)) - .limit(5) - .toList(); - try (var resp = run(""" - TS hosts - | STATS sum(rate(request_count)) BY ts=bucket(@timestamp, 1 minute), cluster - | SORT ts, cluster - | LIMIT 5""")) { - assertThat( - resp.columns(), - equalTo( - List.of( - new ColumnInfoImpl("sum(rate(request_count))", "double", null), - new ColumnInfoImpl("ts", "date", null), - new ColumnInfoImpl("cluster", "keyword", null) - ) - ) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(3)); - var key = sortedKeys.get(i); - assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); - assertThat(row.get(2), equalTo(key.cluster)); - List bucketValues = rateBuckets.get(key); - if (bucketValues.isEmpty()) { - assertNull(row.get(0)); - } else { - assertThat((double) row.get(0), closeTo(bucketValues.stream().mapToDouble(d -> d).sum(), 0.1)); - } - } - } - try (var resp = run(""" - TS hosts - | STATS avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute), cluster - | SORT ts, cluster - | LIMIT 5""")) { - assertThat( - resp.columns(), - equalTo( - List.of( - new ColumnInfoImpl("avg(rate(request_count))", "double", null), - new ColumnInfoImpl("ts", "date", null), - new ColumnInfoImpl("cluster", "keyword", null) - ) - ) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(3)); - var key = sortedKeys.get(i); - assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); - assertThat(row.get(2), equalTo(key.cluster)); - List bucketValues = rateBuckets.get(key); - if (bucketValues.isEmpty()) { - assertNull(row.get(0)); - } else { - double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); - assertThat((double) row.get(0), closeTo(avg, 0.1)); - } - } - } - - try (var resp = run(""" - TS hosts - | STATS - s = sum(rate(request_count)), - c = count(rate(request_count)), - max(rate(request_count)), - avg(rate(request_count)) - BY ts=bucket(@timestamp, 1minute), cluster - | SORT ts, cluster - | LIMIT 5 - | EVAL avg_rate= s/c - | KEEP avg_rate, `max(rate(request_count))`, `avg(rate(request_count))`, ts, cluster - """)) { - assertThat( - resp.columns(), - equalTo( - List.of( - new ColumnInfoImpl("avg_rate", "double", null), - new ColumnInfoImpl("max(rate(request_count))", "double", null), - new ColumnInfoImpl("avg(rate(request_count))", "double", null), - new ColumnInfoImpl("ts", "date", null), - new ColumnInfoImpl("cluster", "keyword", null) - ) - ) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(5)); - var key = sortedKeys.get(i); - assertThat(row.get(3), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); - assertThat(row.get(4), equalTo(key.cluster)); - List bucketValues = rateBuckets.get(key); - if (bucketValues.isEmpty()) { - assertNull(row.get(0)); - assertNull(row.get(1)); - } else { - double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); - assertThat((double) row.get(0), closeTo(avg, 0.1)); - double max = bucketValues.stream().mapToDouble(d -> d).max().orElse(0.0); - assertThat((double) row.get(1), closeTo(max, 0.1)); - } - assertEquals(row.get(0), row.get(2)); - } - } - try (var resp = run(""" - TS hosts - | STATS sum(rate(request_count)), max(cpu) BY ts=bucket(@timestamp, 1 minute), cluster - | SORT ts, cluster - | LIMIT 5""")) { - assertThat( - resp.columns(), - equalTo( - List.of( - new ColumnInfoImpl("sum(rate(request_count))", "double", null), - new ColumnInfoImpl("max(cpu)", "double", null), - new ColumnInfoImpl("ts", "date", null), - new ColumnInfoImpl("cluster", "keyword", null) - ) - ) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(4)); - var key = sortedKeys.get(i); - assertThat(row.get(2), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); - assertThat(row.get(3), equalTo(key.cluster)); - List rateBucket = rateBuckets.get(key); - if (rateBucket.isEmpty()) { - assertNull(row.get(0)); - } else { - assertThat((double) row.get(0), closeTo(rateBucket.stream().mapToDouble(d -> d).sum(), 0.1)); - } - List cpuBucket = cpuBuckets.get(key); - if (cpuBuckets.isEmpty()) { - assertNull(row.get(1)); - } else { - assertThat((double) row.get(1), closeTo(cpuBucket.stream().mapToDouble(d -> d).max().orElse(0.0), 0.1)); - } - } - } - try (var resp = run(""" - TS hosts - | STATS sum(rate(request_count)), avg(cpu) BY ts=bucket(@timestamp, 1 minute), cluster - | SORT ts, cluster - | LIMIT 5""")) { - assertThat( - resp.columns(), - equalTo( - List.of( - new ColumnInfoImpl("sum(rate(request_count))", "double", null), - new ColumnInfoImpl("avg(cpu)", "double", null), - new ColumnInfoImpl("ts", "date", null), - new ColumnInfoImpl("cluster", "keyword", null) - ) - ) - ); - List> values = EsqlTestUtils.getValuesList(resp); - assertThat(values, hasSize(sortedKeys.size())); - for (int i = 0; i < sortedKeys.size(); i++) { - List row = values.get(i); - assertThat(row, hasSize(4)); - var key = sortedKeys.get(i); - assertThat(row.get(2), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); - assertThat(row.get(3), equalTo(key.cluster)); - List rateBucket = rateBuckets.get(key); - if (rateBucket.isEmpty()) { - assertNull(row.get(0)); - } else { - assertThat((double) row.get(0), closeTo(rateBucket.stream().mapToDouble(d -> d).sum(), 0.1)); - } - List cpuBucket = cpuBuckets.get(key); - if (cpuBuckets.isEmpty()) { - assertNull(row.get(1)); - } else { - double avg = cpuBucket.stream().mapToDouble(d -> d).sum() / cpuBucket.size(); - assertThat((double) row.get(1), closeTo(avg, 0.1)); - } - } - } - } - public void testApplyRateBeforeFinalGrouping() { record RateKey(String cluster, String host) { @@ -733,6 +422,63 @@ public void testIndexMode() { assertThat(failure.getMessage(), containsString("Unknown index [hosts-old]")); } + public void testFieldDoesNotExist() { + // the old-hosts index doesn't have the cpu field + Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); + client().admin() + .indices() + .prepareCreate("old-hosts") + .setSettings(settings) + .setMapping( + "@timestamp", + "type=date", + "host", + "type=keyword,time_series_dimension=true", + "cluster", + "type=keyword,time_series_dimension=true", + "memory", + "type=long,time_series_metric=gauge", + "request_count", + "type=integer,time_series_metric=counter" + ) + .get(); + Randomness.shuffle(docs); + for (Doc doc : docs) { + client().prepareIndex("old-hosts") + .setSource( + "@timestamp", + doc.timestamp, + "host", + doc.host, + "cluster", + doc.cluster, + "memory", + doc.memory.getBytes(), + "request_count", + doc.requestCount + ) + .get(); + } + client().admin().indices().prepareRefresh("old-hosts").get(); + try (var resp1 = run(""" + TS hosts,old-hosts + | STATS sum(rate(request_count)), max(last_over_time(cpu)), max(last_over_time(memory)) BY cluster, host + | SORT cluster, host + | DROP `sum(rate(request_count))` + """)) { + try (var resp2 = run(""" + TS hosts + | STATS sum(rate(request_count)), max(last_over_time(cpu)), max(last_over_time(memory)) BY cluster, host + | SORT cluster, host + | DROP `sum(rate(request_count))` + """)) { + List> values1 = EsqlTestUtils.getValuesList(resp1); + List> values2 = EsqlTestUtils.getValuesList(resp2); + assertThat(values1, equalTo(values2)); + } + } + } + public void testProfile() { EsqlQueryRequest request = new EsqlQueryRequest(); request.profile(true); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesRateIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesRateIT.java new file mode 100644 index 0000000000000..92d8c55d66a72 --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesRateIT.java @@ -0,0 +1,477 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.action; + +import org.elasticsearch.Build; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.index.mapper.DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class TimeSeriesRateIT extends AbstractEsqlIntegTestCase { + + @Override + public EsqlQueryResponse run(EsqlQueryRequest request) { + assumeTrue("time series available in snapshot builds only", Build.current().isSnapshot()); + return super.run(request); + } + + record Doc(String host, String cluster, long timestamp, int requestCount, double cpu) {} + + final List docs = new ArrayList<>(); + + final Map hostToClusters = new HashMap<>(); + final Map hostToRate = new HashMap<>(); + final Map hostToCpu = new HashMap<>(); + + static final float DEVIATION_LIMIT = 0.30f; + // extra deviation tolerance for subgroups due to fewer samples + // at 0.35 deviation limit, we see 2/8000 failures. I am expanding to 0.4 + static final float SUBGROUP_DEVIATION_LIMIT = 0.45f; + // We expect a drop in the rate due to not covering window edges and not triggering + // extrapolation logic in the time series engine. + static final float EXPECTED_DROP_RATE = 0.15f; + static final int LIMIT = 5; + static final int MAX_HOSTS = 5; + static final int PCT_CHANCE_OF_RESET = 15; // 15% chance of resetting the request count + + @Before + public void populateIndex() { + // this can be expensive, do one + Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); + client().admin() + .indices() + .prepareCreate("hosts") + .setSettings(settings) + .setMapping( + "@timestamp", + "type=date", + "host", + "type=keyword,time_series_dimension=true", + "cluster", + "type=keyword,time_series_dimension=true", + "cpu", + "type=double,time_series_metric=gauge", + "request_count", + "type=integer,time_series_metric=counter" + ) + .get(); + final Map requestCounts = new HashMap<>(); + for (int i = 0; i < MAX_HOSTS; i++) { + hostToClusters.put("p" + i, randomFrom("qa", "prod")); + hostToRate.put("p" + i, randomIntBetween(10, 50)); + requestCounts.put("p" + i, randomIntBetween(0, 1000)); + hostToCpu.put("p" + i, randomIntBetween(0, 100)); + } + long timestamp = DEFAULT_DATE_TIME_FORMATTER.parseMillis("2024-04-15T00:00:00Z"); + int numDocs = between(100, 300); + docs.clear(); + // We want docs to span a 6-minute period, so we need to adapt their spacing accordingly. + var avgSamplingPeriod = 360.0 / numDocs; // 6 minutes divided by number of docs - then randomized below + + for (int i = 0; i < numDocs; i++) { + final var tsChange = randomDoubleBetween(avgSamplingPeriod - 1.0, avgSamplingPeriod + 1.0, true); + timestamp += Math.round(tsChange * 1000); + // We want a subset of hosts to have docs within a give time point. + var hosts = Set.copyOf(randomSubsetOf(between(2, hostToClusters.size()), hostToClusters.keySet())); + for (String host : hostToClusters.keySet()) { + var requestCount = requestCounts.compute(host, (k, curr) -> { + if (randomInt(100) <= PCT_CHANCE_OF_RESET) { + return Math.toIntExact(Math.round(hostToRate.get(k) * tsChange)); + } else { + return Math.toIntExact(Math.round((curr == null ? 0 : curr) + hostToRate.get(k) * tsChange)); + } + }); + if (hosts.contains(host)) { + docs.add(new Doc(host, hostToClusters.get(host), timestamp, requestCount, hostToCpu.get(host))); + } + } + } + Randomness.shuffle(docs); + for (Doc doc : docs) { + client().prepareIndex("hosts") + .setSource( + "@timestamp", + doc.timestamp, + "host", + doc.host, + "cluster", + doc.cluster, + "cpu", + doc.cpu, + "request_count", + doc.requestCount + ) + .get(); + } + client().admin().indices().prepareRefresh("hosts").get(); + + } + + private String hostTable() { + StringBuilder sb = new StringBuilder(); + for (String host : hostToClusters.keySet()) { + var docsForHost = docs.stream().filter(d -> d.host().equals(host)).toList(); + sb.append(host) + .append(" -> ") + .append(hostToClusters.get(host)) + .append(", rate=") + .append(hostToRate.get(host)) + .append(", cpu=") + .append(hostToCpu.get(host)) + .append(", numDocs=") + .append(docsForHost.size()) + .append("\n"); + } + // Now we add total rate and total CPU used: + sb.append("Total rate: ").append(hostToRate.values().stream().mapToInt(a -> a).sum()).append("\n"); + sb.append("Average rate: ").append(hostToRate.values().stream().mapToInt(a -> a).average().orElseThrow()).append("\n"); + sb.append("Total CPU: ").append(hostToCpu.values().stream().mapToInt(a -> a).sum()).append("\n"); + sb.append("Average CPU: ").append(hostToCpu.values().stream().mapToInt(a -> a).average().orElseThrow()).append("\n"); + // Add global info + sb.append("Count of docs: ").append(docs.size()).append("\n"); + // Add docs per minute + sb.append("Docs in each minute:\n"); + Map docsPerMinute = new HashMap<>(); + for (Doc doc : docs) { + long minute = (doc.timestamp / 60000) % 1000; // convert to minutes + docsPerMinute.merge(minute, 1, Integer::sum); + } + for (Map.Entry entry : docsPerMinute.entrySet()) { + sb.append("Minute ").append(entry.getKey()).append(": ").append(entry.getValue()).append(" docs\n"); + } + return sb.toString(); + } + + private String valuesTable(List> values) { + StringBuilder sb = new StringBuilder(); + for (List row : values) { + sb.append(row).append("\n"); + } + return sb.toString(); + } + + public void testRateWithTimeBucketSumByMin() { + try ( + var resp = run( + "TS hosts | STATS sum(rate(request_count)) BY tbucket=bucket(@timestamp, 1 minute) | SORT tbucket | LIMIT " + LIMIT + ) + ) { + List> values = EsqlTestUtils.getValuesList(resp); + try { + assertThat( + resp.columns(), + equalTo( + List.of(new ColumnInfoImpl("sum(rate(request_count))", "double", null), new ColumnInfoImpl("tbucket", "date", null)) + ) + ); + assertThat(values, hasSize(LIMIT)); + for (int i = 0; i < values.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + var totalRate = hostToRate.values().stream().mapToDouble(a -> a + 0.0).sum(); + assertThat((double) row.get(0), closeTo(totalRate * (1 - EXPECTED_DROP_RATE), DEVIATION_LIMIT * totalRate)); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(values) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketAvgByMin() { + try (var resp = run("TS hosts | STATS avg(rate(request_count)) BY tbucket=bucket(@timestamp, 1minute) | SORT tbucket | LIMIT 5")) { + try { + assertThat( + resp.columns(), + equalTo( + List.of(new ColumnInfoImpl("avg(rate(request_count))", "double", null), new ColumnInfoImpl("tbucket", "date", null)) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(LIMIT)); + for (int i = 0; i < values.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + var expectedRate = hostToRate.values().stream().mapToDouble(a -> a + 0.0).sum() / hostToRate.size(); + assertThat((double) row.get(0), closeTo(expectedRate * (1 - EXPECTED_DROP_RATE), DEVIATION_LIMIT * expectedRate)); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketSumByMinAndLimitAsParam() { + try (var resp = run(""" + TS hosts + | STATS avg(rate(request_count)) BY tbucket=bucket(@timestamp, 1minute) + | SORT tbucket + | LIMIT""" + " " + LIMIT)) { + try { + assertThat( + resp.columns(), + equalTo( + List.of(new ColumnInfoImpl("avg(rate(request_count))", "double", null), new ColumnInfoImpl("tbucket", "date", null)) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(LIMIT)); + for (int i = 0; i < LIMIT; i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + double expectedAvg = hostToRate.values().stream().mapToDouble(d -> d).sum() / hostToRate.size(); + assertThat((double) row.get(0), closeTo(expectedAvg * (1 - EXPECTED_DROP_RATE), DEVIATION_LIMIT * expectedAvg)); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketAndClusterSumByMin() { + try (var resp = run(""" + TS hosts + | STATS sum(rate(request_count)) BY tbucket=bucket(@timestamp, 1 minute), cluster + | SORT tbucket, cluster + | LIMIT 5""")) { + try { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl("sum(rate(request_count))", "double", null), + new ColumnInfoImpl("tbucket", "date", null), + new ColumnInfoImpl("cluster", "keyword", null) + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + // we have 2 clusters, so we expect 2 * limit rows + for (List row : values) { + assertThat(row, hasSize(3)); + String cluster = (String) row.get(2); + var expectedRate = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToRate.get(e.getKey()) + 0.0) + .sum(); + assertThat( + (double) row.get(0), + closeTo(expectedRate * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedRate) + ); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketAndClusterAvgByMin() { + try (var resp = run(""" + TS hosts + | STATS avg(rate(request_count)) BY tbucket=bucket(@timestamp, 1minute), cluster + | SORT tbucket, cluster + | LIMIT 5""")) { + try { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl("avg(rate(request_count))", "double", null), + new ColumnInfoImpl("tbucket", "date", null), + new ColumnInfoImpl("cluster", "keyword", null) + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + for (List row : values) { + assertThat(row, hasSize(3)); + String cluster = (String) row.get(2); + var expectedAvg = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToRate.get(e.getKey()) + 0.0) + .average() + .orElseThrow(); + assertThat( + (double) row.get(0), + closeTo(expectedAvg * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedAvg) + ); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketAndClusterMultipleStatsByMin() { + try (var resp = run(""" + TS hosts + | STATS + s = sum(rate(request_count)), + c = count(rate(request_count)), + max(rate(request_count)), + avg(rate(request_count)) + BY tbucket=bucket(@timestamp, 1minute), cluster + | SORT tbucket, cluster + | LIMIT 5 + | EVAL avg_rate= s/c + | KEEP avg_rate, `max(rate(request_count))`, `avg(rate(request_count))`, tbucket, cluster + """)) { + try { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl("avg_rate", "double", null), + new ColumnInfoImpl("max(rate(request_count))", "double", null), + new ColumnInfoImpl("avg(rate(request_count))", "double", null), + new ColumnInfoImpl("tbucket", "date", null), + new ColumnInfoImpl("cluster", "keyword", null) + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + for (List row : values) { + assertThat(row, hasSize(5)); + String cluster = (String) row.get(4); + var expectedAvg = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToRate.get(e.getKey()) + 0.0) + .average() + .orElseThrow(); + var expectedMax = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToRate.get(e.getKey()) + 0.0) + .max() + .orElseThrow(); + assertThat( + (double) row.get(0), + closeTo(expectedAvg * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedAvg) + ); + assertThat( + (double) row.get(2), + closeTo(expectedAvg * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedAvg) + ); + assertThat( + (double) row.get(1), + closeTo(expectedMax * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedMax) + ); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketAndClusterMultipleMetricsByMin() { + try (var resp = run(""" + TS hosts + | STATS sum(rate(request_count)), max(cpu) BY tbucket=bucket(@timestamp, 1 minute), cluster + | SORT tbucket, cluster + | LIMIT 5""")) { + try { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl("sum(rate(request_count))", "double", null), + new ColumnInfoImpl("max(cpu)", "double", null), + new ColumnInfoImpl("tbucket", "date", null), + new ColumnInfoImpl("cluster", "keyword", null) + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + for (List row : values) { + assertThat(row, hasSize(4)); + String cluster = (String) row.get(3); + var expectedRate = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToRate.get(e.getKey()) + 0.0) + .sum(); + assertThat( + (double) row.get(0), + closeTo(expectedRate * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedRate) + ); + var expectedCpu = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToCpu.get(e.getKey()) + 0.0) + .max() + .orElseThrow(); + assertThat( + (double) row.get(1), + closeTo(expectedCpu * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedCpu) + ); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } + + public void testRateWithTimeBucketAndClusterMultipleMetricsAvgByMin() { + try (var resp = run(""" + TS hosts + | STATS sum(rate(request_count)), avg(cpu) BY tbucket=bucket(@timestamp, 1 minute), cluster + | SORT tbucket, cluster + | LIMIT 5""")) { + try { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfoImpl("sum(rate(request_count))", "double", null), + new ColumnInfoImpl("avg(cpu)", "double", null), + new ColumnInfoImpl("tbucket", "date", null), + new ColumnInfoImpl("cluster", "keyword", null) + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + for (List row : values) { + assertThat(row, hasSize(4)); + String cluster = (String) row.get(3); + var expectedRate = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToRate.get(e.getKey()) + 0.0) + .sum(); + assertThat( + (double) row.get(0), + closeTo(expectedRate * (1 - EXPECTED_DROP_RATE), SUBGROUP_DEVIATION_LIMIT * expectedRate) + ); + var expectedCpu = hostToClusters.entrySet() + .stream() + .filter(e -> e.getValue().equals(cluster)) + .mapToDouble(e -> hostToCpu.get(e.getKey()) + 0.0) + .average() + .orElseThrow(); + assertThat((double) row.get(1), closeTo(expectedCpu, SUBGROUP_DEVIATION_LIMIT * expectedCpu)); + } + } catch (AssertionError e) { + throw new AssertionError("Values:\n" + valuesTable(EsqlTestUtils.getValuesList(resp)) + "\n Hosts:\n" + hostTable(), e); + } + } + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/CanMatchIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/CanMatchIT.java index e1ef6730c1f05..541e6a1421946 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/CanMatchIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/CanMatchIT.java @@ -19,8 +19,10 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; import java.util.Collection; @@ -363,6 +365,10 @@ public void testFailOnUnavailableShards() throws Exception { syncEsqlQueryRequest().query("from events,logs | KEEP timestamp,message").allowPartialResults(true) ) ) { + assertTrue(resp.isPartial()); + EsqlExecutionInfo.Cluster local = resp.getExecutionInfo().getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); + assertThat(local.getFailures(), hasSize(1)); + assertThat(local.getFailures().get(0).reason(), containsString("index [logs] has no active shard copy")); assertThat(getValuesList(resp), hasSize(3)); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java index 9b3c2278c1bb8..c409f3c480ca3 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.compute.operator.exchange.ExchangeService; @@ -131,12 +132,8 @@ public void testRetryOnShardMovement() { (handler, request, channel, task) -> { // move index shard if (shouldMove.compareAndSet(true, false)) { - var currentShardNodeId = clusterService().state() - .routingTable() - .index("index-1") - .shard(0) - .primaryShard() - .currentNodeId(); + var shardRouting = clusterService().state().routingTable(ProjectId.DEFAULT).shardRoutingTable("index-1", 0); + var currentShardNodeId = shardRouting.primaryShard().currentNodeId(); assertAcked( client().admin() .indices() diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 61795addb1e79..9ae1c980337f1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -114,6 +114,29 @@ public void testKnnNonPushedDown() { } } + public void testKnnWithPrefilters() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, 5) AND id > 5 + | KEEP id, floats, _score, vector + | SORT _score DESC + | LIMIT 5 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + // K = 5, 1 more for every id > 10 + assertEquals(5, valuesList.size()); + } + } + public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; Arrays.fill(queryVector, 1.0f); @@ -136,7 +159,7 @@ public void testKnnWithLookupJoin() { @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); var indexName = "test"; var client = client().admin().indices(); @@ -163,7 +186,7 @@ public void setup() throws IOException { var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - numDocs = randomIntBetween(10, 20); + numDocs = randomIntBetween(15, 25); numDims = randomIntBetween(3, 10); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; float value = 0.0f; @@ -202,6 +225,5 @@ private void createAndPopulateLookupIndex(IndicesAdminClient client, String look var createRequest = client.prepareCreate(lookupIndexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/ScoreFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/ScoreFunctionIT.java new file mode 100644 index 0000000000000..6c60c0334eddd --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/ScoreFunctionIT.java @@ -0,0 +1,494 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.xpack.esql.VerificationException; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.kql.KqlPlugin; +import org.junit.Before; + +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.CoreMatchers.containsString; + +//@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug") +public class ScoreFunctionIT extends AbstractEsqlIntegTestCase { + + @Before + public void setupIndex() { + assumeTrue("can run this only when score() function is enabled", EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled()); + createAndPopulateIndex(); + } + + public void testScoreSingleNoMetadata() { + var query = """ + FROM test + | WHERE match(content, "fox") AND match(content, "brown") + | EVAL first_score = score(match(content, "fox")) + | KEEP id, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + assertValues(resp.values(), List.of(List.of(1, 1.156558871269226), List.of(6, 0.9114001989364624))); + } + } + + public void testScoreWithLimit() { + var query = """ + FROM test + | WHERE match(content, "fox") AND match(content, "brown") + | EVAL first_score = score(match(content, "fox")) + | KEEP id, first_score + | SORT id + | LIMIT 1 + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + assertValues(resp.values(), List.of(List.of(1, 1.156558871269226))); + } + } + + public void testScoreAfterLimit() { + var query = """ + FROM test + | WHERE match(content, "fox") AND match(content, "brown") + | LIMIT 1 + | EVAL first_score = score(match(content, "fox")) + | KEEP id, first_score + | SORT id + """; + + var error = expectThrows(VerificationException.class, () -> run(query)); + assertThat(error.getMessage(), containsString("[SCORE] function cannot be used after LIMIT")); + } + + public void testScoreQueryExpressions() { + var query = """ + FROM test METADATA _score + | WHERE match(content, "fox") AND match(content, "brown") + | EVAL first_score = score(match(content, CONCAT("brown ", " fox"))) + | KEEP id, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + assertValues(resp.values(), List.of(List.of(1, 1.4274532794952393), List.of(6, 1.1248724460601807))); + } + } + + public void testDisjunctionWithFiltersNoMetadata() { + var query = """ + FROM test + | EVAL first_score = score((match(content, "fox") OR match(content, "brown")) AND id > 1) + | WHERE match(content, "fox") AND match(content, "brown") + | KEEP id, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + assertValues(resp.values(), List.of(List.of(1, 1.4274532496929169), List.of(6, 1.1248724162578583))); + } + } + + public void testScoreDifferentWhereMatch() { + var query = """ + FROM test METADATA _score + | EVAL first_score = score(match(content, "brown")) + | WHERE match(content, "fox") + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.156558871269226, 0.2708943784236908), List.of(6, 0.9114001989364624, 0.21347221732139587)) + ); + } + } + + public void testScoreDifferentWhereMatchNoMetadata() { + var query = """ + FROM test + | EVAL first_score = score(match(content, "brown")) + | WHERE match(content, "fox") + | KEEP id, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + assertValues(resp.values(), List.of(List.of(1, 0.2708943784236908), List.of(6, 0.21347221732139587))); + } + } + + public void testScoreInWhereWithMatch() { + var query = """ + FROM test + | WHERE score(match(content, "brown")) + """; + + var error = expectThrows(VerificationException.class, () -> run(query)); + assertThat(error.getMessage(), containsString("[SCORE] function can't be used in WHERE")); + } + + public void testScoreInWhereWithFilter() { + var query = """ + FROM test + | WHERE score(id > 0) + """; + + var error = expectThrows(VerificationException.class, () -> run(query)); + assertThat(error.getMessage(), containsString("Condition expression needs to be boolean, found [DOUBLE]")); + } + + public void testScoreNonFullTextFunction() { + var query = """ + FROM test + | EVAL meaningless = score(abs(-0.1)) + | KEEP id, meaningless + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "meaningless")); + assertColumnTypes(resp.columns(), List.of("integer", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 0.0), List.of(2, 0.0), List.of(3, 0.0), List.of(4, 0.0), List.of(5, 0.0), List.of(6, 0.0)) + ); + } + } + + public void testScoreMultipleWhereMatch() { + var query = """ + FROM test METADATA _score + | WHERE match(content, "brown") + | WHERE match(content, "fox") + | EVAL first_score = score(match(content, "brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 0.2708943784236908), List.of(6, 1.1248724460601807, 0.21347221732139587)) + ); + } + } + + public void testScoreMultipleWhereKqlMatch() { + var query = """ + FROM test METADATA _score + | WHERE kql("brown") + | WHERE match(content, "fox") + | EVAL first_score = score(kql("brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 0.2708943784236908), List.of(6, 1.1248724460601807, 0.21347221732139587)) + ); + } + } + + public void testScoreMultipleWhereQstrMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") + | WHERE match(content, "fox") + | EVAL first_score = score(qstr("brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 0.2708943784236908), List.of(6, 1.1248724460601807, 0.21347221732139587)) + ); + } + } + + public void testScoreSameWhereQstrAndMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") AND match(content, "fox") + | EVAL first_score = score(qstr("brown") AND match(content, "fox")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 1.4274532496929169), List.of(6, 1.1248724460601807, 1.1248724162578583)) + ); + } + } + + public void testScoreSingleWhereQstrAndMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") AND match(content, "fox") + | EVAL first_score = score(qstr("brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 0.2708943784236908), List.of(6, 1.1248724460601807, 0.21347221732139587)) + ); + } + } + + public void testScoreBothWhereQstrAndMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") AND match(content, "fox") + | EVAL first_score = score(qstr("brown")) + | EVAL second_score = score(match(content, "fox")) + | KEEP id, _score, first_score, second_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score", "second_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "double")); + assertValues( + resp.values(), + List.of( + List.of(1, 1.4274532794952393, 0.2708943784236908, 1.156558871269226), + List.of(6, 1.1248724460601807, 0.21347221732139587, 0.9114001989364624) + ) + ); + } + } + + public void testScoreSameWhereKqlAndMatch() { + var query = """ + FROM test METADATA _score + | WHERE kql("brown") AND match(content, "fox") + | EVAL first_score = score(kql("brown") AND match(content, "fox")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 1.4274532496929169), List.of(6, 1.1248724460601807, 1.1248724162578583)) + ); + } + } + + public void testScoreSingleWhereKqlAndMatch() { + var query = """ + FROM test METADATA _score + | WHERE kql("brown") AND match(content, "fox") + | EVAL first_score = score(kql("brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of(List.of(1, 1.4274532794952393, 0.2708943784236908), List.of(6, 1.1248724460601807, 0.21347221732139587)) + ); + } + } + + public void testScoreBothWhereKqlAndMatch() { + var query = """ + FROM test METADATA _score + | WHERE kql("brown") AND match(content, "fox") + | EVAL first_score = score(kql("brown")) + | EVAL second_score = score(match(content, "fox")) + | KEEP id, _score, first_score, second_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score", "second_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "double")); + assertValues( + resp.values(), + List.of( + List.of(1, 1.4274532794952393, 0.2708943784236908, 1.156558871269226), + List.of(6, 1.1248724460601807, 0.21347221732139587, 0.9114001989364624) + ) + ); + } + } + + public void testScoreSameWhereQstrORMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") OR match(content, "fox") + | EVAL first_score = score(qstr("brown") OR match(content, "fox")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of( + List.of(1, 1.4274532794952393, 1.4274532496929169), + List.of(2, 0.2708943784236908, 0.2708943784236908), + List.of(3, 0.2708943784236908, 0.2708943784236908), + List.of(4, 0.19301524758338928, 0.19301524758338928), + List.of(6, 1.1248724460601807, 1.1248724162578583) + ) + ); + } + } + + public void testScoreSingleWhereQstrORMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") OR match(content, "fox") + | EVAL first_score = score(qstr("brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of( + List.of(1, 1.4274532794952393, 0.2708943784236908), + List.of(2, 0.2708943784236908, 0.2708943784236908), + List.of(3, 0.2708943784236908, 0.2708943784236908), + List.of(4, 0.19301524758338928, 0.19301524758338928), + List.of(6, 1.1248724460601807, 0.21347221732139587) + ) + ); + } + } + + public void testScoreBothWhereQstrORMatch() { + var query = """ + FROM test METADATA _score + | WHERE qstr("brown") OR match(content, "fox") + | EVAL first_score = score(qstr("brown")) + | EVAL second_score = score(match(content, "fox")) + | KEEP id, _score, first_score, second_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score", "second_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "double")); + assertValues( + resp.values(), + List.of( + List.of(1, 1.4274532794952393, 0.2708943784236908, 1.156558871269226), + List.of(2, 0.2708943784236908, 0.2708943784236908, 0.0), + List.of(3, 0.2708943784236908, 0.2708943784236908, 0.0), + List.of(4, 0.19301524758338928, 0.19301524758338928, 0.0), + List.of(6, 1.1248724460601807, 0.21347221732139587, 0.9114001989364624) + ) + ); + } + } + + public void testSimpleScoreAlone() { + var query = """ + FROM test METADATA _score + | EVAL first_score = score(match(content, "brown")) + | KEEP id, _score, first_score + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "_score", "first_score")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double")); + assertValues( + resp.values(), + List.of( + List.of(1, 0.0, 0.2708943784236908), + List.of(2, 0.0, 0.2708943784236908), + List.of(3, 0.0, 0.2708943784236908), + List.of(4, 0.0, 0.19301524758338928), + List.of(5, 0.0, 0.0), + List.of(6, 0.0, 0.21347221732139587) + ) + ); + } + } + + private void createAndPopulateIndex() { + var indexName = "test"; + var client = client().admin().indices(); + var CreateRequest = client.prepareCreate(indexName) + .setSettings(Settings.builder().put("index.number_of_shards", 1)) + .setMapping("id", "type=integer", "content", "type=text"); + assertAcked(CreateRequest); + client().prepareBulk() + .add(new IndexRequest(indexName).id("1").source("id", 1, "content", "This is a brown fox")) + .add(new IndexRequest(indexName).id("2").source("id", 2, "content", "This is a brown dog")) + .add(new IndexRequest(indexName).id("3").source("id", 3, "content", "This dog is really brown")) + .add(new IndexRequest(indexName).id("4").source("id", 4, "content", "The dog is brown but this document is very very long")) + .add(new IndexRequest(indexName).id("5").source("id", 5, "content", "There is also a white cat")) + .add(new IndexRequest(indexName).id("6").source("id", 6, "content", "The quick brown fox jumps over the lazy dog")) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + ensureYellow(indexName); + } + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), KqlPlugin.class); + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java new file mode 100644 index 0000000000000..6a861746facfd --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.esql.EsqlClientException; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; + +public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase { + + @ParametersFactory + public static Iterable parameters() throws Exception { + List params = new ArrayList<>(); + + params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE }); + + return params; + } + + private final String functionName; + private final VectorSimilarityFunction similarityFunction; + private int numDims; + + public VectorSimilarityFunctionsIT( + @Name("functionName") String functionName, + @Name("similarityFunction") VectorSimilarityFunction similarityFunction + ) { + this.functionName = functionName; + this.similarityFunction = similarityFunction; + } + + @SuppressWarnings("unchecked") + public void testSimilarityBetweenVectors() { + var query = String.format(Locale.ROOT, """ + FROM test + | EVAL similarity = %s(left_vector, right_vector) + | KEEP left_vector, right_vector, similarity + """, functionName); + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + valuesList.forEach(values -> { + float[] left = readVector((List) values.get(0)); + float[] right = readVector((List) values.get(1)); + Double similarity = (Double) values.get(2); + + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.compare(left, right); + assertEquals(expectedSimilarity, similarity, 0.0001); + }); + } + } + + @SuppressWarnings("unchecked") + public void testSimilarityBetweenConstantVectorAndField() { + var randomVector = randomVectorArray(); + var query = String.format(Locale.ROOT, """ + FROM test + | EVAL similarity = %s(left_vector, %s) + | KEEP left_vector, similarity + """, functionName, Arrays.toString(randomVector)); + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + valuesList.forEach(values -> { + float[] left = readVector((List) values.get(0)); + Double similarity = (Double) values.get(1); + + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.compare(left, randomVector); + assertEquals(expectedSimilarity, similarity, 0.0001); + }); + } + } + + public void testDifferentDimensions() { + var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2)); + var query = String.format(Locale.ROOT, """ + FROM test + | EVAL similarity = %s(left_vector, %s) + | KEEP left_vector, similarity + """, functionName, Arrays.toString(randomVector)); + + EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); }); + assertTrue(iae.getMessage().contains("Vectors must have the same dimensions")); + } + + @SuppressWarnings("unchecked") + public void testSimilarityBetweenConstantVectors() { + var vectorLeft = randomVectorArray(); + var vectorRight = randomVectorArray(); + var query = String.format(Locale.ROOT, """ + ROW a = 1 + | EVAL similarity = %s(%s, %s) + | KEEP similarity + """, functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight)); + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(1, valuesList.size()); + + Double similarity = (Double) valuesList.get(0).get(0); + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight); + assertEquals(expectedSimilarity, similarity, 0.0001); + } + } + + private static float[] readVector(List leftVector) { + float[] leftScratch = new float[leftVector.size()]; + for (int i = 0; i < leftVector.size(); i++) { + leftScratch[i] = leftVector.get(i); + } + return leftScratch; + } + + @Before + public void setup() throws IOException { + assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); + + createIndexWithDenseVector("test"); + + numDims = randomIntBetween(32, 64) * 2; // min 64, even number + int numDocs = randomIntBetween(10, 100); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + for (int i = 0; i < numDocs; i++) { + List leftVector = randomVector(); + List rightVector = randomVector(); + docs[i] = prepareIndex("test").setId("" + i) + .setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector); + } + + indexRandom(true, docs); + } + + private List randomVector() { + assert numDims != 0 : "numDims must be set before calling randomVector()"; + List vector = new ArrayList<>(numDims); + for (int j = 0; j < numDims; j++) { + vector.add(randomFloat()); + } + return vector; + } + + private float[] randomVectorArray() { + assert numDims != 0 : "numDims must be set before calling randomVectorArray()"; + return randomVectorArray(numDims); + } + + private static float[] randomVectorArray(int dimensions) { + float[] vector = new float[dimensions]; + for (int j = 0; j < dimensions; j++) { + vector[j] = randomFloat(); + } + return vector; + } + + private void createIndexWithDenseVector(String indexName) throws IOException { + var client = client().admin().indices(); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("id") + .field("type", "integer") + .endObject(); + createDenseVectorField(mapping, "left_vector"); + createDenseVectorField(mapping, "right_vector"); + mapping.endObject().endObject(); + Settings.Builder settingsBuilder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)); + + var CreateRequest = client.prepareCreate(indexName) + .setSettings(Settings.builder().put("index.number_of_shards", 1)) + .setMapping(mapping) + .setSettings(settingsBuilder.build()); + assertAcked(CreateRequest); + } + + private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException { + mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine"); + mapping.endObject(); + } +} diff --git a/x-pack/plugin/esql/src/main/antlr/parser/Expression.g4 b/x-pack/plugin/esql/src/main/antlr/parser/Expression.g4 index abb8fe09164f5..0462b2d6a67ee 100644 --- a/x-pack/plugin/esql/src/main/antlr/parser/Expression.g4 +++ b/x-pack/plugin/esql/src/main/antlr/parser/Expression.g4 @@ -21,6 +21,7 @@ regexBooleanExpression : valueExpression (NOT)? LIKE string #likeExpression | valueExpression (NOT)? RLIKE string #rlikeExpression | valueExpression (NOT)? LIKE LP string (COMMA string )* RP #likeListExpression + | valueExpression (NOT)? RLIKE LP string (COMMA string )* RP #rlikeListExpression ; matchBooleanExpression diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 7ee953b9d1d9a..733f0cabb2e22 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -890,6 +890,31 @@ public enum Cap { */ AGGREGATE_METRIC_DOUBLE_PARTIAL_SUBMETRICS(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + /** + * Support for rendering aggregate_metric_double type + */ + AGGREGATE_METRIC_DOUBLE_RENDERING(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + + /** + * Support for to_aggregate_metric_double function + */ + AGGREGATE_METRIC_DOUBLE_CONVERT_TO(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + + /** + * Support for sorting when aggregate_metric_doubles are present + */ + AGGREGATE_METRIC_DOUBLE_SORTING(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + + /** + * Support avg with aggregate metric doubles + */ + AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + + /** + * Support for implicit casting of aggregate metric double when run in aggregations + */ + AGGREGATE_METRIC_DOUBLE_IMPLICIT_CASTING_IN_AGGS(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + /** * Support change point detection "CHANGE_POINT". */ @@ -913,11 +938,6 @@ public enum Cap { */ SUPPORT_PARTIAL_RESULTS, - /** - * Support for rendering aggregate_metric_double type - */ - AGGREGATE_METRIC_DOUBLE_RENDERING(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), - /** * Support for RERANK command */ @@ -964,11 +984,6 @@ public enum Cap { */ NON_FULL_TEXT_FUNCTIONS_SCORING, - /** - * Support for to_aggregate_metric_double function - */ - AGGREGATE_METRIC_DOUBLE_CONVERT_TO(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), - /** * The {@code _query} API now reports the original types. */ @@ -995,11 +1010,6 @@ public enum Cap { */ MAKE_NUMBER_OF_CHANNELS_CONSISTENT_WITH_LAYOUT, - /** - * Support for sorting when aggregate_metric_doubles are present - */ - AGGREGATE_METRIC_DOUBLE_SORTING(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), - /** * Supercedes {@link Cap#MAKE_NUMBER_OF_CHANNELS_CONSISTENT_WITH_LAYOUT}. */ @@ -1077,6 +1087,11 @@ public enum Cap { */ LAST_OVER_TIME(Build.current().isSnapshot()), + /** + * score function + */ + SCORE_FUNCTION(Build.current().isSnapshot()), + /** * Support for the SAMPLE command */ @@ -1170,6 +1185,11 @@ public enum Cap { */ PARAMETER_FOR_LIMIT, + /** + * Changed and normalized the LIMIT error message. + */ + NORMALIZED_LIMIT_ERROR_MESSAGE, + /** * Dense vector field type support */ @@ -1203,8 +1223,11 @@ public enum Cap { /** * Support knn function */ - KNN_FUNCTION_V2(Build.current().isSnapshot()), + KNN_FUNCTION_V3(Build.current().isSnapshot()), + /** + * Support for the LIKE operator with a list of wildcards. + */ LIKE_WITH_LIST_OF_PATTERNS, LIKE_LIST_ON_INDEX_FIELDS, @@ -1222,29 +1245,46 @@ public enum Cap { */ NO_PLAIN_STRINGS_IN_LITERALS, + /** + * Support for the mv_expand target attribute should be retained in its original position. + * see ES|QL: inconsistent column order #129000 + */ + FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER, + /** * (Re)Added EXPLAIN command */ EXPLAIN(Build.current().isSnapshot()), + /** + * Support for the RLIKE operator with a list of regexes. + */ + RLIKE_WITH_LIST_OF_PATTERNS, /** * FUSE command */ FUSE(Build.current().isSnapshot()), + /** * Support improved behavior for LIKE operator when used with index fields. */ LIKE_ON_INDEX_FIELDS, - /** - * Support avg with aggregate metric doubles - */ - AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), /** * Forbid usage of brackets in unquoted index and enrich policy names * https://github.com/elastic/elasticsearch/issues/130378 */ - NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES; + NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES, + + /** + * Cosine vector similarity function + */ + COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()), + + /** + * Support for the options field of CATEGORIZE. + */ + CATEGORIZE_OPTIONS; private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java index 55c36aa1cf353..61d0d3b0e1026 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java @@ -28,6 +28,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.EnumMap; import java.util.Iterator; @@ -562,8 +563,14 @@ public Cluster.Builder setFailedShards(int failedShards) { return this; } - public Cluster.Builder setFailures(List failures) { - this.failures = failures; + public Cluster.Builder addFailures(List failures) { + if (failures.isEmpty()) { + return this; + } + if (this.failures == null) { + this.failures = new ArrayList<>(original.failures); + } + this.failures.addAll(failures); return this; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index e4b8949af5bdb..44e4fd5a1bb3c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.logging.HeaderWarning; import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.compute.data.AggregateMetricDoubleBlockBuilder; import org.elasticsearch.compute.data.Block; import org.elasticsearch.core.Strings; import org.elasticsearch.index.IndexMode; @@ -53,6 +54,17 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition; import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; +import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; +import org.elasticsearch.xpack.esql.expression.function.aggregate.AvgOverTime; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.CountOverTime; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; +import org.elasticsearch.xpack.esql.expression.function.aggregate.MaxOverTime; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; +import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; @@ -61,6 +73,8 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ConvertFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FoldablesConvertFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; @@ -135,6 +149,7 @@ import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.IMPLICIT_CASTING_DATE_AND_DATE_NANOS; +import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; @@ -182,7 +197,8 @@ public class Analyzer extends ParameterizedRuleExecutor("Finish Analysis", Limiter.ONCE, new AddImplicitLimit(), new AddImplicitForkLimit(), new UnionTypesCleanup()) ); @@ -1400,15 +1416,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func if (f instanceof In in) { return processIn(in); } + if (f instanceof VectorFunction) { + return processVectorFunction(f); + } if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed return processScalarOrGroupingFunction(f, registry); } if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) { return processBinaryOperator((BinaryOperator) f); } - if (f instanceof VectorFunction vectorFunction) { - return processVectorFunction(f); - } return f; } @@ -1613,6 +1629,7 @@ private static Expression castStringLiteral(Expression from, DataType target) { } } + @SuppressWarnings("unchecked") private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) { List args = vectorFunction.arguments(); List newArgs = new ArrayList<>(); @@ -1620,7 +1637,14 @@ private static Expression processVectorFunction(org.elasticsearch.xpack.esql.cor if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) { Object folded = arg.fold(FoldContext.small() /* TODO remove me */); if (folded instanceof List) { - Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR); + // Convert to floats so blocks are created accordingly + List floatVector; + if (arg.dataType() == FLOAT) { + floatVector = (List) folded; + } else { + floatVector = ((List) folded).stream().map(Number::floatValue).collect(Collectors.toList()); + } + Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR); newArgs.add(denseVector); continue; } @@ -1680,9 +1704,15 @@ private LogicalPlan doRule(LogicalPlan plan) { return plan; } - // And add generated fields to EsRelation, so these new attributes will appear in the OutputExec of the Fragment - // and thereby get used in FieldExtractExec - plan = plan.transformDown(EsRelation.class, esr -> { + return addGeneratedFieldsToEsRelations(plan, unionFieldAttributes); + } + + /** + * Add generated fields to EsRelation, so these new attributes will appear in the OutputExec of the Fragment + * and thereby get used in FieldExtractExec + */ + private static LogicalPlan addGeneratedFieldsToEsRelations(LogicalPlan plan, List unionFieldAttributes) { + return plan.transformDown(EsRelation.class, esr -> { List missing = new ArrayList<>(); for (FieldAttribute fa : unionFieldAttributes) { // Using outputSet().contains looks by NameId, resp. uses semanticEquals. @@ -1702,7 +1732,6 @@ private LogicalPlan doRule(LogicalPlan plan) { } return esr; }); - return plan; } private Expression resolveConvertFunction(ConvertFunction convert, List unionFieldAttributes) { @@ -1830,7 +1859,10 @@ private static Expression typeSpecificConvert(ConvertFunction convert, Source so originalFieldAttr.id(), true ); - Expression e = ((Expression) convert).replaceChildren(Collections.singletonList(resolvedAttr)); + Expression fn = (Expression) convert; + List children = new ArrayList<>(fn.children()); + children.set(0, resolvedAttr); + Expression e = ((Expression) convert).replaceChildren(children); /* * Resolve surrogates immediately because these type specific conversions are serialized * and SurrogateExpressions are expected to be resolved on the coordinating node. At least, @@ -1949,4 +1981,103 @@ private static void typeResolutions( var concreteConvert = ResolveUnionTypes.typeSpecificConvert(convert, fieldAttribute.source(), type, imf); typeResolutions.put(key, concreteConvert); } + + /** + * Take InvalidMappedFields in specific aggregations (min, max, sum, count, and avg) and if all original data types + * are aggregate metric double + any combination of numerics, implicitly cast them to the same type: aggregate metric + * double for count, and double for min, max, and sum. Avg gets replaced with its surrogate (Div(Sum, Count)) + */ + private static class ImplicitCastAggregateMetricDoubles extends Rule { + + @Override + public LogicalPlan apply(LogicalPlan plan) { + return plan.transformUp(Aggregate.class, p -> p.childrenResolved() == false ? p : doRule(p)); + } + + private LogicalPlan doRule(Aggregate plan) { + Map unionFields = new HashMap<>(); + Holder aborted = new Holder<>(Boolean.FALSE); + var newPlan = plan.transformExpressionsOnly(AggregateFunction.class, aggFunc -> { + if (aggFunc.field() instanceof FieldAttribute fa && fa.field() instanceof InvalidMappedField mtf) { + if (mtf.types().contains(AGGREGATE_METRIC_DOUBLE) == false + || mtf.types().stream().allMatch(f -> f == AGGREGATE_METRIC_DOUBLE || f.isNumeric()) == false) { + aborted.set(Boolean.TRUE); + return aggFunc; + } + Map typeConverters = typeConverters(aggFunc, fa, mtf); + if (typeConverters == null) { + aborted.set(Boolean.TRUE); + return aggFunc; + } + var newField = unionFields.computeIfAbsent( + Attribute.rawTemporaryName(fa.name(), aggFunc.functionName(), aggFunc.sourceText()), + newName -> new FieldAttribute( + fa.source(), + fa.parentName(), + newName, + MultiTypeEsField.resolveFrom(mtf, typeConverters), + fa.nullable(), + null, + true + ) + ); + List children = new ArrayList<>(aggFunc.children()); + children.set(0, newField); + return aggFunc.replaceChildren(children); + } + return aggFunc; + }); + if (unionFields.isEmpty() || aborted.get()) { + return plan; + } + return ResolveUnionTypes.addGeneratedFieldsToEsRelations(newPlan, unionFields.values().stream().toList()); + } + + private Map typeConverters(AggregateFunction aggFunc, FieldAttribute fa, InvalidMappedField mtf) { + var metric = getMetric(aggFunc); + if (metric == null) { + return null; + } + Map typeConverter = new HashMap<>(); + for (DataType type : mtf.types()) { + final ConvertFunction convert; + // Counting on aggregate metric double has unique behavior in that we cannot just provide the number of + // documents, instead we have to look inside the aggregate metric double's count field and sum those together. + // Grabbing the count value with FromAggregateMetricDouble the same way we do with min/max/sum would result in + // a single Int field, and incorrectly be treated as 1 document (instead of however many originally went into + // the aggregate metric double). + if (metric == AggregateMetricDoubleBlockBuilder.Metric.COUNT) { + convert = new ToAggregateMetricDouble(fa.source(), fa); + } else if (type == AGGREGATE_METRIC_DOUBLE) { + convert = FromAggregateMetricDouble.withMetric(aggFunc.source(), fa, metric); + } else if (type.isNumeric()) { + convert = new ToDouble(fa.source(), fa); + } else { + return null; + } + Expression expression = ResolveUnionTypes.typeSpecificConvert(convert, fa.source(), type, mtf); + typeConverter.put(type.typeName(), expression); + } + return typeConverter; + } + + private static AggregateMetricDoubleBlockBuilder.Metric getMetric(AggregateFunction aggFunc) { + if (aggFunc instanceof Max || aggFunc instanceof MaxOverTime) { + return AggregateMetricDoubleBlockBuilder.Metric.MAX; + } + if (aggFunc instanceof Min || aggFunc instanceof MinOverTime) { + return AggregateMetricDoubleBlockBuilder.Metric.MIN; + } + if (aggFunc instanceof Sum || aggFunc instanceof SumOverTime) { + return AggregateMetricDoubleBlockBuilder.Metric.SUM; + } + if (aggFunc instanceof Count || aggFunc instanceof CountOverTime) { + return AggregateMetricDoubleBlockBuilder.Metric.COUNT; + } + if (aggFunc instanceof Avg || aggFunc instanceof AvgOverTime) { + return AggregateMetricDoubleBlockBuilder.Metric.COUNT; + } + return null; + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java index 6d5630b0e6581..dd305f09c12dc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java @@ -450,6 +450,7 @@ private static Operator extractFieldsOperator( } return new ValuesSourceReaderOperator( driverContext.blockFactory(), + Long.MAX_VALUE, fields, List.of( new ValuesSourceReaderOperator.ShardContext( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index a3f6d3a089d49..311f666581279 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; @@ -82,10 +81,11 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Space; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLikeList; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; -import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull; @@ -182,6 +182,7 @@ public static List unaryScalars() { entries.add(Neg.ENTRY); entries.add(Not.ENTRY); entries.add(RLike.ENTRY); + entries.add(RLikeList.ENTRY); entries.add(RTrim.ENTRY); entries.add(Scalb.ENTRY); entries.add(Signum.ENTRY); @@ -259,9 +260,6 @@ private static List fullText() { } private static List vector() { - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { - return List.of(Knn.ENTRY); - } - return List.of(); + return VectorWritables.getNamedWritables(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java new file mode 100644 index 0000000000000..f0401ae1d4f05 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.stats.SearchStats; + +/** + * Interface signaling to the local logical plan optimizer that the declaring expression + * has to be replaced by a different form. + * Implement this on {@code Function}s when: + *
    + *
  • The expression can be rewritten to another expression on data node, with the statistics available in SearchStats. + * Like {@code DateTrunc} and {@code Bucket} could be rewritten to {@code RoundTo} with the min/max values on the date field. + *
  • + *
+ */ +public interface LocalSurrogateExpression { + /** + * Returns the expression to be replaced by or {@code null} if this cannot be replaced. + */ + Expression surrogate(SearchStats searchStats); +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 630c9c2008a13..65e3f56d267a6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchPhrase; import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Score; import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; @@ -179,6 +180,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; +import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -301,7 +303,7 @@ private static FunctionDefinition[][] functions() { return new FunctionDefinition[][] { // grouping functions new FunctionDefinition[] { - def(Bucket.class, Bucket::new, "bucket", "bin"), + def(Bucket.class, quin(Bucket::new), "bucket", "bin"), def(Categorize.class, Categorize::new, "categorize") }, // aggregate functions // since they declare two public constructors - one with filter (for nested where) and one without @@ -477,8 +479,9 @@ private static FunctionDefinition[][] snapshotFunctions() { def(AvgOverTime.class, uni(AvgOverTime::new), "avg_over_time"), def(LastOverTime.class, uni(LastOverTime::new), "last_over_time"), def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), + def(Score.class, uni(Score::new), Score.NAME), def(Term.class, bi(Term::new), "term"), - def(Knn.class, Knn::new, "knn"), + def(Knn.class, quad(Knn::new), "knn"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), @@ -487,7 +490,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"), def(StGeohex.class, StGeohex::new, "st_geohex"), def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"), - def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } }; + def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"), + def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } }; } public EsqlFunctionRegistry snapshotRegistry() { @@ -1010,6 +1014,39 @@ protected static FunctionDefinition def(Class function, return def(function, builder, names); } + /** + * Build a {@linkplain FunctionDefinition} for a quinary function. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected static FunctionDefinition def(Class function, QuinaryBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { + if (OptionalArgument.class.isAssignableFrom(function)) { + if (children.size() > 5 || children.size() < 4) { + throw new QlIllegalArgumentException("expects four or five arguments"); + } + } else if (TwoOptionalArguments.class.isAssignableFrom(function)) { + if (children.size() > 5 || children.size() < 3) { + throw new QlIllegalArgumentException("expects minimum three, maximum five arguments"); + } + } else if (ThreeOptionalArguments.class.isAssignableFrom(function)) { + if (children.size() > 5 || children.size() < 2) { + throw new QlIllegalArgumentException("expects minimum two, maximum five arguments"); + } + } else if (children.size() != 5) { + throw new QlIllegalArgumentException("expects exactly five arguments"); + } + return ctorRef.build( + source, + children.get(0), + children.get(1), + children.size() > 2 ? children.get(2) : null, + children.size() > 3 ? children.get(3) : null, + children.size() > 4 ? children.get(4) : null + ); + }; + return def(function, builder, names); + } + protected interface QuaternaryBuilder { T build(Source source, Expression one, Expression two, Expression three, Expression four); } @@ -1204,4 +1241,11 @@ private static TernaryBuilder tri(TernaryBuilder func return function; } + private static QuaternaryBuilder quad(QuaternaryBuilder function) { + return function; + } + + private static QuinaryBuilder quin(QuinaryBuilder function) { + return function; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java new file mode 100644 index 0000000000000..891d8f1e6c264 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.EntryExpression; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; + +public class Options { + + public static Expression.TypeResolution resolve( + Expression options, + Source source, + TypeResolutions.ParamOrdinal paramOrdinal, + Map allowedOptions + ) { + return resolve(options, source, paramOrdinal, allowedOptions, null); + } + + public static Expression.TypeResolution resolve( + Expression options, + Source source, + TypeResolutions.ParamOrdinal paramOrdinal, + Map allowedOptions, + Consumer> verifyOptions + ) { + if (options != null) { + Expression.TypeResolution resolution = isNotNull(options, source.text(), paramOrdinal); + if (resolution.unresolved()) { + return resolution; + } + // MapExpression does not have a DataType associated with it + resolution = isMapExpression(options, source.text(), paramOrdinal); + if (resolution.unresolved()) { + return resolution; + } + try { + Map optionsMap = new HashMap<>(); + populateMap((MapExpression) options, optionsMap, source, paramOrdinal, allowedOptions); + if (verifyOptions != null) { + verifyOptions.accept(optionsMap); + } + } catch (InvalidArgumentException e) { + return new Expression.TypeResolution(e.getMessage()); + } + } + return Expression.TypeResolution.TYPE_RESOLVED; + } + + public static void populateMap( + final MapExpression options, + final Map optionsMap, + final Source source, + final TypeResolutions.ParamOrdinal paramOrdinal, + final Map allowedOptions + ) throws InvalidArgumentException { + for (EntryExpression entry : options.entryExpressions()) { + Expression optionExpr = entry.key(); + Expression valueExpr = entry.value(); + Expression.TypeResolution resolution = isFoldable(optionExpr, source.text(), paramOrdinal).and( + isFoldable(valueExpr, source.text(), paramOrdinal) + ); + if (resolution.unresolved()) { + throw new InvalidArgumentException(resolution.message()); + } + Object optionExprLiteral = ((Literal) optionExpr).value(); + Object valueExprLiteral = ((Literal) valueExpr).value(); + String optionName = optionExprLiteral instanceof BytesRef br ? br.utf8ToString() : optionExprLiteral.toString(); + String optionValue = valueExprLiteral instanceof BytesRef br ? br.utf8ToString() : valueExprLiteral.toString(); + // validate the optionExpr is supported + DataType dataType = allowedOptions.get(optionName); + if (dataType == null) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], expected one of {}", optionName, source.text(), allowedOptions.keySet()) + ); + } + try { + optionsMap.put(optionName, DataTypeConverter.convert(optionValue, dataType)); + } catch (InvalidArgumentException e) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], {}", optionName, source.text(), e.getMessage()) + ); + } + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/ThreeOptionalArguments.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/ThreeOptionalArguments.java new file mode 100644 index 0000000000000..be464c05ef6cb --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/ThreeOptionalArguments.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +/** + * Marker interface indicating that a function accepts three optional arguments (the last three). + * This is used by the {@link EsqlFunctionRegistry} to perform validation of function declaration. + */ +public interface ThreeOptionalArguments { + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java index e469f16f8d5a2..5e8c3a9bcf104 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java @@ -14,12 +14,14 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.SurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import java.io.IOException; import java.util.List; @@ -29,7 +31,7 @@ /** * Similar to {@link Avg}, but it is used to calculate the average value over a time series of values from the given field. */ -public class AvgOverTime extends TimeSeriesAggregateFunction { +public class AvgOverTime extends TimeSeriesAggregateFunction implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "AvgOverTime", @@ -93,6 +95,13 @@ public AvgOverTime withFilter(Expression filter) { return new AvgOverTime(source(), field(), filter); } + @Override + public Expression surrogate() { + Source s = source(); + Expression f = field(); + return new Div(s, new SumOverTime(s, f, filter()), new CountOverTime(s, f, filter()), dataType()); + } + @Override public AggregateFunction perTimeSeriesAggregation() { return new Avg(source(), field(), filter()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java index d353373453153..ce37e98b292f1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/package-info.java @@ -121,7 +121,7 @@ * It is also possible to declare any number of arbitrary arguments that must be provided via generated Supplier. * *
  • - * {@code combine, combineStates, combineIntermediate, evaluateFinal} methods (see below) could be generated automatically + * {@code combine, combineIntermediate, evaluateFinal} methods (see below) could be generated automatically * when both input type I and mutable accumulator state AggregatorState and GroupingAggregatorState are primitive (DOUBLE, INT). *
  • *
  • @@ -167,10 +167,6 @@ * of the grouping aggregation state *
  • *
  • - * {@code void combineStates(GroupingAggregatorState targetState, int targetGroupId, GS otherState, int otherGroupId)} - * merges other grouped aggregation state into the first one - *
  • - *
  • * {@code void combineIntermediate(GroupingAggregatorState current, int groupId, intermediate states)} adds serialized * aggregation state to the current grouped aggregation state (used to combine results across different nodes) *
  • diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index ec29b4b658c76..b5378db783f46 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.fulltext; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.ShardConfig; import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator; @@ -16,28 +15,25 @@ import org.elasticsearch.compute.operator.ScoreOperator; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; -import org.elasticsearch.xpack.esql.core.InvalidArgumentException; -import org.elasticsearch.xpack.esql.core.expression.EntryExpression; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; @@ -50,19 +46,15 @@ import org.elasticsearch.xpack.esql.querydsl.query.TranslationAwareExpressionQuery; import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper; +import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Predicate; -import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; @@ -163,20 +155,19 @@ public boolean equals(Object obj) { @Override public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { - // In isolation, full text functions are pushable to source. We check if there are no disjunctions in Or conditions return Translatable.YES; } @Override public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { - return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(handler); + return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(pushdownPredicates, handler); } public QueryBuilder queryBuilder() { return queryBuilder; } - protected abstract Query translate(TranslatorHandler handler); + protected abstract Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler); public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder); @@ -195,6 +186,10 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu if (plan instanceof Filter f) { Expression condition = f.condition(); + if (condition instanceof Score) { + failures.add(fail(condition, "[SCORE] function can't be used in WHERE")); + } + List.of(QueryString.class, Kql.class).forEach(functionClass -> { // Check for limitations of QSTR and KQL function. checkCommandsBeforeExpression( @@ -219,12 +214,38 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Failures failu } else if (plan instanceof Aggregate agg) { checkFullTextFunctionsInAggs(agg, failures); } else { + List scoredFTFs = new ArrayList<>(); + plan.forEachExpression(Score.class, scoreFunction -> { + checkScoreFunction(plan, failures, scoreFunction); + plan.forEachExpression(FullTextFunction.class, scoredFTFs::add); + }); plan.forEachExpression(FullTextFunction.class, ftf -> { - failures.add(fail(ftf, "[{}] {} is only supported in WHERE and STATS commands", ftf.functionName(), ftf.functionType())); + if (scoredFTFs.remove(ftf) == false) { + failures.add( + fail( + ftf, + "[{}] {} is only supported in WHERE and STATS commands" + + (EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled() ? ", or in EVAL within score(.) function" : ""), + ftf.functionName(), + ftf.functionType() + ) + ); + } }); } } + private static void checkScoreFunction(LogicalPlan plan, Failures failures, Score scoreFunction) { + checkCommandsBeforeExpression( + plan, + scoreFunction.canonical(), + Score.class, + lp -> (lp instanceof Limit == false) && (lp instanceof Aggregate == false), + m -> "[" + m.functionName() + "] function", + failures + ); + } + private static void checkFullTextFunctionsInAggs(Aggregate agg, Failures failures) { agg.groupings().forEach(exp -> { exp.forEachDown(e -> { @@ -281,6 +302,7 @@ private static void checkFullTextFunctionsParents(Expression condition, Failures forEachFullTextFunctionParent(condition, (ftf, parent) -> { if ((parent instanceof FullTextFunction == false) && (parent instanceof BinaryLogic == false) + && (parent instanceof EsqlBinaryComparison == false) && (parent instanceof Not == false)) { failures.add( fail( @@ -376,66 +398,6 @@ public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { return new LuceneQueryScoreEvaluator.Factory(shardConfigs); } - protected static void populateOptionsMap( - final MapExpression options, - final Map optionsMap, - final TypeResolutions.ParamOrdinal paramOrdinal, - final String sourceText, - final Map allowedOptions - ) throws InvalidArgumentException { - for (EntryExpression entry : options.entryExpressions()) { - Expression optionExpr = entry.key(); - Expression valueExpr = entry.value(); - TypeResolution resolution = isFoldable(optionExpr, sourceText, paramOrdinal).and( - isFoldable(valueExpr, sourceText, paramOrdinal) - ); - if (resolution.unresolved()) { - throw new InvalidArgumentException(resolution.message()); - } - Object optionExprLiteral = ((Literal) optionExpr).value(); - Object valueExprLiteral = ((Literal) valueExpr).value(); - String optionName = optionExprLiteral instanceof BytesRef br ? br.utf8ToString() : optionExprLiteral.toString(); - String optionValue = valueExprLiteral instanceof BytesRef br ? br.utf8ToString() : valueExprLiteral.toString(); - // validate the optionExpr is supported - DataType dataType = allowedOptions.get(optionName); - if (dataType == null) { - throw new InvalidArgumentException( - format(null, "Invalid option [{}] in [{}], expected one of {}", optionName, sourceText, allowedOptions.keySet()) - ); - } - try { - optionsMap.put(optionName, DataTypeConverter.convert(optionValue, dataType)); - } catch (InvalidArgumentException e) { - throw new InvalidArgumentException(format(null, "Invalid option [{}] in [{}], {}", optionName, sourceText, e.getMessage())); - } - } - } - - protected TypeResolution resolveOptions(Expression options, TypeResolutions.ParamOrdinal paramOrdinal) { - if (options != null) { - TypeResolution resolution = isNotNull(options, sourceText(), paramOrdinal); - if (resolution.unresolved()) { - return resolution; - } - // MapExpression does not have a DataType associated with it - resolution = isMapExpression(options, sourceText(), paramOrdinal); - if (resolution.unresolved()) { - return resolution; - } - - try { - resolvedOptions(); - } catch (InvalidArgumentException e) { - return new TypeResolution(e.getMessage()); - } - } - return TypeResolution.TYPE_RESOLVED; - } - - protected Map resolvedOptions() throws InvalidArgumentException { - return Map.of(); - } - // TODO: this should likely be replaced by calls to FieldAttribute#fieldName; the MultiTypeEsField case looks // wrong if `fieldAttribute` is a subfield, e.g. `parent.child` - multiTypeEsField#getName will just return `child`. public static String getNameFromFieldAttribute(FieldAttribute fieldAttribute) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java index 18c0a22589baa..657017a76b1db 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java @@ -28,6 +28,9 @@ public static List getNamedWriteables() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { entries.add(Term.ENTRY); } + if (EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled()) { + entries.add(Score.ENTRY); + } return Collections.unmodifiableList(entries); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java index b373becca9965..df3cf5af84232 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery; @@ -93,7 +94,7 @@ protected NodeInfo info() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { return new KqlQuery(source(), Objects.toString(queryAsObject())); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java index e6d99d158aaaf..5c5a46fd2f759 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java @@ -33,8 +33,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.MatchQuery; @@ -297,7 +299,9 @@ public final void writeTo(StreamOutput out) throws IOException { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveOptions(options(), THIRD)).and(checkParamCompatibility()); + return resolveField().and(resolveQuery()) + .and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS)) + .and(checkParamCompatibility()); } private TypeResolution resolveField() { @@ -341,11 +345,6 @@ private TypeResolution checkParamCompatibility() { return new TypeResolution(formatIncompatibleTypesMessage(fieldType, queryType, sourceText())); } - @Override - protected Map resolvedOptions() { - return matchQueryOptions(); - } - private Map matchQueryOptions() throws InvalidArgumentException { if (options() == null) { return Map.of(LENIENT_FIELD.getPreferredName(), true); @@ -355,7 +354,7 @@ private Map matchQueryOptions() throws InvalidArgumentException // Match is lenient by default to avoid failing on incompatible types matchOptions.put(LENIENT_FIELD.getPreferredName(), true); - populateOptionsMap((MapExpression) options(), matchOptions, SECOND, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), matchOptions, source(), SECOND, ALLOWED_OPTIONS); return matchOptions; } @@ -423,7 +422,7 @@ public Object queryAsObject() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = fieldAsFieldAttribute(); Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java index 4a99227576611..4ed0e16ab5b4a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java @@ -30,8 +30,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.MatchPhraseQuery; @@ -89,9 +91,7 @@ public class MatchPhrase extends FullTextFunction implements OptionalArgument, P @FunctionInfo( returnType = "boolean", - appliesTo = { - @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.UNAVAILABLE, version = "9.0"), - @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.GA, version = "9.1.0") }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.GA, version = "9.1.0") }, description = """ Use `MATCH_PHRASE` to perform a [`match_phrase`](/reference/query-languages/query-dsl/query-dsl-match-query-phrase.md) on the specified field. @@ -188,7 +188,7 @@ public final void writeTo(StreamOutput out) throws IOException { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveOptions(options(), THIRD)); + return resolveField().and(resolveQuery()).and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS)); } private TypeResolution resolveField() { @@ -201,18 +201,13 @@ private TypeResolution resolveQuery() { ); } - @Override - protected Map resolvedOptions() throws InvalidArgumentException { - return matchPhraseQueryOptions(); - } - private Map matchPhraseQueryOptions() throws InvalidArgumentException { if (options() == null) { return Map.of(); } Map matchPhraseOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchPhraseOptions, SECOND, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), matchPhraseOptions, source(), SECOND, ALLOWED_OPTIONS); return matchPhraseOptions; } @@ -278,7 +273,7 @@ public Object queryAsObject() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = fieldAsFieldAttribute(); Check.notNull(fieldAttribute, "MatchPhrase must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java index 2c398c7f6c6f1..3e9fed6be850d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java @@ -29,8 +29,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.MultiMatchQuery; @@ -335,7 +337,7 @@ protected NodeInfo info() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { Map fieldsWithBoost = new HashMap<>(); for (Expression field : fields) { var fieldAttribute = Match.fieldAsFieldAttribute(field); @@ -367,7 +369,7 @@ private Map getOptions() throws InvalidArgumentException { return options; } - Match.populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), OPTIONS); + Options.populateMap((MapExpression) options(), options, source(), THIRD, OPTIONS); return options; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java index 4e201a17a4aec..7285f19fc5aa7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java @@ -26,8 +26,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import java.io.IOException; @@ -320,18 +322,13 @@ private Map queryStringOptions() throws InvalidArgumentException } Map matchOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchOptions, SECOND, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), matchOptions, source(), SECOND, ALLOWED_OPTIONS); return matchOptions; } - @Override - protected Map resolvedOptions() { - return queryStringOptions(); - } - @Override protected TypeResolution resolveParams() { - return resolveQuery().and(resolveOptions(options(), SECOND)); + return resolveQuery().and(Options.resolve(options(), source(), SECOND, ALLOWED_OPTIONS)); } @Override @@ -345,7 +342,7 @@ protected NodeInfo info() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { return new QueryStringQuery(source(), Objects.toString(queryAsObject()), Map.of(), queryStringOptions()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Score.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Score.java new file mode 100644 index 0000000000000..1b471931eaa0e --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Score.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.fulltext; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.ScoreOperator; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.score.ScoreMapper; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * A function to be used to score specific portions of an ES|QL query e.g., in conjunction with + * an {@link org.elasticsearch.xpack.esql.plan.logical.Eval}. + */ +public class Score extends Function implements EvaluatorMapper { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "score", Score::readFrom); + + public static final String NAME = "score"; + + @FunctionInfo( + returnType = "double", + preview = true, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }, + description = "Scores an expression. Only full text functions will be scored. Returns scores for all the resulting docs.", + examples = { @Example(file = "score-function", tag = "score-function") } + ) + public Score( + Source source, + @Param( + name = "query", + type = { "boolean" }, + description = "Boolean expression that contains full text function(s) to be scored." + ) Expression scorableQuery + ) { + this(source, List.of(scorableQuery)); + } + + protected Score(Source source, List children) { + super(source, children); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Score(source(), newChildren); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Score::new, children().getFirst()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { + ScoreOperator.ExpressionScorer.Factory scorerFactory = ScoreMapper.toScorer(children().getFirst(), toEvaluator.shardContexts()); + return driverContext -> new ScorerEvaluatorFactory(scorerFactory).get(driverContext); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteableCollection(this.children()); + } + + private static Expression readFrom(StreamInput in) throws IOException { + Source source = Source.readFrom((PlanStreamInput) in); + Expression query = in.readOptionalNamedWriteable(Expression.class); + return new Score(source, query); + } + + private record ScorerEvaluatorFactory(ScoreOperator.ExpressionScorer.Factory scoreFactory) + implements + EvalOperator.ExpressionEvaluator.Factory { + + @Override + public EvalOperator.ExpressionEvaluator get(DriverContext context) { + return new EvalOperator.ExpressionEvaluator() { + + private final ScoreOperator.ExpressionScorer scorer = scoreFactory.get(context); + + @Override + public void close() { + scorer.close(); + } + + @Override + public Block eval(Page page) { + return scorer.score(page); + } + }; + } + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Score score = (Score) o; + return super.equals(o) && score.children().equals(children()); + } + + @Override + public int hashCode() { + return Objects.hash(children()); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java index 76188dc146ee6..cecef10a136f7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; @@ -130,7 +131,7 @@ protected TypeResolutions.ParamOrdinal queryParamOrdinal() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { // Uses a term query that contributes to scoring return new TermQuery(source(), ((FieldAttribute) field()).name(), queryAsObject(), false, true); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index bb6633686fc7c..3f316c3a9b473 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -8,10 +8,15 @@ package org.elasticsearch.xpack.esql.expression.function.grouping; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Rounding; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; @@ -25,20 +30,22 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.LocalSurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; import org.elasticsearch.xpack.esql.expression.function.Param; -import org.elasticsearch.xpack.esql.expression.function.TwoOptionalArguments; +import org.elasticsearch.xpack.esql.expression.function.ThreeOptionalArguments; import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Floor; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.stats.SearchStats; import java.io.IOException; -import java.time.ZoneId; -import java.time.ZoneOffset; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; @@ -50,6 +57,8 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.expression.Validations.isFoldable; +import static org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc.maybeSubstituteWithRoundTo; +import static org.elasticsearch.xpack.esql.session.Configuration.DEFAULT_TZ; import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong; /** @@ -61,7 +70,9 @@ public class Bucket extends GroupingFunction.EvaluatableGroupingFunction implements PostOptimizationVerificationAware, - TwoOptionalArguments { + ThreeOptionalArguments, + LocalSurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Bucket", Bucket::new); // TODO maybe we should just cover the whole of representable dates here - like ten years, 100 years, 1000 years, all the way up. @@ -87,12 +98,11 @@ public class Bucket extends GroupingFunction.EvaluatableGroupingFunction Rounding.builder(TimeValue.timeValueMillis(10)).build(), Rounding.builder(TimeValue.timeValueMillis(1)).build(), }; - private static final ZoneId DEFAULT_TZ = ZoneOffset.UTC; // TODO: plug in the config - private final Expression field; private final Expression buckets; private final Expression from; private final Expression to; + private final Expression emitEmptyBuckets; @FunctionInfo( returnType = { "double", "date", "date_nanos" }, @@ -211,13 +221,20 @@ public Bucket( type = { "integer", "long", "double", "date", "keyword", "text" }, optional = true, description = "End of the range. Can be a number, a date or a date expressed as a string." - ) Expression to + ) Expression to, + @Param( + name = "emitEmptyBuckets", + type = { "boolean" }, + optional = true, + description = "Whether or not empty buckets should be emitted." + ) Expression emitEmptyBuckets ) { - super(source, fields(field, buckets, from, to)); + super(source, fields(field, buckets, from, to, emitEmptyBuckets)); this.field = field; this.buckets = buckets; this.from = from; this.to = to; + this.emitEmptyBuckets = emitEmptyBuckets; } private Bucket(StreamInput in) throws IOException { @@ -226,11 +243,20 @@ private Bucket(StreamInput in) throws IOException { in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class), in.readOptionalNamedWriteable(Expression.class), - in.readOptionalNamedWriteable(Expression.class) + in.readOptionalNamedWriteable(Expression.class), + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_EMIT_EMPTY_BUCKETS) + ? in.readOptionalNamedWriteable(Expression.class) + : null ); } - private static List fields(Expression field, Expression buckets, Expression from, Expression to) { + private static List fields( + Expression field, + Expression buckets, + Expression from, + Expression to, + Expression emitEmptyBuckets + ) { List list = new ArrayList<>(4); list.add(field); list.add(buckets); @@ -240,6 +266,9 @@ private static List fields(Expression field, Expression buckets, Exp list.add(to); } } + if (emitEmptyBuckets != null) { + list.add(emitEmptyBuckets); + } return list; } @@ -250,6 +279,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(buckets); out.writeOptionalNamedWriteable(from); out.writeOptionalNamedWriteable(to); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_EMIT_EMPTY_BUCKETS)) { + out.writeOptionalNamedWriteable(emitEmptyBuckets); + } } @Override @@ -259,25 +291,21 @@ public String getWriteableName() { @Override public boolean foldable() { - return field.foldable() && buckets.foldable() && (from == null || from.foldable()) && (to == null || to.foldable()); + return field.foldable() + && buckets.foldable() + && (from == null || from.foldable()) + && (to == null || to.foldable()) + && (emitEmptyBuckets == null || emitEmptyBuckets.foldable()); } @Override public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) { - Rounding.Prepared preparedRounding = getDateRounding(toEvaluator.foldCtx()); + Rounding.Prepared preparedRounding = getDateRounding(field, buckets, from, to, toEvaluator.foldCtx()); return DateTrunc.evaluator(field.dataType(), source(), toEvaluator.apply(field), preparedRounding); } if (field.dataType().isNumeric()) { - double roundTo; - if (from != null) { - int b = ((Number) buckets.fold(toEvaluator.foldCtx())).intValue(); - double f = ((Number) from.fold(toEvaluator.foldCtx())).doubleValue(); - double t = ((Number) to.fold(toEvaluator.foldCtx())).doubleValue(); - roundTo = pickRounding(b, f, t); - } else { - roundTo = ((Number) buckets.fold(toEvaluator.foldCtx())).doubleValue(); - } + double roundTo = determineRounding(buckets, from, to, toEvaluator.foldCtx()); Literal rounding = new Literal(source(), roundTo, DataType.DOUBLE); // We could make this more efficient, either by generating the evaluators with byte code or hand rolling this one. @@ -289,27 +317,61 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { throw EsqlIllegalArgumentException.illegalDataType(field.dataType()); } + public BlockHash.EmptyBucketGenerator createEmptyBucketGenerator() { + assert emitEmptyBuckets() != null; + FoldContext foldContext = new FoldContext(128); + Boolean emit = (Boolean) emitEmptyBuckets.fold(foldContext); + if (Boolean.TRUE.equals(emit) == false) { + return null; + } else if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) { + return new DatetimeEmptyBucketGenerator(field, buckets, from, to, foldContext); + } else { + return new NumericEmptyBucketGenerator(buckets, from, to, foldContext); + } + } + /** * Returns the date rounding from this bucket function if the target field is a date type; otherwise, returns null. */ public Rounding.Prepared getDateRoundingOrNull(FoldContext foldCtx) { if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) { - return getDateRounding(foldCtx); + return getDateRounding(field, buckets, from, to, foldCtx); } else { return null; } } - private Rounding.Prepared getDateRounding(FoldContext foldContext) { + private static Rounding.Prepared getDateRounding( + Expression field, + Expression buckets, + Expression from, + Expression to, + FoldContext foldContext + ) { + return getDateRounding(field, buckets, from, to, foldContext, null, null); + } + + private static Rounding.Prepared getDateRounding( + Expression field, + Expression buckets, + Expression from, + Expression to, + FoldContext foldContext, + Long min, + Long max + ) { assert field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS : "expected date type; got " + field; if (buckets.dataType().isWholeNumber()) { int b = ((Number) buckets.fold(foldContext)).intValue(); long f = foldToLong(foldContext, from); long t = foldToLong(foldContext, to); + if (min != null && max != null) { + return new DateRoundingPicker(b, f, t).pickRounding().prepare(min, max); + } return new DateRoundingPicker(b, f, t).pickRounding().prepareForUnknown(); } else { assert DataType.isTemporalAmount(buckets.dataType()) : "Unexpected span data type [" + buckets.dataType() + "]"; - return DateTrunc.createRounding(buckets.fold(foldContext), DEFAULT_TZ); + return DateTrunc.createRounding(buckets.fold(foldContext), DEFAULT_TZ, min, max); } } @@ -344,7 +406,18 @@ boolean roundingIsOk(Rounding rounding) { } } - private double pickRounding(int buckets, double from, double to) { + static double determineRounding(Expression buckets, Expression from, Expression to, FoldContext foldContext) { + if (from != null) { + int b = ((Number) buckets.fold(foldContext)).intValue(); + double f = ((Number) from.fold(foldContext)).doubleValue(); + double t = ((Number) to.fold(foldContext)).doubleValue(); + return pickRounding(b, f, t); + } else { + return ((Number) buckets.fold(foldContext)).doubleValue(); + } + } + + private static double pickRounding(int buckets, double from, double to) { double precise = (to - from) / buckets; double nextPowerOfTen = Math.pow(10, Math.ceil(Math.log10(precise))); double halfPower = nextPowerOfTen / 2; @@ -402,9 +475,8 @@ protected TypeResolution resolveType() { private TypeResolution checkArgsCount(int expectedCount) { String expected = null; - if (expectedCount == 2 && (from != null || to != null)) { - expected = "two"; - } else if (expectedCount == 4 && (from == null || to == null)) { + + if (expectedCount == 4 && (from == null || to == null)) { expected = "four"; } else if ((from == null && to != null) || (from != null && to == null)) { expected = "two or four"; @@ -443,7 +515,7 @@ public void postOptimizationVerification(Failures failures) { .add(to != null ? isFoldable(to, operation, FOURTH) : null); } - private long foldToLong(FoldContext ctx, Expression e) { + private static long foldToLong(FoldContext ctx, Expression e) { Object value = Foldables.valueOf(ctx, e); return DataType.isDateTime(e.dataType()) ? ((Number) value).longValue() : dateTimeToLong(((BytesRef) value).utf8ToString()); } @@ -460,12 +532,13 @@ public DataType dataType() { public Expression replaceChildren(List newChildren) { Expression from = newChildren.size() > 2 ? newChildren.get(2) : null; Expression to = newChildren.size() > 3 ? newChildren.get(3) : null; - return new Bucket(source(), newChildren.get(0), newChildren.get(1), from, to); + Expression emitEmptyBuckets = newChildren.size() > 4 ? newChildren.get(4) : null; + return new Bucket(source(), newChildren.get(0), newChildren.get(1), from, to, emitEmptyBuckets); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Bucket::new, field, buckets, from, to); + return NodeInfo.create(this, Bucket::new, field, buckets, from, to, emitEmptyBuckets); } public Expression field() { @@ -484,8 +557,103 @@ public Expression to() { return to; } + public Expression emitEmptyBuckets() { + return emitEmptyBuckets; + } + @Override public String toString() { - return "Bucket{" + "field=" + field + ", buckets=" + buckets + ", from=" + from + ", to=" + to + '}'; + return "Bucket{" + + "field=" + + field + + ", buckets=" + + buckets + + ", from=" + + from + + ", to=" + + to + + ", emitEmptyBuckets=" + + emitEmptyBuckets + + '}'; + } + + @Override + public Expression surrogate(SearchStats searchStats) { + // LocalSubstituteSurrogateExpressions should make sure this doesn't happen + assert searchStats != null : "SearchStats cannot be null"; + return maybeSubstituteWithRoundTo( + source(), + field(), + buckets(), + searchStats, + (interval, minValue, maxValue) -> getDateRounding(field, buckets, from, to, FoldContext.small(), minValue, maxValue) + ); + } + + record DatetimeEmptyBucketGenerator(long from, long to, Rounding.Prepared rounding) implements BlockHash.EmptyBucketGenerator { + + DatetimeEmptyBucketGenerator(Expression field, Expression buckets, Expression from, Expression to, FoldContext foldContext) { + this(foldToLong(foldContext, from), foldToLong(foldContext, to), getDateRounding(field, buckets, from, to, foldContext)); + } + + @Override + public int getEmptyBucketCount() { + int i = 0; + for (long bucket = rounding.round(from); bucket < to; bucket = rounding.nextRoundingValue(bucket)) { + i++; + } + return i; + } + + @Override + public void generate(Block.Builder blockBuilder) { + for (long bucket = rounding.round(from); bucket < to; bucket = rounding.nextRoundingValue(bucket)) { + ((LongBlock.Builder) blockBuilder).appendLong(bucket); + } + } + } + + record NumericEmptyBucketGenerator(double from, double to, double roundTo) implements BlockHash.EmptyBucketGenerator { + + NumericEmptyBucketGenerator(Expression buckets, Expression from, Expression to, FoldContext foldContext) { + this( + ((Number) from.fold(foldContext)).doubleValue(), + ((Number) to.fold(foldContext)).doubleValue(), + determineRounding(buckets, from, to, foldContext) + ); + } + + @Override + public int getEmptyBucketCount() { + int i = 0; + for (double bucket = round(Math.floor(from / roundTo) * roundTo, 2); bucket < to; bucket = round(bucket + roundTo, 2)) { + i++; + } + return i; + } + + @Override + public void generate(Block.Builder blockBuilder) { + for (double bucket = round(Math.floor(from / roundTo) * roundTo, 2); bucket < to; bucket = round(bucket + roundTo, 2)) { + ((DoubleBlock.Builder) blockBuilder).appendDouble(bucket); + } + } + + private static double round(double value, int n) { + return new BigDecimal(value).setScale(n, RoundingMode.HALF_UP).doubleValue(); + } + } + + @Override + public Expression surrogate(SearchStats searchStats) { + // LocalSubstituteSurrogateExpressions should make sure this doesn't happen + assert searchStats != null : "SearchStats cannot be null"; + return maybeSubstituteWithRoundTo( + source(), + field(), + buckets(), + searchStats, + (interval, minValue, maxValue) -> getDateRounding(FoldContext.small(), minValue, maxValue) + ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 15b4621589457..75918091f9ecd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -7,13 +7,18 @@ package org.elasticsearch.xpack.esql.expression.function.grouping; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash.CategorizeDef; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash.CategorizeDef.OutputFormat; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.xpack.esql.LicenseAware; import org.elasticsearch.xpack.esql.SupportsObservabilityTier; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -21,16 +26,29 @@ import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TreeMap; +import static java.util.Map.entry; +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.compute.aggregation.blockhash.BlockHash.CategorizeDef.OutputFormat.REGEX; import static org.elasticsearch.xpack.esql.SupportsObservabilityTier.ObservabilityTier.COMPLETE; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; /** * Categorizes text messages. @@ -42,14 +60,23 @@ *

    */ @SupportsObservabilityTier(tier = COMPLETE) -public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction implements LicenseAware { +public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction implements OptionalArgument, LicenseAware { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "Categorize", Categorize::new ); + private static final String ANALYZER = "analyzer"; + private static final String OUTPUT_FORMAT = "output_format"; + private static final String SIMILARITY_THRESHOLD = "similarity_threshold"; + + private static final Map ALLOWED_OPTIONS = new TreeMap<>( + Map.ofEntries(entry(ANALYZER, KEYWORD), entry(OUTPUT_FORMAT, KEYWORD), entry(SIMILARITY_THRESHOLD, INTEGER)) + ); + private final Expression field; + private final Expression options; @FunctionInfo( returnType = "keyword", @@ -70,21 +97,56 @@ public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction ) public Categorize( Source source, - @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field - + @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field, + @MapParam( + name = "options", + description = "(Optional) Categorize additional options as <>.", + params = { + @MapParam.MapParamEntry( + name = ANALYZER, + type = "keyword", + valueHint = { "standard" }, + description = "Analyzer used to convert the field into tokens for text categorization." + ), + @MapParam.MapParamEntry( + name = OUTPUT_FORMAT, + type = "keyword", + valueHint = { "regex", "tokens" }, + description = "The output format of the categories. Defaults to regex." + ), + @MapParam.MapParamEntry( + name = SIMILARITY_THRESHOLD, + type = "integer", + valueHint = { "70" }, + description = "The minimum percentage of token weight that must match for text to be added to the category bucket. " + + "Must be between 1 and 100. The larger the value the narrower the categories. " + + "Larger values will increase memory usage and create narrower categories. Defaults to 70." + ), }, + optional = true + ) Expression options ) { - super(source, List.of(field)); + super(source, options == null ? List.of(field) : List.of(field, options)); this.field = field; + this.options = options; } private Categorize(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class)); + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CATEGORIZE_OPTIONS) + ? in.readOptionalNamedWriteable(Expression.class) + : null + ); } @Override public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); out.writeNamedWriteable(field); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CATEGORIZE_OPTIONS)) { + out.writeOptionalNamedWriteable(options); + } } @Override @@ -107,7 +169,48 @@ public Nullability nullable() { @Override protected TypeResolution resolveType() { - return isString(field(), sourceText(), DEFAULT); + return isString(field(), sourceText(), DEFAULT).and( + Options.resolve(options, source(), SECOND, ALLOWED_OPTIONS, this::verifyOptions) + ); + } + + private void verifyOptions(Map optionsMap) { + if (options == null) { + return; + } + Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD); + if (similarityThreshold != null) { + if (similarityThreshold <= 0 || similarityThreshold > 100) { + throw new InvalidArgumentException( + format("invalid similarity threshold [{}], expecting a number between 1 and 100, inclusive", similarityThreshold) + ); + } + } + String outputFormat = (String) optionsMap.get(OUTPUT_FORMAT); + if (outputFormat != null) { + try { + OutputFormat.valueOf(outputFormat.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new InvalidArgumentException( + format(null, "invalid output format [{}], expecting one of [REGEX, TOKENS]", outputFormat) + ); + } + } + } + + public CategorizeDef categorizeDef() { + Map optionsMap = new HashMap<>(); + if (options != null) { + Options.populateMap((MapExpression) options, optionsMap, source(), SECOND, ALLOWED_OPTIONS); + } + Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD); + String outputFormatString = (String) optionsMap.get(OUTPUT_FORMAT); + OutputFormat outputFormat = outputFormatString == null ? null : OutputFormat.valueOf(outputFormatString.toUpperCase(Locale.ROOT)); + return new CategorizeDef( + (String) optionsMap.get("analyzer"), + outputFormat == null ? REGEX : outputFormat, + similarityThreshold == null ? 70 : similarityThreshold + ); } @Override @@ -117,12 +220,12 @@ public DataType dataType() { @Override public Categorize replaceChildren(List newChildren) { - return new Categorize(source(), newChildren.get(0)); + return new Categorize(source(), newChildren.get(0), newChildren.size() > 1 ? newChildren.get(1) : null); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Categorize::new, field); + return NodeInfo.create(this, Categorize::new, field, options); } public Expression field() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FromAggregateMetricDouble.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FromAggregateMetricDouble.java index 61129df973a55..f9f01307e025f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FromAggregateMetricDouble.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FromAggregateMetricDouble.java @@ -32,14 +32,16 @@ import java.io.IOException; import java.util.List; +import java.util.Set; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; -public class FromAggregateMetricDouble extends EsqlScalarFunction { +public class FromAggregateMetricDouble extends EsqlScalarFunction implements ConvertFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "FromAggregateMetricDouble", @@ -169,4 +171,14 @@ public String toString() { } }; } + + @Override + public Expression field() { + return field; + } + + @Override + public Set supportedTypes() { + return Set.of(AGGREGATE_METRIC_DOUBLE); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/BinaryDateTimeFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/BinaryDateTimeFunction.java deleted file mode 100644 index 74f0dae76c425..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/BinaryDateTimeFunction.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.expression.function.scalar.date; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; - -import java.time.ZoneId; -import java.time.ZoneOffset; -import java.util.Objects; - -public abstract class BinaryDateTimeFunction extends BinaryScalarFunction { - - protected static final ZoneId DEFAULT_TZ = ZoneOffset.UTC; - - private final ZoneId zoneId; - - protected BinaryDateTimeFunction(Source source, Expression argument, Expression timestamp) { - super(source, argument, timestamp); - zoneId = DEFAULT_TZ; - } - - @Override - public DataType dataType() { - return DataType.DATETIME; - } - - public Expression timestampField() { - return right(); - } - - public ZoneId zoneId() { - return zoneId; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), zoneId()); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (super.equals(o) == false) { - return false; - } - BinaryDateTimeFunction that = (BinaryDateTimeFunction) o; - return zoneId().equals(that.zoneId()); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java index a2ec96d1e0b34..2437fdc307415 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java @@ -57,7 +57,7 @@ public class DateDiff extends EsqlScalarFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "DateDiff", DateDiff::new); - public static final ZoneId UTC = ZoneId.of("Z"); + public static final ZoneId UTC = org.elasticsearch.xpack.esql.core.util.DateUtils.UTC; private final Expression unit; private final Expression startTimestamp; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java index 6981c8e3b9d82..9b4d312e9df42 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.date; import org.elasticsearch.common.Rounding; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -16,38 +17,53 @@ import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; +import org.elasticsearch.xpack.esql.expression.LocalSurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.stats.SearchStats; import java.io.IOException; import java.time.Duration; import java.time.Period; import java.time.ZoneId; -import java.time.ZoneOffset; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; +import static org.elasticsearch.xpack.esql.core.type.DataType.isDateTime; +import static org.elasticsearch.xpack.esql.session.Configuration.DEFAULT_TZ; +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateWithTypeToString; -public class DateTrunc extends EsqlScalarFunction { +public class DateTrunc extends EsqlScalarFunction implements LocalSurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "DateTrunc", DateTrunc::new ); + private static final Logger logger = LogManager.getLogger(DateTrunc.class); + @FunctionalInterface public interface DateTruncFactoryProvider { ExpressionEvaluator.Factory apply(Source source, ExpressionEvaluator.Factory lhs, Rounding.Prepared rounding); @@ -59,7 +75,6 @@ public interface DateTruncFactoryProvider { ); private final Expression interval; private final Expression timestampField; - protected static final ZoneId DEFAULT_TZ = ZoneOffset.UTC; @FunctionInfo( returnType = { "date", "date_nanos" }, @@ -163,14 +178,23 @@ static Rounding.Prepared createRounding(final Object interval) { public static Rounding.Prepared createRounding(final Object interval, final ZoneId timeZone) { if (interval instanceof Period period) { - return createRounding(period, timeZone); + return createRounding(period, timeZone, null, null); + } else if (interval instanceof Duration duration) { + return createRounding(duration, timeZone, null, null); + } + throw new IllegalArgumentException("Time interval is not supported"); + } + + public static Rounding.Prepared createRounding(final Object interval, final ZoneId timeZone, Long min, Long max) { + if (interval instanceof Period period) { + return createRounding(period, timeZone, min, max); } else if (interval instanceof Duration duration) { - return createRounding(duration, timeZone); + return createRounding(duration, timeZone, min, max); } throw new IllegalArgumentException("Time interval is not supported"); } - private static Rounding.Prepared createRounding(final Period period, final ZoneId timeZone) { + private static Rounding.Prepared createRounding(final Period period, final ZoneId timeZone, Long min, Long max) { // Zero or negative intervals are not supported if (period == null || period.isNegative() || period.isZero()) { throw new IllegalArgumentException("Zero or negative time interval is not supported"); @@ -182,6 +206,7 @@ private static Rounding.Prepared createRounding(final Period period, final ZoneI } final Rounding.Builder rounding; + boolean tryPrepareWithMinMax = true; if (period.getDays() == 1) { rounding = new Rounding.Builder(Rounding.DateTimeUnit.DAY_OF_MONTH); } else if (period.getDays() == 7) { @@ -190,6 +215,7 @@ private static Rounding.Prepared createRounding(final Period period, final ZoneI rounding = new Rounding.Builder(Rounding.DateTimeUnit.WEEK_OF_WEEKYEAR); } else if (period.getDays() > 1) { rounding = new Rounding.Builder(new TimeValue(period.getDays(), TimeUnit.DAYS)); + tryPrepareWithMinMax = false; } else if (period.getMonths() == 3) { // java.time.Period does not have a QUARTERLY period, so a period of 3 months // returns a quarterly rounding @@ -198,19 +224,26 @@ private static Rounding.Prepared createRounding(final Period period, final ZoneI rounding = new Rounding.Builder(Rounding.DateTimeUnit.MONTH_OF_YEAR); } else if (period.getMonths() > 0) { rounding = new Rounding.Builder(Rounding.DateTimeUnit.MONTHS_OF_YEAR, period.getMonths()); + tryPrepareWithMinMax = false; } else if (period.getYears() == 1) { rounding = new Rounding.Builder(Rounding.DateTimeUnit.YEAR_OF_CENTURY); } else if (period.getYears() > 0) { rounding = new Rounding.Builder(Rounding.DateTimeUnit.YEARS_OF_CENTURY, period.getYears()); + tryPrepareWithMinMax = false; } else { throw new IllegalArgumentException("Time interval is not supported"); } rounding.timeZone(timeZone); + if (min != null && max != null && tryPrepareWithMinMax) { + // Multiple quantities calendar interval - day/week/month/quarter/year is not supported by PreparedRounding.maybeUseArray, + // which is called by prepare(min, max), as it may hit an assert. Call prepare(min, max) only for single calendar interval. + return rounding.build().prepare(min, max); + } return rounding.build().prepareForUnknown(); } - private static Rounding.Prepared createRounding(final Duration duration, final ZoneId timeZone) { + private static Rounding.Prepared createRounding(final Duration duration, final ZoneId timeZone, Long min, Long max) { // Zero or negative intervals are not supported if (duration == null || duration.isNegative() || duration.isZero()) { throw new IllegalArgumentException("Zero or negative time interval is not supported"); @@ -218,6 +251,9 @@ private static Rounding.Prepared createRounding(final Duration duration, final Z final Rounding.Builder rounding = new Rounding.Builder(TimeValue.timeValueMillis(duration.toMillis())); rounding.timeZone(timeZone); + if (min != null && max != null) { + return rounding.build().prepare(min, max); + } return rounding.build().prepareForUnknown(); } @@ -249,4 +285,56 @@ public static ExpressionEvaluator.Factory evaluator( ) { return evaluatorMap.get(forType).apply(source, fieldEvaluator, rounding); } + + @Override + public Expression surrogate(SearchStats searchStats) { + // LocalSubstituteSurrogateExpressions should make sure this doesn't happen + assert searchStats != null : "SearchStats cannot be null"; + return maybeSubstituteWithRoundTo( + source(), + field(), + interval(), + searchStats, + (interval, minValue, maxValue) -> createRounding(interval, DEFAULT_TZ, minValue, maxValue) + ); + } + + public static RoundTo maybeSubstituteWithRoundTo( + Source source, + Expression field, + Expression foldableTimeExpression, + SearchStats searchStats, + TriFunction roundingFunction + ) { + if (field instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField == false && isDateTime(fa.dataType())) { + // Extract min/max from SearchStats + DataType fieldType = fa.dataType(); + FieldAttribute.FieldName fieldName = fa.fieldName(); + var min = searchStats.min(fieldName); + var max = searchStats.max(fieldName); + // If min/max is available create rounding with them + if (min instanceof Long minValue && max instanceof Long maxValue && foldableTimeExpression.foldable()) { + Object foldedInterval = foldableTimeExpression.fold(FoldContext.small() /* TODO remove me */); + Rounding.Prepared rounding = roundingFunction.apply(foldedInterval, minValue, maxValue); + long[] roundingPoints = rounding.fixedRoundingPoints(); + if (roundingPoints == null) { + logger.trace( + "Fixed rounding point is null for field {}, minValue {} in string format {} and maxValue {} in string format {}", + fieldName, + minValue, + dateWithTypeToString(minValue, fieldType), + maxValue, + dateWithTypeToString(maxValue, fieldType) + ); + return null; + } + // Convert to round_to function with the roundings + List points = Arrays.stream(roundingPoints) + .mapToObj(l -> new Literal(Source.EMPTY, l, fieldType)) + .collect(Collectors.toList()); + return new RoundTo(source, field, points); + } + } + return null; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java index dea36bba2c4fb..e03af271b30d7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java @@ -43,6 +43,15 @@ Matching special characters (eg. `.`, `*`, `(`...) will require escaping. To reduce the overhead of escaping, we suggest using triple quotes strings `\"\"\"` <> + ```{applies_to} + stack: ga 9.2 + serverless: ga + ``` + + Both a single pattern or a list of patterns are supported. If a list of patterns is provided, + the expression will return true if any of the patterns match. + + <> """, operator = NAME, examples = @Example(file = "docs", tag = "rlike")) public RLike( Source source, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLikeList.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLikeList.java new file mode 100644 index 0000000000000..3112cfcd47c93 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLikeList.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string.regex; + +import org.apache.lucene.search.MultiTermQuery; +import org.apache.lucene.util.automaton.Automaton; +import org.apache.lucene.util.automaton.CharacterRunAutomaton; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePatternList; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.ExpressionQuery; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; + +import java.io.IOException; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +public class RLikeList extends RegexMatch { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "RLikeList", + RLikeList::new + ); + + Supplier automatonSupplier = new Supplier<>() { + Automaton cached; + + @Override + public Automaton get() { + if (cached == null) { + cached = pattern().createAutomaton(caseInsensitive()); + } + return cached; + } + }; + + Supplier characterRunAutomatonSupplier = new Supplier<>() { + CharacterRunAutomaton cached; + + @Override + public CharacterRunAutomaton get() { + if (cached == null) { + cached = new CharacterRunAutomaton(automatonSupplier.get()); + } + return cached; + } + }; + + /** + * The documentation for this function is in RLike, and shown to the users as `RLIKE` in the docs. + */ + public RLikeList( + Source source, + @Param(name = "str", type = { "keyword", "text" }, description = "A literal value.") Expression value, + @Param(name = "patterns", type = { "keyword", "text" }, description = "A list of regular expressions.") RLikePatternList patterns + ) { + this(source, value, patterns, false); + } + + public RLikeList(Source source, Expression field, RLikePatternList rLikePattern, boolean caseInsensitive) { + super(source, field, rLikePattern, caseInsensitive); + } + + private RLikeList(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + new RLikePatternList(in), + deserializeCaseInsensitivity(in) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field()); + pattern().writeTo(out); + serializeCaseInsensitivity(out); + } + + @Override + public String name() { + return ENTRY.name; + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected RLikeList replaceChild(Expression newChild) { + return new RLikeList(source(), newChild, pattern(), caseInsensitive()); + } + + @Override + public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { + return pushdownPredicates.isPushableAttribute(field()) ? Translatable.YES : Translatable.NO; + } + + /** + * Returns a {@link Query} that matches the field against the provided patterns. + * For now, we only support a single pattern in the list for pushdown. + */ + @Override + public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { + var field = field(); + LucenePushdownPredicates.checkIsPushableAttribute(field); + return translateField(handler.nameOf(field instanceof FieldAttribute fa ? fa.exactAttribute() : field)); + } + + private Query translateField(String targetFieldName) { + return new ExpressionQuery(source(), targetFieldName, this); + } + + @Override + public org.apache.lucene.search.Query asLuceneQuery( + MappedFieldType fieldType, + MultiTermQuery.RewriteMethod constantScoreRewrite, + SearchExecutionContext context + ) { + return fieldType.automatonQuery( + automatonSupplier, + characterRunAutomatonSupplier, + constantScoreRewrite, + context, + getLuceneQueryDescription() + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, RLikeList::new, field(), pattern(), caseInsensitive()); + } + + private String getLuceneQueryDescription() { + // we use the information used to create the automaton to describe the query here + String patternDesc = pattern().patternList().stream().map(RLikePattern::pattern).collect(Collectors.joining("\", \"")); + return "RLIKE(\"" + patternDesc + "\"), caseInsensitive=" + caseInsensitive(); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/WildcardLikeList.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/WildcardLikeList.java index 1155589dc8ab3..d38e315b58b4f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/WildcardLikeList.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/WildcardLikeList.java @@ -11,7 +11,6 @@ import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.CharacterRunAutomaton; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -35,6 +34,8 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +import static org.elasticsearch.index.query.WildcardQueryBuilder.expressionTransportSupported; + public class WildcardLikeList extends RegexMatch { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, @@ -145,7 +146,7 @@ public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHand } private boolean supportsPushdown(TransportVersion version) { - return version == null || version.onOrAfter(TransportVersions.ESQL_FIXED_INDEX_LIKE); + return version == null || expressionTransportSupported(version); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java new file mode 100644 index 0000000000000..a86eb5633f729 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; + +public class CosineSimilarity extends VectorSimilarityFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "CosineSimilarity", + CosineSimilarity::new + ); + static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = COSINE::compare; + + @FunctionInfo( + returnType = "double", + preview = true, + description = "Calculates the cosine similarity between two dense_vectors.", + examples = { @Example(file = "vector-cosine-similarity", tag = "vector-cosine-similarity") }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public CosineSimilarity( + Source source, + @Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left, + @Param( + name = "right", + type = { "dense_vector" }, + description = "second dense_vector to calculate cosine similarity" + ) Expression right + ) { + super(source, left, right); + } + + private CosineSimilarity(StreamInput in) throws IOException { + super(in); + } + + @Override + protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) { + return new CosineSimilarity(source(), newLeft, newRight); + } + + @Override + protected SimilarityEvaluatorFunction getSimilarityFunction() { + return SIMILARITY_FUNCTION; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, CosineSimilarity::new, left(), right()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 63026fb9d7201..cab5ec862d7f5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -29,10 +30,12 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; @@ -51,10 +54,10 @@ import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; @@ -70,6 +73,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes private final transient Expression k; private final Expression options; + // Expressions to be used as prefilters in knn query + private final List filterExpressions; public static final Map ALLOWED_OPTIONS = Map.ofEntries( entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), @@ -139,14 +144,23 @@ public Knn( optional = true ) Expression options ) { - this(source, field, query, k, options, null); + this(source, field, query, k, options, null, List.of()); } - private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) { + public Knn( + Source source, + Expression field, + Expression query, + Expression k, + Expression options, + QueryBuilder queryBuilder, + List filterExpressions + ) { super(source, query, expressionList(field, query, k, options), queryBuilder); this.field = field; this.k = k; this.options = options; + this.filterExpressions = filterExpressions; } private static List expressionList(Expression field, Expression query, Expression k, Expression options) { @@ -174,6 +188,10 @@ public Expression options() { return options; } + public List filterExpressions() { + return filterExpressions; + } + @Override public DataType dataType() { return DataType.BOOLEAN; @@ -181,7 +199,7 @@ public DataType dataType() { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions()); + return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS)); } private TypeResolution resolveField() { @@ -204,42 +222,27 @@ private TypeResolution resolveK() { .and(isNotNull(k(), sourceText(), THIRD)); } - private TypeResolution resolveOptions() { - if (options() != null) { - TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); - if (resolution.unresolved()) { - return resolution; - } - // MapExpression does not have a DataType associated with it - resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); - if (resolution.unresolved()) { - return resolution; - } - - try { - knnQueryOptions(); - } catch (InvalidArgumentException e) { - return new TypeResolution(e.getMessage()); - } - } - return TypeResolution.TYPE_RESOLVED; + @Override + public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { + return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); } - private Map knnQueryOptions() throws InvalidArgumentException { - if (options() == null) { - return Map.of(); + @Override + public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { + Translatable translatable = super.translatable(pushdownPredicates); + // We need to check whether filter expressions are translatable as well + for (Expression filterExpression : filterExpressions()) { + translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates)); } - Map matchOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); - return matchOptions; + return translatable; } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = Match.fieldAsFieldAttribute(field()); - Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); + Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); @SuppressWarnings("unchecked") List queryFolded = (List) query().fold(FoldContext.small() /* TODO remove me */); @@ -252,18 +255,29 @@ protected Query translate(TranslatorHandler handler) { Map opts = queryOptions(); opts.put(K_FIELD.getPreferredName(), kValue); - return new KnnQuery(source(), fieldName, queryAsFloats, opts); + List filterQueries = new ArrayList<>(); + for (Expression filterExpression : filterExpressions()) { + if (filterExpression instanceof TranslationAware translationAware) { + // We can only translate filter expressions that are translatable. In case any is not translatable, + // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them + // when creating an evaluator for the non-pushed down query + if (translationAware.translatable(pushdownPredicates) == Translatable.YES) { + filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); + } + } + } + + return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries); } - @Override - public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { - return new Knn(source(), field(), query(), k(), options(), queryBuilder); + public Expression withFilters(List filterExpressions) { + return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } private Map queryOptions() throws InvalidArgumentException { Map options = new HashMap<>(); if (options() != null) { - populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); + Options.populateMap((MapExpression) options(), options, source(), FOURTH, ALLOWED_OPTIONS); } return options; } @@ -284,13 +298,14 @@ public Expression replaceChildren(List newChildren) { newChildren.get(1), newChildren.get(2), newChildren.size() > 3 ? newChildren.get(3) : null, - queryBuilder() + queryBuilder(), + filterExpressions() ); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Knn::new, field(), query(), k(), options()); + return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions()); } @Override @@ -303,7 +318,8 @@ private static Knn readFrom(StreamInput in) throws IOException { Expression field = in.readNamedWriteable(Expression.class); Expression query = in.readNamedWriteable(Expression.class); QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); - return new Knn(source, field, query, null, null, queryBuilder); + List filterExpressions = in.readNamedWriteableCollectionAsList(Expression.class); + return new Knn(source, field, query, null, null, queryBuilder, filterExpressions); } @Override @@ -312,6 +328,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(field()); out.writeNamedWriteable(query()); out.writeOptionalNamedWriteable(queryBuilder()); + out.writeNamedWriteableCollection(filterExpressions()); } @Override @@ -322,12 +339,13 @@ public boolean equals(Object o) { Knn knn = (Knn) o; return Objects.equals(field(), knn.field()) && Objects.equals(query(), knn.query()) - && Objects.equals(queryBuilder(), knn.queryBuilder()); + && Objects.equals(queryBuilder(), knn.queryBuilder()) + && Objects.equals(filterExpressions(), knn.filterExpressions()); } @Override public int hashCode() { - return Objects.hash(field(), query(), queryBuilder()); + return Objects.hash(field(), query(), queryBuilder(), filterExpressions()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java new file mode 100644 index 0000000000000..fc27ae2d876e8 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.xpack.esql.EsqlClientException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; + +import java.io.IOException; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; + +/** + * Base class for vector similarity functions, which compute a similarity score between two dense vectors + */ +public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper, VectorFunction { + + protected VectorSimilarityFunction(Source source, Expression left, Expression right) { + super(source, left, right); + } + + protected VectorSimilarityFunction(StreamInput in) throws IOException { + super(in); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + return checkDenseVectorParam(left(), FIRST).and(checkDenseVectorParam(right(), SECOND)); + } + + private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) { + return isNotNull(param, sourceText(), paramOrdinal).and( + isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector") + ); + } + + /** + * Functional interface for evaluating the similarity between two float arrays + */ + @FunctionalInterface + public interface SimilarityEvaluatorFunction { + float calculateSimilarity(float[] leftScratch, float[] rightScratch); + } + + @Override + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); + } + + @Override + public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { + return new SimilarityEvaluatorFactory( + toEvaluator.apply(left()), + toEvaluator.apply(right()), + getSimilarityFunction(), + getClass().getSimpleName() + "Evaluator" + ); + } + + /** + * Returns the similarity function to be used for evaluating the similarity between two vectors. + */ + protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); + + private record SimilarityEvaluatorFactory( + EvalOperator.ExpressionEvaluator.Factory left, + EvalOperator.ExpressionEvaluator.Factory right, + SimilarityEvaluatorFunction similarityFunction, + String evaluatorName + ) implements EvalOperator.ExpressionEvaluator.Factory { + + @Override + public EvalOperator.ExpressionEvaluator get(DriverContext context) { + // TODO check whether to use this custom evaluator or reuse / define an existing one + return new EvalOperator.ExpressionEvaluator() { + @Override + public Block eval(Page page) { + try ( + FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page); + FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page) + ) { + int positionCount = page.getPositionCount(); + int dimensions = 0; + // Get the first non-empty vector to calculate the dimension + for (int p = 0; p < positionCount; p++) { + if (leftBlock.getValueCount(p) != 0) { + dimensions = leftBlock.getValueCount(p); + break; + } + } + if (dimensions == 0) { + return context.blockFactory().newConstantFloatBlockWith(0F, 0); + } + + float[] leftScratch = new float[dimensions]; + float[] rightScratch = new float[dimensions]; + try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) { + for (int p = 0; p < positionCount; p++) { + int dimsLeft = leftBlock.getValueCount(p); + int dimsRight = rightBlock.getValueCount(p); + + if (dimsLeft == 0 || dimsRight == 0) { + // A null value on the left or right vector. Similarity is 0 + builder.appendDouble(0.0); + continue; + } else if (dimsLeft != dimsRight) { + throw new EsqlClientException( + "Vectors must have the same dimensions; first vector has {}, and second has {}", + dimsLeft, + dimsRight + ); + } + readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch); + readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch); + float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + } + + @Override + public String toString() { + return evaluatorName() + "[left=" + left + ", right=" + right + "]"; + } + + @Override + public void close() {} + }; + } + + private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { + for (int i = 0; i < dimensions; i++) { + scratch[i] = block.getFloat(position + i); + } + } + + @Override + public String toString() { + return evaluatorName() + "[left=" + left + ", right=" + right + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java new file mode 100644 index 0000000000000..a4274bf28de4b --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Defines the named writables for vector functions in ESQL. + */ +public final class VectorWritables { + + private VectorWritables() { + // Utility class + throw new UnsupportedOperationException(); + } + + public static List getNamedWritables() { + List entries = new ArrayList<>(); + + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + entries.add(Knn.ENTRY); + } + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + entries.add(CosineSimilarity.ENTRY); + } + + return Collections.unmodifiableList(entries); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/IndexResolution.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/IndexResolution.java index b9040d2ef40d6..4d31f48da77de 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/IndexResolution.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/IndexResolution.java @@ -6,10 +6,10 @@ */ package org.elasticsearch.xpack.esql.index; -import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesFailure; import org.elasticsearch.core.Nullable; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -19,33 +19,26 @@ public final class IndexResolution { /** * @param index EsIndex encapsulating requested index expression, resolved mappings and index modes from field-caps. * @param resolvedIndices Set of concrete indices resolved by field-caps. (This information is not always present in the EsIndex). - * @param unavailableShards Set of shards that were unavailable during index resolution - * @param unavailableClusters Remote clusters that could not be contacted during planning + * @param failures failures occurred during field-caps. * @return valid IndexResolution */ - public static IndexResolution valid( - EsIndex index, - Set resolvedIndices, - Set unavailableShards, - Map unavailableClusters - ) { + public static IndexResolution valid(EsIndex index, Set resolvedIndices, Map> failures) { Objects.requireNonNull(index, "index must not be null if it was found"); Objects.requireNonNull(resolvedIndices, "resolvedIndices must not be null"); - Objects.requireNonNull(unavailableShards, "unavailableShards must not be null"); - Objects.requireNonNull(unavailableClusters, "unavailableClusters must not be null"); - return new IndexResolution(index, null, resolvedIndices, unavailableShards, unavailableClusters); + Objects.requireNonNull(failures, "failures must not be null"); + return new IndexResolution(index, null, resolvedIndices, failures); } /** * Use this method only if the set of concrete resolved indices is the same as EsIndex#concreteIndices(). */ public static IndexResolution valid(EsIndex index) { - return valid(index, index.concreteIndices(), Set.of(), Map.of()); + return valid(index, index.concreteIndices(), Map.of()); } public static IndexResolution invalid(String invalid) { Objects.requireNonNull(invalid, "invalid must not be null to signal that the index is invalid"); - return new IndexResolution(null, invalid, Set.of(), Set.of(), Map.of()); + return new IndexResolution(null, invalid, Set.of(), Map.of()); } public static IndexResolution notFound(String name) { @@ -59,22 +52,19 @@ public static IndexResolution notFound(String name) { // all indices found by field-caps private final Set resolvedIndices; - private final Set unavailableShards; - // remote clusters included in the user's index expression that could not be connected to - private final Map unavailableClusters; + // map from cluster alias to failures that occurred during field-caps. + private final Map> failures; private IndexResolution( EsIndex index, @Nullable String invalid, Set resolvedIndices, - Set unavailableShards, - Map unavailableClusters + Map> failures ) { this.index = index; this.invalid = invalid; this.resolvedIndices = resolvedIndices; - this.unavailableShards = unavailableShards; - this.unavailableClusters = unavailableClusters; + this.failures = failures; } public boolean matches(String indexName) { @@ -101,11 +91,10 @@ public boolean isValid() { } /** - * @return Map of unavailable clusters (could not be connected to during field-caps query). Key of map is cluster alias, - * value is the {@link FieldCapabilitiesFailure} describing the issue. + * @return Map from cluster alias to failures that occurred during field-caps. */ - public Map unavailableClusters() { - return unavailableClusters; + public Map> failures() { + return failures; } /** @@ -115,13 +104,6 @@ public Set resolvedIndices() { return resolvedIndices; } - /** - * @return set of unavailable shards during index resolution - */ - public Set getUnavailableShards() { - return unavailableShards; - } - @Override public boolean equals(Object obj) { if (obj == null || obj.getClass() != getClass()) { @@ -131,12 +113,12 @@ public boolean equals(Object obj) { return Objects.equals(index, other.index) && Objects.equals(invalid, other.invalid) && Objects.equals(resolvedIndices, other.resolvedIndices) - && Objects.equals(unavailableClusters, other.unavailableClusters); + && Objects.equals(failures, other.failures); } @Override public int hashCode() { - return Objects.hash(index, invalid, resolvedIndices, unavailableClusters); + return Objects.hash(index, invalid, resolvedIndices, failures); } @Override @@ -152,7 +134,7 @@ public String toString() { + ", resolvedIndices=" + resolvedIndices + ", unavailableClusters=" - + unavailableClusters + + failures + '}'; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamWrapperQueryBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamWrapperQueryBuilder.java index bb9ca136f6d66..9cdf193c56fef 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamWrapperQueryBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamWrapperQueryBuilder.java @@ -20,6 +20,8 @@ import java.io.IOException; +import static org.elasticsearch.index.query.WildcardQueryBuilder.expressionTransportSupported; + /** * A {@link QueryBuilder} that wraps another {@linkplain QueryBuilder} * so it read with a {@link PlanStreamInput}. @@ -56,6 +58,11 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ESQL_FIXED_INDEX_LIKE; } + @Override + public boolean supportsVersion(TransportVersion version) { + return expressionTransportSupported(version); + } + @Override public Query toQuery(SearchExecutionContext context) throws IOException { return next.toQuery(context); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index b9d85d191f1d2..3749aef7488ad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.InferIsNotNull; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.InferNonNullAggConstraint; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalPropagateEmptyRelation; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalSubstituteSurrogateExpressions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceFieldWithConstantOrNull; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNWithLimitAndSort; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -46,7 +47,8 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor localOperators() { public LogicalPlan localOptimize(LogicalPlan plan) { LogicalPlan optimized = execute(plan); - Failures failures = verifier.verify(optimized, true); + Failures failures = verifier.verify(optimized, true, plan.output()); if (failures.hasFailures()) { throw new VerificationException(failures); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index 836eab9bb9590..af36963ac54a3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.EnableSpatialDistancePushdown; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.InsertFieldExtraction; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.ParallelizeTimeSeriesSource; @@ -42,15 +43,15 @@ public LocalPhysicalPlanOptimizer(LocalPhysicalOptimizerContext context) { } public PhysicalPlan localOptimize(PhysicalPlan plan) { - return verify(execute(plan)); + return verify(execute(plan), plan.output()); } - PhysicalPlan verify(PhysicalPlan plan) { - Failures failures = verifier.verify(plan, true); + PhysicalPlan verify(PhysicalPlan optimizedPlan, List expectedOutputAttributes) { + Failures failures = verifier.verify(optimizedPlan, true, expectedOutputAttributes); if (failures.hasFailures()) { throw new VerificationException(failures); } - return plan; + return optimizedPlan; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index eed6a6b57b68f..dac533f872022 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineOrderBy; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineSample; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownConjunctionsToKnnPrefilters; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan; @@ -112,7 +113,7 @@ public LogicalPlanOptimizer(LogicalOptimizerContext optimizerContext) { public LogicalPlan optimize(LogicalPlan verified) { var optimized = execute(verified); - Failures failures = verifier.verify(optimized, false); + Failures failures = verifier.verify(optimized, false, verified.output()); if (failures.hasFailures()) { throw new VerificationException(failures); } @@ -192,6 +193,7 @@ protected static Batch operators(boolean local) { new PruneLiteralsInOrderBy(), new PushDownAndCombineLimits(), new PushDownAndCombineFilters(), + new PushDownConjunctionsToKnnPrefilters(), new PushDownAndCombineSample(), new PushDownInferencePlan(), new PushDownEval(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java index 6751ae4cd2d80..4a04b46be295a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java @@ -13,27 +13,28 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -public final class LogicalVerifier { +public final class LogicalVerifier extends PostOptimizationPhasePlanVerifier { public static final LogicalVerifier INSTANCE = new LogicalVerifier(); private LogicalVerifier() {} - /** Verifies the optimized logical plan. */ - public Failures verify(LogicalPlan plan, boolean skipRemoteEnrichVerification) { - Failures failures = new Failures(); - Failures dependencyFailures = new Failures(); - + @Override + boolean skipVerification(LogicalPlan optimizedPlan, boolean skipRemoteEnrichVerification) { if (skipRemoteEnrichVerification) { // AwaitsFix https://github.com/elastic/elasticsearch/issues/118531 - var enriches = plan.collectFirstChildren(Enrich.class::isInstance); + var enriches = optimizedPlan.collectFirstChildren(Enrich.class::isInstance); if (enriches.isEmpty() == false && ((Enrich) enriches.get(0)).mode() == Enrich.Mode.REMOTE) { - return failures; + return true; } } + return false; + } - plan.forEachUp(p -> { - PlanConsistencyChecker.checkPlan(p, dependencyFailures); + @Override + void checkPlanConsistency(LogicalPlan optimizedPlan, Failures failures, Failures depFailures) { + optimizedPlan.forEachUp(p -> { + PlanConsistencyChecker.checkPlan(p, depFailures); if (failures.hasFailures() == false) { if (p instanceof PostOptimizationVerificationAware pova) { @@ -46,11 +47,5 @@ public Failures verify(LogicalPlan plan, boolean skipRemoteEnrichVerification) { }); } }); - - if (dependencyFailures.hasFailures()) { - throw new IllegalStateException(dependencyFailures.toString()); - } - - return failures; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java index ab6bea5ffddac..6d60c547f47d6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns; import org.elasticsearch.xpack.esql.plan.physical.FragmentExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -34,15 +35,15 @@ public PhysicalPlanOptimizer(PhysicalOptimizerContext context) { } public PhysicalPlan optimize(PhysicalPlan plan) { - return verify(execute(plan)); + return verify(execute(plan), plan.output()); } - PhysicalPlan verify(PhysicalPlan plan) { - Failures failures = verifier.verify(plan, false); + PhysicalPlan verify(PhysicalPlan optimizedPlan, List expectedOutputAttributes) { + Failures failures = verifier.verify(optimizedPlan, false, expectedOutputAttributes); if (failures.hasFailures()) { throw new VerificationException(failures); } - return plan; + return optimizedPlan; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java index 607aa11575bcb..781a8f5263c1f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java @@ -20,26 +20,27 @@ import static org.elasticsearch.xpack.esql.common.Failure.fail; /** Physical plan verifier. */ -public final class PhysicalVerifier { +public final class PhysicalVerifier extends PostOptimizationPhasePlanVerifier { public static final PhysicalVerifier INSTANCE = new PhysicalVerifier(); private PhysicalVerifier() {} - /** Verifies the physical plan. */ - public Failures verify(PhysicalPlan plan, boolean skipRemoteEnrichVerification) { - Failures failures = new Failures(); - Failures depFailures = new Failures(); - + @Override + boolean skipVerification(PhysicalPlan optimizedPlan, boolean skipRemoteEnrichVerification) { if (skipRemoteEnrichVerification) { // AwaitsFix https://github.com/elastic/elasticsearch/issues/118531 - var enriches = plan.collectFirstChildren(EnrichExec.class::isInstance); + var enriches = optimizedPlan.collectFirstChildren(EnrichExec.class::isInstance); if (enriches.isEmpty() == false && ((EnrichExec) enriches.get(0)).mode() == Enrich.Mode.REMOTE) { - return failures; + return true; } } + return false; + } - plan.forEachDown(p -> { + @Override + void checkPlanConsistency(PhysicalPlan optimizedPlan, Failures failures, Failures depFailures) { + optimizedPlan.forEachDown(p -> { if (p instanceof FieldExtractExec fieldExtractExec) { Attribute sourceAttribute = fieldExtractExec.sourceAttribute(); if (sourceAttribute == null) { @@ -66,11 +67,5 @@ public Failures verify(PhysicalPlan plan, boolean skipRemoteEnrichVerification) }); } }); - - if (depFailures.hasFailures()) { - throw new IllegalStateException(depFailures.toString()); - } - - return failures; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PostOptimizationPhasePlanVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PostOptimizationPhasePlanVerifier.java new file mode 100644 index 0000000000000..647dafe649984 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PostOptimizationPhasePlanVerifier.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns; +import org.elasticsearch.xpack.esql.plan.QueryPlan; +import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; + +import java.util.List; + +import static org.elasticsearch.index.IndexMode.LOOKUP; +import static org.elasticsearch.xpack.esql.common.Failure.fail; +import static org.elasticsearch.xpack.esql.core.expression.Attribute.dataTypeEquals; + +/** + * Verifies the plan after optimization. + * This is invoked immediately after a Plan Optimizer completes its work. + * Currently, it is called after LogicalPlanOptimizer, PhysicalPlanOptimizer, + * LocalLogicalPlanOptimizer, and LocalPhysicalPlanOptimizer. + * Note: Logical and Physical optimizers may override methods in this class to perform different checks. + */ +public abstract class PostOptimizationPhasePlanVerifier

    > { + + /** Verifies the optimized plan */ + public Failures verify(P optimizedPlan, boolean skipRemoteEnrichVerification, List expectedOutputAttributes) { + Failures failures = new Failures(); + Failures depFailures = new Failures(); + if (skipVerification(optimizedPlan, skipRemoteEnrichVerification)) { + return failures; + } + + checkPlanConsistency(optimizedPlan, failures, depFailures); + + verifyOutputNotChanged(optimizedPlan, expectedOutputAttributes, failures); + + if (depFailures.hasFailures()) { + throw new IllegalStateException(depFailures.toString()); + } + + return failures; + } + + abstract boolean skipVerification(P optimizedPlan, boolean skipRemoteEnrichVerification); + + abstract void checkPlanConsistency(P optimizedPlan, Failures failures, Failures depFailures); + + private static void verifyOutputNotChanged(QueryPlan optimizedPlan, List expectedOutputAttributes, Failures failures) { + if (dataTypeEquals(expectedOutputAttributes, optimizedPlan.output()) == false) { + // If the output level is empty we add a column called ProjectAwayColumns.ALL_FIELDS_PROJECTED + // We will ignore such cases for output verification + // TODO: this special casing is required due to https://github.com/elastic/elasticsearch/issues/121741, remove when fixed. + boolean hasProjectAwayColumns = optimizedPlan.output() + .stream() + .anyMatch(x -> x.name().equals(ProjectAwayColumns.ALL_FIELDS_PROJECTED)); + // LookupJoinExec represents the lookup index with EsSourceExec and this is turned into EsQueryExec by + // ReplaceSourceAttributes. Because InsertFieldExtractions doesn't apply to lookup indices, the + // right hand side will only have the EsQueryExec providing the _doc attribute and nothing else. + // We perform an optimizer run on every fragment. LookupJoinExec also contains such a fragment, + // and currently it only contains an EsQueryExec after optimization. + boolean hasLookupJoinExec = optimizedPlan instanceof EsQueryExec esQueryExec && esQueryExec.indexMode() == LOOKUP; + boolean ignoreError = hasProjectAwayColumns || hasLookupJoinExec; + if (ignoreError == false) { + failures.add( + fail( + optimizedPlan, + "Output has changed from [{}] to [{}]. ", + expectedOutputAttributes.toString(), + optimizedPlan.output().toString() + ) + ); + } + } + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java new file mode 100644 index 0000000000000..aa4bb203b4346 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Stack; + +/** + * Rewrites an expression tree to push down conjunctions in the prefilter of {@link Knn} functions. + * knn functions won't contain other knn functions as a prefilter, to avoid circular dependencies. + * Given an expression tree like {@code (A OR B) AND (C AND knn())} this rule will rewrite it to + * {@code (A OR B) AND (C AND knn(filterExpressions = [(A OR B), C]))} +*/ +public class PushDownConjunctionsToKnnPrefilters extends OptimizerRules.OptimizerRule { + + @Override + protected LogicalPlan rule(Filter filter) { + Stack filters = new Stack<>(); + Expression condition = filter.condition(); + Expression newCondition = pushConjunctionsToKnn(condition, filters, null); + + return condition.equals(newCondition) ? filter : filter.with(newCondition); + } + + /** + * Updates knn function prefilters. This method processes conjunctions so knn functions on one side of the conjunction receive + * the other side of the conjunction as a prefilter + * + * @param expression expression to process recursively + * @param filters current filters to apply to the expression. They contain expressions on the other side of the traversed conjunctions + * @param addedFilter a new filter to add to the list of filters for the processing + * @return the updated expression, or the original expression if it doesn't need to be updated + */ + private static Expression pushConjunctionsToKnn(Expression expression, Stack filters, Expression addedFilter) { + if (addedFilter != null) { + filters.push(addedFilter); + } + Expression result = switch (expression) { + case And and: + // Traverse both sides of the And, using the other side as the added filter + Expression newLeft = pushConjunctionsToKnn(and.left(), filters, and.right()); + Expression newRight = pushConjunctionsToKnn(and.right(), filters, and.left()); + if (newLeft.equals(and.left()) && newRight.equals(and.right())) { + yield and; + } + yield and.replaceChildrenSameSize(List.of(newLeft, newRight)); + case Knn knn: + // We don't want knn expressions to have other knn expressions as a prefilter to avoid circular dependencies + List newFilters = filters.stream() + .map(PushDownConjunctionsToKnnPrefilters::removeKnn) + .filter(Objects::nonNull) + .toList(); + if (newFilters.equals(knn.filterExpressions())) { + yield knn; + } + yield knn.withFilters(newFilters); + default: + List children = expression.children(); + boolean childrenChanged = false; + + // This copies transformChildren algorithm to avoid unnecessary changes + List transformedChildren = null; + + for (int i = 0, s = children.size(); i < s; i++) { + Expression child = children.get(i); + Expression next = pushConjunctionsToKnn(child, filters, null); + if (child.equals(next) == false) { + // lazy copy + replacement in place + if (childrenChanged == false) { + childrenChanged = true; + transformedChildren = new ArrayList<>(children); + } + transformedChildren.set(i, next); + } + } + + yield (childrenChanged ? expression.replaceChildrenSameSize(transformedChildren) : expression); + }; + + if (addedFilter != null) { + filters.pop(); + } + + return result; + } + + /** + * Removes knn functions from the expression tree + * @param expression expression to process + * @return expression without knn functions, or null if the expression is a knn function + */ + private static Expression removeKnn(Expression expression) { + if (expression.children().isEmpty()) { + return expression; + } + if (expression instanceof Knn) { + return null; + } + + List filteredChildren = expression.children() + .stream() + .map(PushDownConjunctionsToKnnPrefilters::removeKnn) + .filter(Objects::nonNull) + .toList(); + if (filteredChildren.equals(expression.children())) { + return expression; + } else if (filteredChildren.isEmpty()) { + return null; + } else if (expression instanceof BinaryLogic && filteredChildren.size() == 1) { + // Simplify an AND / OR expression to a single child + return filteredChildren.getFirst(); + } else { + return expression.replaceChildrenSameSize(filteredChildren); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java index dd7ee26aa84bd..8fe9ccc18c006 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; @@ -137,13 +138,13 @@ private static Expression transformNonEvaluatableGroupingFunction( List newChildren = new ArrayList<>(gf.children().size()); for (Expression ex : gf.children()) { - if (ex instanceof Attribute == false) { // TODO: foldables shouldn't require eval'ing either + if (ex instanceof Attribute || ex instanceof MapExpression) { + newChildren.add(ex); + } else { // TODO: foldables shouldn't require eval'ing either var alias = new Alias(ex.source(), syntheticName(ex, gf, counter++), ex, null, true); evals.add(alias); newChildren.add(alias.toAttribute()); childrenChanged = true; - } else { - newChildren.add(ex); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveRegexMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveRegexMatch.java index fa43d51634efd..33de97cc9d08e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveRegexMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveRegexMatch.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePatternList; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexMatch; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.StringPattern; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPatternList; @@ -29,8 +30,8 @@ public ReplaceStringCasingWithInsensitiveRegexMatch() { @Override protected Expression rule(RegexMatch regexMatch, LogicalOptimizerContext unused) { Expression e = regexMatch; - if (regexMatch.pattern() instanceof WildcardPatternList) { - // This optimization is not supported for WildcardPatternList for now + if (regexMatch.pattern() instanceof WildcardPatternList || regexMatch.pattern() instanceof RLikePatternList) { + // This optimization is not supported for WildcardPatternList and RLikePatternList for now return e; } if (regexMatch.field() instanceof ChangeCase changeCase) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java new file mode 100644 index 0000000000000..ff25be0c85258 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.expression.LocalSurrogateExpression; +import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.rule.ParameterizedRule; +import org.elasticsearch.xpack.esql.stats.SearchStats; + +public class LocalSubstituteSurrogateExpressions extends ParameterizedRule { + + @Override + public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) { + return context.searchStats() != null + ? plan.transformUp(Eval.class, eval -> eval.transformExpressionsOnly(Function.class, f -> substitute(f, context.searchStats()))) + : plan; + } + + /** + * Perform the actual substitution. + */ + private static Expression substitute(Expression e, SearchStats searchStats) { + if (e instanceof LocalSurrogateExpression s) { + Expression surrogate = s.surrogate(searchStats); + if (surrogate != null) { + return surrogate; + } + } + return e; + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java index 26cfbf40eb7ff..189fc5e4c7415 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java @@ -36,6 +36,7 @@ * extraction. */ public class ProjectAwayColumns extends Rule { + public static String ALL_FIELDS_PROJECTED = ""; @Override public PhysicalPlan apply(PhysicalPlan plan) { @@ -76,13 +77,25 @@ public PhysicalPlan apply(PhysicalPlan plan) { // no need for projection when dealing with aggs if (logicalFragment instanceof Aggregate == false) { - List output = new ArrayList<>(requiredAttrBuilder.build()); + // we should respect the order of the attributes + List output = new ArrayList<>(); + for (Attribute attribute : logicalFragment.output()) { + if (requiredAttrBuilder.contains(attribute)) { + output.add(attribute); + requiredAttrBuilder.remove(attribute); + } + } + // requiredAttrBuilder should be empty unless the plan is inconsistent due to a bug. + // This can happen in case of remote ENRICH, see https://github.com/elastic/elasticsearch/issues/118531 + // TODO: stop adding the remaining required attributes once remote ENRICH is fixed. + output.addAll(requiredAttrBuilder.build()); + // if all the fields are filtered out, it's only the count that matters // however until a proper fix (see https://github.com/elastic/elasticsearch/issues/98703) // add a synthetic field (so it doesn't clash with the user defined one) to return a constant // to avoid the block from being trimmed if (output.isEmpty()) { - var alias = new Alias(logicalFragment.source(), "", Literal.NULL, null, true); + var alias = new Alias(logicalFragment.source(), ALL_FIELDS_PROJECTED, Literal.NULL, null, true); List fields = singletonList(alias); logicalFragment = new Eval(logicalFragment.source(), logicalFragment, fields); output = Expressions.asAttributes(fields); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java index a4ec64b004a0c..69e20e4895e48 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java @@ -8,13 +8,11 @@ package org.elasticsearch.xpack.esql.optimizer.rules.physical.local; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns; -import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; import org.elasticsearch.xpack.esql.plan.physical.LeafExec; @@ -48,16 +46,6 @@ public PhysicalPlan rule(PhysicalPlan plan, LocalPhysicalOptimizerContext contex var missing = missingAttributes(p); - /* - * If there is a single grouping then we'll try to use ords. Either way - * it loads the field lazily. If we have more than one field we need to - * make sure the fields are loaded for the standard hash aggregator. - */ - if (p instanceof AggregateExec agg) { - var ordinalAttributes = agg.ordinalAttributes(); - missing.removeAll(Expressions.references(ordinalAttributes)); - } - // add extractor if (missing.isEmpty() == false) { // identify child (for binary nodes) that exports _doc and place the field extractor there diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp index b80d38a7f4a39..fe807dc62d367 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp @@ -373,4 +373,4 @@ joinPredicate atn: -[4, 1, 139, 815, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15, 7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 2, 19, 7, 19, 2, 20, 7, 20, 2, 21, 7, 21, 2, 22, 7, 22, 2, 23, 7, 23, 2, 24, 7, 24, 2, 25, 7, 25, 2, 26, 7, 26, 2, 27, 7, 27, 2, 28, 7, 28, 2, 29, 7, 29, 2, 30, 7, 30, 2, 31, 7, 31, 2, 32, 7, 32, 2, 33, 7, 33, 2, 34, 7, 34, 2, 35, 7, 35, 2, 36, 7, 36, 2, 37, 7, 37, 2, 38, 7, 38, 2, 39, 7, 39, 2, 40, 7, 40, 2, 41, 7, 41, 2, 42, 7, 42, 2, 43, 7, 43, 2, 44, 7, 44, 2, 45, 7, 45, 2, 46, 7, 46, 2, 47, 7, 47, 2, 48, 7, 48, 2, 49, 7, 49, 2, 50, 7, 50, 2, 51, 7, 51, 2, 52, 7, 52, 2, 53, 7, 53, 2, 54, 7, 54, 2, 55, 7, 55, 2, 56, 7, 56, 2, 57, 7, 57, 2, 58, 7, 58, 2, 59, 7, 59, 2, 60, 7, 60, 2, 61, 7, 61, 2, 62, 7, 62, 2, 63, 7, 63, 2, 64, 7, 64, 2, 65, 7, 65, 2, 66, 7, 66, 2, 67, 7, 67, 2, 68, 7, 68, 2, 69, 7, 69, 2, 70, 7, 70, 2, 71, 7, 71, 2, 72, 7, 72, 2, 73, 7, 73, 2, 74, 7, 74, 2, 75, 7, 75, 2, 76, 7, 76, 2, 77, 7, 77, 2, 78, 7, 78, 2, 79, 7, 79, 2, 80, 7, 80, 2, 81, 7, 81, 2, 82, 7, 82, 2, 83, 7, 83, 2, 84, 7, 84, 2, 85, 7, 85, 2, 86, 7, 86, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 184, 8, 1, 10, 1, 12, 1, 187, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 196, 8, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 3, 3, 225, 8, 3, 1, 4, 1, 4, 1, 4, 1, 5, 1, 5, 1, 6, 1, 6, 1, 6, 1, 7, 1, 7, 1, 7, 5, 7, 238, 8, 7, 10, 7, 12, 7, 241, 9, 7, 1, 8, 1, 8, 1, 8, 3, 8, 246, 8, 8, 1, 8, 1, 8, 1, 9, 1, 9, 1, 9, 5, 9, 253, 8, 9, 10, 9, 12, 9, 256, 9, 9, 1, 10, 1, 10, 1, 10, 3, 10, 261, 8, 10, 1, 11, 1, 11, 1, 11, 1, 12, 1, 12, 1, 12, 1, 13, 1, 13, 1, 13, 5, 13, 272, 8, 13, 10, 13, 12, 13, 275, 9, 13, 1, 13, 3, 13, 278, 8, 13, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 3, 14, 289, 8, 14, 1, 15, 1, 15, 1, 16, 1, 16, 1, 17, 1, 17, 1, 18, 1, 18, 1, 19, 1, 19, 1, 19, 1, 19, 5, 19, 303, 8, 19, 10, 19, 12, 19, 306, 9, 19, 1, 20, 1, 20, 1, 20, 1, 21, 1, 21, 3, 21, 313, 8, 21, 1, 21, 1, 21, 3, 21, 317, 8, 21, 1, 22, 1, 22, 1, 22, 5, 22, 322, 8, 22, 10, 22, 12, 22, 325, 9, 22, 1, 23, 1, 23, 1, 23, 3, 23, 330, 8, 23, 1, 24, 1, 24, 1, 24, 5, 24, 335, 8, 24, 10, 24, 12, 24, 338, 9, 24, 1, 25, 1, 25, 1, 25, 5, 25, 343, 8, 25, 10, 25, 12, 25, 346, 9, 25, 1, 26, 1, 26, 1, 26, 5, 26, 351, 8, 26, 10, 26, 12, 26, 354, 9, 26, 1, 27, 1, 27, 1, 28, 1, 28, 1, 28, 3, 28, 361, 8, 28, 1, 29, 1, 29, 3, 29, 365, 8, 29, 1, 30, 1, 30, 3, 30, 369, 8, 30, 1, 31, 1, 31, 1, 31, 3, 31, 374, 8, 31, 1, 32, 1, 32, 1, 32, 1, 33, 1, 33, 1, 33, 1, 33, 5, 33, 383, 8, 33, 10, 33, 12, 33, 386, 9, 33, 1, 34, 1, 34, 3, 34, 390, 8, 34, 1, 34, 1, 34, 3, 34, 394, 8, 34, 1, 35, 1, 35, 1, 35, 1, 36, 1, 36, 1, 36, 1, 37, 1, 37, 1, 37, 1, 37, 5, 37, 406, 8, 37, 10, 37, 12, 37, 409, 9, 37, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 3, 38, 419, 8, 38, 1, 39, 1, 39, 1, 39, 1, 39, 3, 39, 425, 8, 39, 1, 40, 1, 40, 1, 40, 1, 40, 1, 41, 1, 41, 1, 41, 1, 42, 1, 42, 1, 42, 5, 42, 437, 8, 42, 10, 42, 12, 42, 440, 9, 42, 1, 43, 1, 43, 1, 43, 1, 43, 1, 44, 1, 44, 1, 44, 1, 45, 1, 45, 1, 45, 1, 45, 1, 46, 1, 46, 1, 46, 1, 47, 1, 47, 1, 47, 1, 47, 3, 47, 460, 8, 47, 1, 47, 1, 47, 1, 47, 1, 47, 5, 47, 466, 8, 47, 10, 47, 12, 47, 469, 9, 47, 3, 47, 471, 8, 47, 1, 48, 1, 48, 1, 49, 1, 49, 1, 49, 3, 49, 478, 8, 49, 1, 49, 1, 49, 1, 50, 1, 50, 1, 50, 1, 51, 1, 51, 1, 51, 1, 51, 3, 51, 489, 8, 51, 1, 51, 1, 51, 1, 51, 1, 51, 1, 51, 3, 51, 496, 8, 51, 1, 52, 1, 52, 1, 52, 1, 53, 4, 53, 502, 8, 53, 11, 53, 12, 53, 503, 1, 54, 1, 54, 1, 54, 1, 54, 1, 55, 1, 55, 1, 55, 1, 55, 1, 55, 1, 55, 5, 55, 516, 8, 55, 10, 55, 12, 55, 519, 9, 55, 1, 56, 1, 56, 1, 57, 1, 57, 1, 57, 1, 57, 3, 57, 527, 8, 57, 1, 57, 1, 57, 1, 57, 1, 57, 1, 58, 1, 58, 1, 58, 1, 58, 1, 58, 1, 59, 1, 59, 1, 59, 1, 59, 3, 59, 542, 8, 59, 1, 60, 1, 60, 1, 60, 1, 61, 1, 61, 1, 62, 1, 62, 1, 62, 5, 62, 552, 8, 62, 10, 62, 12, 62, 555, 9, 62, 1, 63, 1, 63, 1, 63, 1, 63, 1, 64, 1, 64, 3, 64, 563, 8, 64, 1, 65, 1, 65, 1, 65, 1, 65, 1, 65, 1, 65, 3, 65, 571, 8, 65, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 3, 66, 580, 8, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 5, 66, 587, 8, 66, 10, 66, 12, 66, 590, 9, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 3, 66, 597, 8, 66, 1, 66, 1, 66, 1, 66, 3, 66, 602, 8, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 5, 66, 610, 8, 66, 10, 66, 12, 66, 613, 9, 66, 1, 67, 1, 67, 3, 67, 617, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 3, 67, 624, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 3, 67, 631, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 5, 67, 638, 8, 67, 10, 67, 12, 67, 641, 9, 67, 1, 67, 1, 67, 3, 67, 645, 8, 67, 1, 68, 1, 68, 1, 68, 3, 68, 650, 8, 68, 1, 68, 1, 68, 1, 68, 1, 69, 1, 69, 1, 69, 1, 69, 1, 69, 3, 69, 660, 8, 69, 1, 70, 1, 70, 1, 70, 1, 70, 3, 70, 666, 8, 70, 1, 70, 1, 70, 1, 70, 1, 70, 1, 70, 1, 70, 5, 70, 674, 8, 70, 10, 70, 12, 70, 677, 9, 70, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 3, 71, 687, 8, 71, 1, 71, 1, 71, 1, 71, 5, 71, 692, 8, 71, 10, 71, 12, 71, 695, 9, 71, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 5, 72, 703, 8, 72, 10, 72, 12, 72, 706, 9, 72, 1, 72, 1, 72, 3, 72, 710, 8, 72, 3, 72, 712, 8, 72, 1, 72, 1, 72, 1, 73, 1, 73, 1, 74, 1, 74, 1, 74, 1, 74, 5, 74, 722, 8, 74, 10, 74, 12, 74, 725, 9, 74, 1, 74, 1, 74, 1, 75, 1, 75, 1, 75, 1, 75, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 5, 76, 746, 8, 76, 10, 76, 12, 76, 749, 9, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 5, 76, 757, 8, 76, 10, 76, 12, 76, 760, 9, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 5, 76, 768, 8, 76, 10, 76, 12, 76, 771, 9, 76, 1, 76, 1, 76, 3, 76, 775, 8, 76, 1, 77, 1, 77, 1, 78, 1, 78, 3, 78, 781, 8, 78, 1, 79, 3, 79, 784, 8, 79, 1, 79, 1, 79, 1, 80, 3, 80, 789, 8, 80, 1, 80, 1, 80, 1, 81, 1, 81, 1, 82, 1, 82, 1, 83, 1, 83, 1, 83, 1, 83, 1, 83, 1, 84, 1, 84, 1, 85, 1, 85, 1, 85, 1, 85, 5, 85, 808, 8, 85, 10, 85, 12, 85, 811, 9, 85, 1, 86, 1, 86, 1, 86, 0, 5, 2, 110, 132, 140, 142, 87, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138, 140, 142, 144, 146, 148, 150, 152, 154, 156, 158, 160, 162, 164, 166, 168, 170, 172, 0, 10, 2, 0, 53, 53, 107, 107, 1, 0, 101, 102, 2, 0, 57, 57, 63, 63, 2, 0, 66, 66, 69, 69, 2, 0, 38, 38, 53, 53, 1, 0, 87, 88, 1, 0, 89, 91, 2, 0, 65, 65, 78, 78, 2, 0, 80, 80, 82, 86, 2, 0, 23, 23, 25, 26, 841, 0, 174, 1, 0, 0, 0, 2, 177, 1, 0, 0, 0, 4, 195, 1, 0, 0, 0, 6, 224, 1, 0, 0, 0, 8, 226, 1, 0, 0, 0, 10, 229, 1, 0, 0, 0, 12, 231, 1, 0, 0, 0, 14, 234, 1, 0, 0, 0, 16, 245, 1, 0, 0, 0, 18, 249, 1, 0, 0, 0, 20, 257, 1, 0, 0, 0, 22, 262, 1, 0, 0, 0, 24, 265, 1, 0, 0, 0, 26, 268, 1, 0, 0, 0, 28, 288, 1, 0, 0, 0, 30, 290, 1, 0, 0, 0, 32, 292, 1, 0, 0, 0, 34, 294, 1, 0, 0, 0, 36, 296, 1, 0, 0, 0, 38, 298, 1, 0, 0, 0, 40, 307, 1, 0, 0, 0, 42, 310, 1, 0, 0, 0, 44, 318, 1, 0, 0, 0, 46, 326, 1, 0, 0, 0, 48, 331, 1, 0, 0, 0, 50, 339, 1, 0, 0, 0, 52, 347, 1, 0, 0, 0, 54, 355, 1, 0, 0, 0, 56, 360, 1, 0, 0, 0, 58, 364, 1, 0, 0, 0, 60, 368, 1, 0, 0, 0, 62, 373, 1, 0, 0, 0, 64, 375, 1, 0, 0, 0, 66, 378, 1, 0, 0, 0, 68, 387, 1, 0, 0, 0, 70, 395, 1, 0, 0, 0, 72, 398, 1, 0, 0, 0, 74, 401, 1, 0, 0, 0, 76, 418, 1, 0, 0, 0, 78, 420, 1, 0, 0, 0, 80, 426, 1, 0, 0, 0, 82, 430, 1, 0, 0, 0, 84, 433, 1, 0, 0, 0, 86, 441, 1, 0, 0, 0, 88, 445, 1, 0, 0, 0, 90, 448, 1, 0, 0, 0, 92, 452, 1, 0, 0, 0, 94, 455, 1, 0, 0, 0, 96, 472, 1, 0, 0, 0, 98, 477, 1, 0, 0, 0, 100, 481, 1, 0, 0, 0, 102, 484, 1, 0, 0, 0, 104, 497, 1, 0, 0, 0, 106, 501, 1, 0, 0, 0, 108, 505, 1, 0, 0, 0, 110, 509, 1, 0, 0, 0, 112, 520, 1, 0, 0, 0, 114, 522, 1, 0, 0, 0, 116, 532, 1, 0, 0, 0, 118, 537, 1, 0, 0, 0, 120, 543, 1, 0, 0, 0, 122, 546, 1, 0, 0, 0, 124, 548, 1, 0, 0, 0, 126, 556, 1, 0, 0, 0, 128, 562, 1, 0, 0, 0, 130, 564, 1, 0, 0, 0, 132, 601, 1, 0, 0, 0, 134, 644, 1, 0, 0, 0, 136, 646, 1, 0, 0, 0, 138, 659, 1, 0, 0, 0, 140, 665, 1, 0, 0, 0, 142, 686, 1, 0, 0, 0, 144, 696, 1, 0, 0, 0, 146, 715, 1, 0, 0, 0, 148, 717, 1, 0, 0, 0, 150, 728, 1, 0, 0, 0, 152, 774, 1, 0, 0, 0, 154, 776, 1, 0, 0, 0, 156, 780, 1, 0, 0, 0, 158, 783, 1, 0, 0, 0, 160, 788, 1, 0, 0, 0, 162, 792, 1, 0, 0, 0, 164, 794, 1, 0, 0, 0, 166, 796, 1, 0, 0, 0, 168, 801, 1, 0, 0, 0, 170, 803, 1, 0, 0, 0, 172, 812, 1, 0, 0, 0, 174, 175, 3, 2, 1, 0, 175, 176, 5, 0, 0, 1, 176, 1, 1, 0, 0, 0, 177, 178, 6, 1, -1, 0, 178, 179, 3, 4, 2, 0, 179, 185, 1, 0, 0, 0, 180, 181, 10, 1, 0, 0, 181, 182, 5, 52, 0, 0, 182, 184, 3, 6, 3, 0, 183, 180, 1, 0, 0, 0, 184, 187, 1, 0, 0, 0, 185, 183, 1, 0, 0, 0, 185, 186, 1, 0, 0, 0, 186, 3, 1, 0, 0, 0, 187, 185, 1, 0, 0, 0, 188, 196, 3, 22, 11, 0, 189, 196, 3, 12, 6, 0, 190, 196, 3, 92, 46, 0, 191, 192, 4, 2, 1, 0, 192, 196, 3, 24, 12, 0, 193, 194, 4, 2, 2, 0, 194, 196, 3, 88, 44, 0, 195, 188, 1, 0, 0, 0, 195, 189, 1, 0, 0, 0, 195, 190, 1, 0, 0, 0, 195, 191, 1, 0, 0, 0, 195, 193, 1, 0, 0, 0, 196, 5, 1, 0, 0, 0, 197, 225, 3, 40, 20, 0, 198, 225, 3, 8, 4, 0, 199, 225, 3, 70, 35, 0, 200, 225, 3, 64, 32, 0, 201, 225, 3, 42, 21, 0, 202, 225, 3, 66, 33, 0, 203, 225, 3, 72, 36, 0, 204, 225, 3, 74, 37, 0, 205, 225, 3, 78, 39, 0, 206, 225, 3, 80, 40, 0, 207, 225, 3, 94, 47, 0, 208, 225, 3, 82, 41, 0, 209, 225, 3, 166, 83, 0, 210, 225, 3, 102, 51, 0, 211, 225, 3, 114, 57, 0, 212, 225, 3, 100, 50, 0, 213, 225, 3, 104, 52, 0, 214, 215, 4, 3, 3, 0, 215, 225, 3, 118, 59, 0, 216, 217, 4, 3, 4, 0, 217, 225, 3, 116, 58, 0, 218, 219, 4, 3, 5, 0, 219, 225, 3, 120, 60, 0, 220, 221, 4, 3, 6, 0, 221, 225, 3, 130, 65, 0, 222, 223, 4, 3, 7, 0, 223, 225, 3, 122, 61, 0, 224, 197, 1, 0, 0, 0, 224, 198, 1, 0, 0, 0, 224, 199, 1, 0, 0, 0, 224, 200, 1, 0, 0, 0, 224, 201, 1, 0, 0, 0, 224, 202, 1, 0, 0, 0, 224, 203, 1, 0, 0, 0, 224, 204, 1, 0, 0, 0, 224, 205, 1, 0, 0, 0, 224, 206, 1, 0, 0, 0, 224, 207, 1, 0, 0, 0, 224, 208, 1, 0, 0, 0, 224, 209, 1, 0, 0, 0, 224, 210, 1, 0, 0, 0, 224, 211, 1, 0, 0, 0, 224, 212, 1, 0, 0, 0, 224, 213, 1, 0, 0, 0, 224, 214, 1, 0, 0, 0, 224, 216, 1, 0, 0, 0, 224, 218, 1, 0, 0, 0, 224, 220, 1, 0, 0, 0, 224, 222, 1, 0, 0, 0, 225, 7, 1, 0, 0, 0, 226, 227, 5, 16, 0, 0, 227, 228, 3, 132, 66, 0, 228, 9, 1, 0, 0, 0, 229, 230, 3, 54, 27, 0, 230, 11, 1, 0, 0, 0, 231, 232, 5, 12, 0, 0, 232, 233, 3, 14, 7, 0, 233, 13, 1, 0, 0, 0, 234, 239, 3, 16, 8, 0, 235, 236, 5, 62, 0, 0, 236, 238, 3, 16, 8, 0, 237, 235, 1, 0, 0, 0, 238, 241, 1, 0, 0, 0, 239, 237, 1, 0, 0, 0, 239, 240, 1, 0, 0, 0, 240, 15, 1, 0, 0, 0, 241, 239, 1, 0, 0, 0, 242, 243, 3, 48, 24, 0, 243, 244, 5, 58, 0, 0, 244, 246, 1, 0, 0, 0, 245, 242, 1, 0, 0, 0, 245, 246, 1, 0, 0, 0, 246, 247, 1, 0, 0, 0, 247, 248, 3, 132, 66, 0, 248, 17, 1, 0, 0, 0, 249, 254, 3, 20, 10, 0, 250, 251, 5, 62, 0, 0, 251, 253, 3, 20, 10, 0, 252, 250, 1, 0, 0, 0, 253, 256, 1, 0, 0, 0, 254, 252, 1, 0, 0, 0, 254, 255, 1, 0, 0, 0, 255, 19, 1, 0, 0, 0, 256, 254, 1, 0, 0, 0, 257, 260, 3, 48, 24, 0, 258, 259, 5, 58, 0, 0, 259, 261, 3, 132, 66, 0, 260, 258, 1, 0, 0, 0, 260, 261, 1, 0, 0, 0, 261, 21, 1, 0, 0, 0, 262, 263, 5, 19, 0, 0, 263, 264, 3, 26, 13, 0, 264, 23, 1, 0, 0, 0, 265, 266, 5, 20, 0, 0, 266, 267, 3, 26, 13, 0, 267, 25, 1, 0, 0, 0, 268, 273, 3, 28, 14, 0, 269, 270, 5, 62, 0, 0, 270, 272, 3, 28, 14, 0, 271, 269, 1, 0, 0, 0, 272, 275, 1, 0, 0, 0, 273, 271, 1, 0, 0, 0, 273, 274, 1, 0, 0, 0, 274, 277, 1, 0, 0, 0, 275, 273, 1, 0, 0, 0, 276, 278, 3, 38, 19, 0, 277, 276, 1, 0, 0, 0, 277, 278, 1, 0, 0, 0, 278, 27, 1, 0, 0, 0, 279, 280, 3, 30, 15, 0, 280, 281, 5, 61, 0, 0, 281, 282, 3, 34, 17, 0, 282, 289, 1, 0, 0, 0, 283, 284, 3, 34, 17, 0, 284, 285, 5, 60, 0, 0, 285, 286, 3, 32, 16, 0, 286, 289, 1, 0, 0, 0, 287, 289, 3, 36, 18, 0, 288, 279, 1, 0, 0, 0, 288, 283, 1, 0, 0, 0, 288, 287, 1, 0, 0, 0, 289, 29, 1, 0, 0, 0, 290, 291, 5, 107, 0, 0, 291, 31, 1, 0, 0, 0, 292, 293, 5, 107, 0, 0, 293, 33, 1, 0, 0, 0, 294, 295, 5, 107, 0, 0, 295, 35, 1, 0, 0, 0, 296, 297, 7, 0, 0, 0, 297, 37, 1, 0, 0, 0, 298, 299, 5, 106, 0, 0, 299, 304, 5, 107, 0, 0, 300, 301, 5, 62, 0, 0, 301, 303, 5, 107, 0, 0, 302, 300, 1, 0, 0, 0, 303, 306, 1, 0, 0, 0, 304, 302, 1, 0, 0, 0, 304, 305, 1, 0, 0, 0, 305, 39, 1, 0, 0, 0, 306, 304, 1, 0, 0, 0, 307, 308, 5, 9, 0, 0, 308, 309, 3, 14, 7, 0, 309, 41, 1, 0, 0, 0, 310, 312, 5, 15, 0, 0, 311, 313, 3, 44, 22, 0, 312, 311, 1, 0, 0, 0, 312, 313, 1, 0, 0, 0, 313, 316, 1, 0, 0, 0, 314, 315, 5, 59, 0, 0, 315, 317, 3, 14, 7, 0, 316, 314, 1, 0, 0, 0, 316, 317, 1, 0, 0, 0, 317, 43, 1, 0, 0, 0, 318, 323, 3, 46, 23, 0, 319, 320, 5, 62, 0, 0, 320, 322, 3, 46, 23, 0, 321, 319, 1, 0, 0, 0, 322, 325, 1, 0, 0, 0, 323, 321, 1, 0, 0, 0, 323, 324, 1, 0, 0, 0, 324, 45, 1, 0, 0, 0, 325, 323, 1, 0, 0, 0, 326, 329, 3, 16, 8, 0, 327, 328, 5, 16, 0, 0, 328, 330, 3, 132, 66, 0, 329, 327, 1, 0, 0, 0, 329, 330, 1, 0, 0, 0, 330, 47, 1, 0, 0, 0, 331, 336, 3, 62, 31, 0, 332, 333, 5, 64, 0, 0, 333, 335, 3, 62, 31, 0, 334, 332, 1, 0, 0, 0, 335, 338, 1, 0, 0, 0, 336, 334, 1, 0, 0, 0, 336, 337, 1, 0, 0, 0, 337, 49, 1, 0, 0, 0, 338, 336, 1, 0, 0, 0, 339, 344, 3, 56, 28, 0, 340, 341, 5, 64, 0, 0, 341, 343, 3, 56, 28, 0, 342, 340, 1, 0, 0, 0, 343, 346, 1, 0, 0, 0, 344, 342, 1, 0, 0, 0, 344, 345, 1, 0, 0, 0, 345, 51, 1, 0, 0, 0, 346, 344, 1, 0, 0, 0, 347, 352, 3, 50, 25, 0, 348, 349, 5, 62, 0, 0, 349, 351, 3, 50, 25, 0, 350, 348, 1, 0, 0, 0, 351, 354, 1, 0, 0, 0, 352, 350, 1, 0, 0, 0, 352, 353, 1, 0, 0, 0, 353, 53, 1, 0, 0, 0, 354, 352, 1, 0, 0, 0, 355, 356, 7, 1, 0, 0, 356, 55, 1, 0, 0, 0, 357, 361, 5, 128, 0, 0, 358, 361, 3, 58, 29, 0, 359, 361, 3, 60, 30, 0, 360, 357, 1, 0, 0, 0, 360, 358, 1, 0, 0, 0, 360, 359, 1, 0, 0, 0, 361, 57, 1, 0, 0, 0, 362, 365, 5, 76, 0, 0, 363, 365, 5, 95, 0, 0, 364, 362, 1, 0, 0, 0, 364, 363, 1, 0, 0, 0, 365, 59, 1, 0, 0, 0, 366, 369, 5, 94, 0, 0, 367, 369, 5, 96, 0, 0, 368, 366, 1, 0, 0, 0, 368, 367, 1, 0, 0, 0, 369, 61, 1, 0, 0, 0, 370, 374, 3, 54, 27, 0, 371, 374, 3, 58, 29, 0, 372, 374, 3, 60, 30, 0, 373, 370, 1, 0, 0, 0, 373, 371, 1, 0, 0, 0, 373, 372, 1, 0, 0, 0, 374, 63, 1, 0, 0, 0, 375, 376, 5, 11, 0, 0, 376, 377, 3, 152, 76, 0, 377, 65, 1, 0, 0, 0, 378, 379, 5, 14, 0, 0, 379, 384, 3, 68, 34, 0, 380, 381, 5, 62, 0, 0, 381, 383, 3, 68, 34, 0, 382, 380, 1, 0, 0, 0, 383, 386, 1, 0, 0, 0, 384, 382, 1, 0, 0, 0, 384, 385, 1, 0, 0, 0, 385, 67, 1, 0, 0, 0, 386, 384, 1, 0, 0, 0, 387, 389, 3, 132, 66, 0, 388, 390, 7, 2, 0, 0, 389, 388, 1, 0, 0, 0, 389, 390, 1, 0, 0, 0, 390, 393, 1, 0, 0, 0, 391, 392, 5, 73, 0, 0, 392, 394, 7, 3, 0, 0, 393, 391, 1, 0, 0, 0, 393, 394, 1, 0, 0, 0, 394, 69, 1, 0, 0, 0, 395, 396, 5, 30, 0, 0, 396, 397, 3, 52, 26, 0, 397, 71, 1, 0, 0, 0, 398, 399, 5, 29, 0, 0, 399, 400, 3, 52, 26, 0, 400, 73, 1, 0, 0, 0, 401, 402, 5, 32, 0, 0, 402, 407, 3, 76, 38, 0, 403, 404, 5, 62, 0, 0, 404, 406, 3, 76, 38, 0, 405, 403, 1, 0, 0, 0, 406, 409, 1, 0, 0, 0, 407, 405, 1, 0, 0, 0, 407, 408, 1, 0, 0, 0, 408, 75, 1, 0, 0, 0, 409, 407, 1, 0, 0, 0, 410, 411, 3, 50, 25, 0, 411, 412, 5, 132, 0, 0, 412, 413, 3, 50, 25, 0, 413, 419, 1, 0, 0, 0, 414, 415, 3, 50, 25, 0, 415, 416, 5, 58, 0, 0, 416, 417, 3, 50, 25, 0, 417, 419, 1, 0, 0, 0, 418, 410, 1, 0, 0, 0, 418, 414, 1, 0, 0, 0, 419, 77, 1, 0, 0, 0, 420, 421, 5, 8, 0, 0, 421, 422, 3, 142, 71, 0, 422, 424, 3, 162, 81, 0, 423, 425, 3, 84, 42, 0, 424, 423, 1, 0, 0, 0, 424, 425, 1, 0, 0, 0, 425, 79, 1, 0, 0, 0, 426, 427, 5, 10, 0, 0, 427, 428, 3, 142, 71, 0, 428, 429, 3, 162, 81, 0, 429, 81, 1, 0, 0, 0, 430, 431, 5, 28, 0, 0, 431, 432, 3, 48, 24, 0, 432, 83, 1, 0, 0, 0, 433, 438, 3, 86, 43, 0, 434, 435, 5, 62, 0, 0, 435, 437, 3, 86, 43, 0, 436, 434, 1, 0, 0, 0, 437, 440, 1, 0, 0, 0, 438, 436, 1, 0, 0, 0, 438, 439, 1, 0, 0, 0, 439, 85, 1, 0, 0, 0, 440, 438, 1, 0, 0, 0, 441, 442, 3, 54, 27, 0, 442, 443, 5, 58, 0, 0, 443, 444, 3, 152, 76, 0, 444, 87, 1, 0, 0, 0, 445, 446, 5, 6, 0, 0, 446, 447, 3, 90, 45, 0, 447, 89, 1, 0, 0, 0, 448, 449, 5, 99, 0, 0, 449, 450, 3, 2, 1, 0, 450, 451, 5, 100, 0, 0, 451, 91, 1, 0, 0, 0, 452, 453, 5, 33, 0, 0, 453, 454, 5, 136, 0, 0, 454, 93, 1, 0, 0, 0, 455, 456, 5, 5, 0, 0, 456, 459, 3, 96, 48, 0, 457, 458, 5, 74, 0, 0, 458, 460, 3, 50, 25, 0, 459, 457, 1, 0, 0, 0, 459, 460, 1, 0, 0, 0, 460, 470, 1, 0, 0, 0, 461, 462, 5, 79, 0, 0, 462, 467, 3, 98, 49, 0, 463, 464, 5, 62, 0, 0, 464, 466, 3, 98, 49, 0, 465, 463, 1, 0, 0, 0, 466, 469, 1, 0, 0, 0, 467, 465, 1, 0, 0, 0, 467, 468, 1, 0, 0, 0, 468, 471, 1, 0, 0, 0, 469, 467, 1, 0, 0, 0, 470, 461, 1, 0, 0, 0, 470, 471, 1, 0, 0, 0, 471, 95, 1, 0, 0, 0, 472, 473, 7, 4, 0, 0, 473, 97, 1, 0, 0, 0, 474, 475, 3, 50, 25, 0, 475, 476, 5, 58, 0, 0, 476, 478, 1, 0, 0, 0, 477, 474, 1, 0, 0, 0, 477, 478, 1, 0, 0, 0, 478, 479, 1, 0, 0, 0, 479, 480, 3, 50, 25, 0, 480, 99, 1, 0, 0, 0, 481, 482, 5, 13, 0, 0, 482, 483, 3, 152, 76, 0, 483, 101, 1, 0, 0, 0, 484, 485, 5, 4, 0, 0, 485, 488, 3, 48, 24, 0, 486, 487, 5, 74, 0, 0, 487, 489, 3, 48, 24, 0, 488, 486, 1, 0, 0, 0, 488, 489, 1, 0, 0, 0, 489, 495, 1, 0, 0, 0, 490, 491, 5, 132, 0, 0, 491, 492, 3, 48, 24, 0, 492, 493, 5, 62, 0, 0, 493, 494, 3, 48, 24, 0, 494, 496, 1, 0, 0, 0, 495, 490, 1, 0, 0, 0, 495, 496, 1, 0, 0, 0, 496, 103, 1, 0, 0, 0, 497, 498, 5, 21, 0, 0, 498, 499, 3, 106, 53, 0, 499, 105, 1, 0, 0, 0, 500, 502, 3, 108, 54, 0, 501, 500, 1, 0, 0, 0, 502, 503, 1, 0, 0, 0, 503, 501, 1, 0, 0, 0, 503, 504, 1, 0, 0, 0, 504, 107, 1, 0, 0, 0, 505, 506, 5, 99, 0, 0, 506, 507, 3, 110, 55, 0, 507, 508, 5, 100, 0, 0, 508, 109, 1, 0, 0, 0, 509, 510, 6, 55, -1, 0, 510, 511, 3, 112, 56, 0, 511, 517, 1, 0, 0, 0, 512, 513, 10, 1, 0, 0, 513, 514, 5, 52, 0, 0, 514, 516, 3, 112, 56, 0, 515, 512, 1, 0, 0, 0, 516, 519, 1, 0, 0, 0, 517, 515, 1, 0, 0, 0, 517, 518, 1, 0, 0, 0, 518, 111, 1, 0, 0, 0, 519, 517, 1, 0, 0, 0, 520, 521, 3, 6, 3, 0, 521, 113, 1, 0, 0, 0, 522, 526, 5, 7, 0, 0, 523, 524, 3, 48, 24, 0, 524, 525, 5, 58, 0, 0, 525, 527, 1, 0, 0, 0, 526, 523, 1, 0, 0, 0, 526, 527, 1, 0, 0, 0, 527, 528, 1, 0, 0, 0, 528, 529, 3, 142, 71, 0, 529, 530, 5, 79, 0, 0, 530, 531, 3, 62, 31, 0, 531, 115, 1, 0, 0, 0, 532, 533, 5, 27, 0, 0, 533, 534, 3, 28, 14, 0, 534, 535, 5, 74, 0, 0, 535, 536, 3, 52, 26, 0, 536, 117, 1, 0, 0, 0, 537, 538, 5, 17, 0, 0, 538, 541, 3, 44, 22, 0, 539, 540, 5, 59, 0, 0, 540, 542, 3, 14, 7, 0, 541, 539, 1, 0, 0, 0, 541, 542, 1, 0, 0, 0, 542, 119, 1, 0, 0, 0, 543, 544, 5, 31, 0, 0, 544, 545, 3, 52, 26, 0, 545, 121, 1, 0, 0, 0, 546, 547, 5, 22, 0, 0, 547, 123, 1, 0, 0, 0, 548, 553, 3, 126, 63, 0, 549, 550, 5, 62, 0, 0, 550, 552, 3, 126, 63, 0, 551, 549, 1, 0, 0, 0, 552, 555, 1, 0, 0, 0, 553, 551, 1, 0, 0, 0, 553, 554, 1, 0, 0, 0, 554, 125, 1, 0, 0, 0, 555, 553, 1, 0, 0, 0, 556, 557, 3, 54, 27, 0, 557, 558, 5, 58, 0, 0, 558, 559, 3, 128, 64, 0, 559, 127, 1, 0, 0, 0, 560, 563, 3, 152, 76, 0, 561, 563, 3, 54, 27, 0, 562, 560, 1, 0, 0, 0, 562, 561, 1, 0, 0, 0, 563, 129, 1, 0, 0, 0, 564, 565, 5, 18, 0, 0, 565, 566, 3, 152, 76, 0, 566, 567, 5, 74, 0, 0, 567, 570, 3, 18, 9, 0, 568, 569, 5, 79, 0, 0, 569, 571, 3, 124, 62, 0, 570, 568, 1, 0, 0, 0, 570, 571, 1, 0, 0, 0, 571, 131, 1, 0, 0, 0, 572, 573, 6, 66, -1, 0, 573, 574, 5, 71, 0, 0, 574, 602, 3, 132, 66, 8, 575, 602, 3, 138, 69, 0, 576, 602, 3, 134, 67, 0, 577, 579, 3, 138, 69, 0, 578, 580, 5, 71, 0, 0, 579, 578, 1, 0, 0, 0, 579, 580, 1, 0, 0, 0, 580, 581, 1, 0, 0, 0, 581, 582, 5, 67, 0, 0, 582, 583, 5, 99, 0, 0, 583, 588, 3, 138, 69, 0, 584, 585, 5, 62, 0, 0, 585, 587, 3, 138, 69, 0, 586, 584, 1, 0, 0, 0, 587, 590, 1, 0, 0, 0, 588, 586, 1, 0, 0, 0, 588, 589, 1, 0, 0, 0, 589, 591, 1, 0, 0, 0, 590, 588, 1, 0, 0, 0, 591, 592, 5, 100, 0, 0, 592, 602, 1, 0, 0, 0, 593, 594, 3, 138, 69, 0, 594, 596, 5, 68, 0, 0, 595, 597, 5, 71, 0, 0, 596, 595, 1, 0, 0, 0, 596, 597, 1, 0, 0, 0, 597, 598, 1, 0, 0, 0, 598, 599, 5, 72, 0, 0, 599, 602, 1, 0, 0, 0, 600, 602, 3, 136, 68, 0, 601, 572, 1, 0, 0, 0, 601, 575, 1, 0, 0, 0, 601, 576, 1, 0, 0, 0, 601, 577, 1, 0, 0, 0, 601, 593, 1, 0, 0, 0, 601, 600, 1, 0, 0, 0, 602, 611, 1, 0, 0, 0, 603, 604, 10, 5, 0, 0, 604, 605, 5, 56, 0, 0, 605, 610, 3, 132, 66, 6, 606, 607, 10, 4, 0, 0, 607, 608, 5, 75, 0, 0, 608, 610, 3, 132, 66, 5, 609, 603, 1, 0, 0, 0, 609, 606, 1, 0, 0, 0, 610, 613, 1, 0, 0, 0, 611, 609, 1, 0, 0, 0, 611, 612, 1, 0, 0, 0, 612, 133, 1, 0, 0, 0, 613, 611, 1, 0, 0, 0, 614, 616, 3, 138, 69, 0, 615, 617, 5, 71, 0, 0, 616, 615, 1, 0, 0, 0, 616, 617, 1, 0, 0, 0, 617, 618, 1, 0, 0, 0, 618, 619, 5, 70, 0, 0, 619, 620, 3, 162, 81, 0, 620, 645, 1, 0, 0, 0, 621, 623, 3, 138, 69, 0, 622, 624, 5, 71, 0, 0, 623, 622, 1, 0, 0, 0, 623, 624, 1, 0, 0, 0, 624, 625, 1, 0, 0, 0, 625, 626, 5, 77, 0, 0, 626, 627, 3, 162, 81, 0, 627, 645, 1, 0, 0, 0, 628, 630, 3, 138, 69, 0, 629, 631, 5, 71, 0, 0, 630, 629, 1, 0, 0, 0, 630, 631, 1, 0, 0, 0, 631, 632, 1, 0, 0, 0, 632, 633, 5, 70, 0, 0, 633, 634, 5, 99, 0, 0, 634, 639, 3, 162, 81, 0, 635, 636, 5, 62, 0, 0, 636, 638, 3, 162, 81, 0, 637, 635, 1, 0, 0, 0, 638, 641, 1, 0, 0, 0, 639, 637, 1, 0, 0, 0, 639, 640, 1, 0, 0, 0, 640, 642, 1, 0, 0, 0, 641, 639, 1, 0, 0, 0, 642, 643, 5, 100, 0, 0, 643, 645, 1, 0, 0, 0, 644, 614, 1, 0, 0, 0, 644, 621, 1, 0, 0, 0, 644, 628, 1, 0, 0, 0, 645, 135, 1, 0, 0, 0, 646, 649, 3, 48, 24, 0, 647, 648, 5, 60, 0, 0, 648, 650, 3, 10, 5, 0, 649, 647, 1, 0, 0, 0, 649, 650, 1, 0, 0, 0, 650, 651, 1, 0, 0, 0, 651, 652, 5, 61, 0, 0, 652, 653, 3, 152, 76, 0, 653, 137, 1, 0, 0, 0, 654, 660, 3, 140, 70, 0, 655, 656, 3, 140, 70, 0, 656, 657, 3, 164, 82, 0, 657, 658, 3, 140, 70, 0, 658, 660, 1, 0, 0, 0, 659, 654, 1, 0, 0, 0, 659, 655, 1, 0, 0, 0, 660, 139, 1, 0, 0, 0, 661, 662, 6, 70, -1, 0, 662, 666, 3, 142, 71, 0, 663, 664, 7, 5, 0, 0, 664, 666, 3, 140, 70, 3, 665, 661, 1, 0, 0, 0, 665, 663, 1, 0, 0, 0, 666, 675, 1, 0, 0, 0, 667, 668, 10, 2, 0, 0, 668, 669, 7, 6, 0, 0, 669, 674, 3, 140, 70, 3, 670, 671, 10, 1, 0, 0, 671, 672, 7, 5, 0, 0, 672, 674, 3, 140, 70, 2, 673, 667, 1, 0, 0, 0, 673, 670, 1, 0, 0, 0, 674, 677, 1, 0, 0, 0, 675, 673, 1, 0, 0, 0, 675, 676, 1, 0, 0, 0, 676, 141, 1, 0, 0, 0, 677, 675, 1, 0, 0, 0, 678, 679, 6, 71, -1, 0, 679, 687, 3, 152, 76, 0, 680, 687, 3, 48, 24, 0, 681, 687, 3, 144, 72, 0, 682, 683, 5, 99, 0, 0, 683, 684, 3, 132, 66, 0, 684, 685, 5, 100, 0, 0, 685, 687, 1, 0, 0, 0, 686, 678, 1, 0, 0, 0, 686, 680, 1, 0, 0, 0, 686, 681, 1, 0, 0, 0, 686, 682, 1, 0, 0, 0, 687, 693, 1, 0, 0, 0, 688, 689, 10, 1, 0, 0, 689, 690, 5, 60, 0, 0, 690, 692, 3, 10, 5, 0, 691, 688, 1, 0, 0, 0, 692, 695, 1, 0, 0, 0, 693, 691, 1, 0, 0, 0, 693, 694, 1, 0, 0, 0, 694, 143, 1, 0, 0, 0, 695, 693, 1, 0, 0, 0, 696, 697, 3, 146, 73, 0, 697, 711, 5, 99, 0, 0, 698, 712, 5, 89, 0, 0, 699, 704, 3, 132, 66, 0, 700, 701, 5, 62, 0, 0, 701, 703, 3, 132, 66, 0, 702, 700, 1, 0, 0, 0, 703, 706, 1, 0, 0, 0, 704, 702, 1, 0, 0, 0, 704, 705, 1, 0, 0, 0, 705, 709, 1, 0, 0, 0, 706, 704, 1, 0, 0, 0, 707, 708, 5, 62, 0, 0, 708, 710, 3, 148, 74, 0, 709, 707, 1, 0, 0, 0, 709, 710, 1, 0, 0, 0, 710, 712, 1, 0, 0, 0, 711, 698, 1, 0, 0, 0, 711, 699, 1, 0, 0, 0, 711, 712, 1, 0, 0, 0, 712, 713, 1, 0, 0, 0, 713, 714, 5, 100, 0, 0, 714, 145, 1, 0, 0, 0, 715, 716, 3, 62, 31, 0, 716, 147, 1, 0, 0, 0, 717, 718, 5, 92, 0, 0, 718, 723, 3, 150, 75, 0, 719, 720, 5, 62, 0, 0, 720, 722, 3, 150, 75, 0, 721, 719, 1, 0, 0, 0, 722, 725, 1, 0, 0, 0, 723, 721, 1, 0, 0, 0, 723, 724, 1, 0, 0, 0, 724, 726, 1, 0, 0, 0, 725, 723, 1, 0, 0, 0, 726, 727, 5, 93, 0, 0, 727, 149, 1, 0, 0, 0, 728, 729, 3, 162, 81, 0, 729, 730, 5, 61, 0, 0, 730, 731, 3, 152, 76, 0, 731, 151, 1, 0, 0, 0, 732, 775, 5, 72, 0, 0, 733, 734, 3, 160, 80, 0, 734, 735, 5, 101, 0, 0, 735, 775, 1, 0, 0, 0, 736, 775, 3, 158, 79, 0, 737, 775, 3, 160, 80, 0, 738, 775, 3, 154, 77, 0, 739, 775, 3, 58, 29, 0, 740, 775, 3, 162, 81, 0, 741, 742, 5, 97, 0, 0, 742, 747, 3, 156, 78, 0, 743, 744, 5, 62, 0, 0, 744, 746, 3, 156, 78, 0, 745, 743, 1, 0, 0, 0, 746, 749, 1, 0, 0, 0, 747, 745, 1, 0, 0, 0, 747, 748, 1, 0, 0, 0, 748, 750, 1, 0, 0, 0, 749, 747, 1, 0, 0, 0, 750, 751, 5, 98, 0, 0, 751, 775, 1, 0, 0, 0, 752, 753, 5, 97, 0, 0, 753, 758, 3, 154, 77, 0, 754, 755, 5, 62, 0, 0, 755, 757, 3, 154, 77, 0, 756, 754, 1, 0, 0, 0, 757, 760, 1, 0, 0, 0, 758, 756, 1, 0, 0, 0, 758, 759, 1, 0, 0, 0, 759, 761, 1, 0, 0, 0, 760, 758, 1, 0, 0, 0, 761, 762, 5, 98, 0, 0, 762, 775, 1, 0, 0, 0, 763, 764, 5, 97, 0, 0, 764, 769, 3, 162, 81, 0, 765, 766, 5, 62, 0, 0, 766, 768, 3, 162, 81, 0, 767, 765, 1, 0, 0, 0, 768, 771, 1, 0, 0, 0, 769, 767, 1, 0, 0, 0, 769, 770, 1, 0, 0, 0, 770, 772, 1, 0, 0, 0, 771, 769, 1, 0, 0, 0, 772, 773, 5, 98, 0, 0, 773, 775, 1, 0, 0, 0, 774, 732, 1, 0, 0, 0, 774, 733, 1, 0, 0, 0, 774, 736, 1, 0, 0, 0, 774, 737, 1, 0, 0, 0, 774, 738, 1, 0, 0, 0, 774, 739, 1, 0, 0, 0, 774, 740, 1, 0, 0, 0, 774, 741, 1, 0, 0, 0, 774, 752, 1, 0, 0, 0, 774, 763, 1, 0, 0, 0, 775, 153, 1, 0, 0, 0, 776, 777, 7, 7, 0, 0, 777, 155, 1, 0, 0, 0, 778, 781, 3, 158, 79, 0, 779, 781, 3, 160, 80, 0, 780, 778, 1, 0, 0, 0, 780, 779, 1, 0, 0, 0, 781, 157, 1, 0, 0, 0, 782, 784, 7, 5, 0, 0, 783, 782, 1, 0, 0, 0, 783, 784, 1, 0, 0, 0, 784, 785, 1, 0, 0, 0, 785, 786, 5, 55, 0, 0, 786, 159, 1, 0, 0, 0, 787, 789, 7, 5, 0, 0, 788, 787, 1, 0, 0, 0, 788, 789, 1, 0, 0, 0, 789, 790, 1, 0, 0, 0, 790, 791, 5, 54, 0, 0, 791, 161, 1, 0, 0, 0, 792, 793, 5, 53, 0, 0, 793, 163, 1, 0, 0, 0, 794, 795, 7, 8, 0, 0, 795, 165, 1, 0, 0, 0, 796, 797, 7, 9, 0, 0, 797, 798, 5, 114, 0, 0, 798, 799, 3, 168, 84, 0, 799, 800, 3, 170, 85, 0, 800, 167, 1, 0, 0, 0, 801, 802, 3, 28, 14, 0, 802, 169, 1, 0, 0, 0, 803, 804, 5, 74, 0, 0, 804, 809, 3, 172, 86, 0, 805, 806, 5, 62, 0, 0, 806, 808, 3, 172, 86, 0, 807, 805, 1, 0, 0, 0, 808, 811, 1, 0, 0, 0, 809, 807, 1, 0, 0, 0, 809, 810, 1, 0, 0, 0, 810, 171, 1, 0, 0, 0, 811, 809, 1, 0, 0, 0, 812, 813, 3, 138, 69, 0, 813, 173, 1, 0, 0, 0, 72, 185, 195, 224, 239, 245, 254, 260, 273, 277, 288, 304, 312, 316, 323, 329, 336, 344, 352, 360, 364, 368, 373, 384, 389, 393, 407, 418, 424, 438, 459, 467, 470, 477, 488, 495, 503, 517, 526, 541, 553, 562, 570, 579, 588, 596, 601, 609, 611, 616, 623, 630, 639, 644, 649, 659, 665, 673, 675, 686, 693, 704, 709, 711, 723, 747, 758, 769, 774, 780, 783, 788, 809] \ No newline at end of file +[4, 1, 139, 831, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15, 7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 2, 19, 7, 19, 2, 20, 7, 20, 2, 21, 7, 21, 2, 22, 7, 22, 2, 23, 7, 23, 2, 24, 7, 24, 2, 25, 7, 25, 2, 26, 7, 26, 2, 27, 7, 27, 2, 28, 7, 28, 2, 29, 7, 29, 2, 30, 7, 30, 2, 31, 7, 31, 2, 32, 7, 32, 2, 33, 7, 33, 2, 34, 7, 34, 2, 35, 7, 35, 2, 36, 7, 36, 2, 37, 7, 37, 2, 38, 7, 38, 2, 39, 7, 39, 2, 40, 7, 40, 2, 41, 7, 41, 2, 42, 7, 42, 2, 43, 7, 43, 2, 44, 7, 44, 2, 45, 7, 45, 2, 46, 7, 46, 2, 47, 7, 47, 2, 48, 7, 48, 2, 49, 7, 49, 2, 50, 7, 50, 2, 51, 7, 51, 2, 52, 7, 52, 2, 53, 7, 53, 2, 54, 7, 54, 2, 55, 7, 55, 2, 56, 7, 56, 2, 57, 7, 57, 2, 58, 7, 58, 2, 59, 7, 59, 2, 60, 7, 60, 2, 61, 7, 61, 2, 62, 7, 62, 2, 63, 7, 63, 2, 64, 7, 64, 2, 65, 7, 65, 2, 66, 7, 66, 2, 67, 7, 67, 2, 68, 7, 68, 2, 69, 7, 69, 2, 70, 7, 70, 2, 71, 7, 71, 2, 72, 7, 72, 2, 73, 7, 73, 2, 74, 7, 74, 2, 75, 7, 75, 2, 76, 7, 76, 2, 77, 7, 77, 2, 78, 7, 78, 2, 79, 7, 79, 2, 80, 7, 80, 2, 81, 7, 81, 2, 82, 7, 82, 2, 83, 7, 83, 2, 84, 7, 84, 2, 85, 7, 85, 2, 86, 7, 86, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 184, 8, 1, 10, 1, 12, 1, 187, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 196, 8, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 3, 3, 225, 8, 3, 1, 4, 1, 4, 1, 4, 1, 5, 1, 5, 1, 6, 1, 6, 1, 6, 1, 7, 1, 7, 1, 7, 5, 7, 238, 8, 7, 10, 7, 12, 7, 241, 9, 7, 1, 8, 1, 8, 1, 8, 3, 8, 246, 8, 8, 1, 8, 1, 8, 1, 9, 1, 9, 1, 9, 5, 9, 253, 8, 9, 10, 9, 12, 9, 256, 9, 9, 1, 10, 1, 10, 1, 10, 3, 10, 261, 8, 10, 1, 11, 1, 11, 1, 11, 1, 12, 1, 12, 1, 12, 1, 13, 1, 13, 1, 13, 5, 13, 272, 8, 13, 10, 13, 12, 13, 275, 9, 13, 1, 13, 3, 13, 278, 8, 13, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 1, 14, 3, 14, 289, 8, 14, 1, 15, 1, 15, 1, 16, 1, 16, 1, 17, 1, 17, 1, 18, 1, 18, 1, 19, 1, 19, 1, 19, 1, 19, 5, 19, 303, 8, 19, 10, 19, 12, 19, 306, 9, 19, 1, 20, 1, 20, 1, 20, 1, 21, 1, 21, 3, 21, 313, 8, 21, 1, 21, 1, 21, 3, 21, 317, 8, 21, 1, 22, 1, 22, 1, 22, 5, 22, 322, 8, 22, 10, 22, 12, 22, 325, 9, 22, 1, 23, 1, 23, 1, 23, 3, 23, 330, 8, 23, 1, 24, 1, 24, 1, 24, 5, 24, 335, 8, 24, 10, 24, 12, 24, 338, 9, 24, 1, 25, 1, 25, 1, 25, 5, 25, 343, 8, 25, 10, 25, 12, 25, 346, 9, 25, 1, 26, 1, 26, 1, 26, 5, 26, 351, 8, 26, 10, 26, 12, 26, 354, 9, 26, 1, 27, 1, 27, 1, 28, 1, 28, 1, 28, 3, 28, 361, 8, 28, 1, 29, 1, 29, 3, 29, 365, 8, 29, 1, 30, 1, 30, 3, 30, 369, 8, 30, 1, 31, 1, 31, 1, 31, 3, 31, 374, 8, 31, 1, 32, 1, 32, 1, 32, 1, 33, 1, 33, 1, 33, 1, 33, 5, 33, 383, 8, 33, 10, 33, 12, 33, 386, 9, 33, 1, 34, 1, 34, 3, 34, 390, 8, 34, 1, 34, 1, 34, 3, 34, 394, 8, 34, 1, 35, 1, 35, 1, 35, 1, 36, 1, 36, 1, 36, 1, 37, 1, 37, 1, 37, 1, 37, 5, 37, 406, 8, 37, 10, 37, 12, 37, 409, 9, 37, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 1, 38, 3, 38, 419, 8, 38, 1, 39, 1, 39, 1, 39, 1, 39, 3, 39, 425, 8, 39, 1, 40, 1, 40, 1, 40, 1, 40, 1, 41, 1, 41, 1, 41, 1, 42, 1, 42, 1, 42, 5, 42, 437, 8, 42, 10, 42, 12, 42, 440, 9, 42, 1, 43, 1, 43, 1, 43, 1, 43, 1, 44, 1, 44, 1, 44, 1, 45, 1, 45, 1, 45, 1, 45, 1, 46, 1, 46, 1, 46, 1, 47, 1, 47, 1, 47, 1, 47, 3, 47, 460, 8, 47, 1, 47, 1, 47, 1, 47, 1, 47, 5, 47, 466, 8, 47, 10, 47, 12, 47, 469, 9, 47, 3, 47, 471, 8, 47, 1, 48, 1, 48, 1, 49, 1, 49, 1, 49, 3, 49, 478, 8, 49, 1, 49, 1, 49, 1, 50, 1, 50, 1, 50, 1, 51, 1, 51, 1, 51, 1, 51, 3, 51, 489, 8, 51, 1, 51, 1, 51, 1, 51, 1, 51, 1, 51, 3, 51, 496, 8, 51, 1, 52, 1, 52, 1, 52, 1, 53, 4, 53, 502, 8, 53, 11, 53, 12, 53, 503, 1, 54, 1, 54, 1, 54, 1, 54, 1, 55, 1, 55, 1, 55, 1, 55, 1, 55, 1, 55, 5, 55, 516, 8, 55, 10, 55, 12, 55, 519, 9, 55, 1, 56, 1, 56, 1, 57, 1, 57, 1, 57, 1, 57, 3, 57, 527, 8, 57, 1, 57, 1, 57, 1, 57, 1, 57, 1, 58, 1, 58, 1, 58, 1, 58, 1, 58, 1, 59, 1, 59, 1, 59, 1, 59, 3, 59, 542, 8, 59, 1, 60, 1, 60, 1, 60, 1, 61, 1, 61, 1, 62, 1, 62, 1, 62, 5, 62, 552, 8, 62, 10, 62, 12, 62, 555, 9, 62, 1, 63, 1, 63, 1, 63, 1, 63, 1, 64, 1, 64, 3, 64, 563, 8, 64, 1, 65, 1, 65, 1, 65, 1, 65, 1, 65, 1, 65, 3, 65, 571, 8, 65, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 3, 66, 580, 8, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 5, 66, 587, 8, 66, 10, 66, 12, 66, 590, 9, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 3, 66, 597, 8, 66, 1, 66, 1, 66, 1, 66, 3, 66, 602, 8, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 1, 66, 5, 66, 610, 8, 66, 10, 66, 12, 66, 613, 9, 66, 1, 67, 1, 67, 3, 67, 617, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 3, 67, 624, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 3, 67, 631, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 5, 67, 638, 8, 67, 10, 67, 12, 67, 641, 9, 67, 1, 67, 1, 67, 1, 67, 1, 67, 3, 67, 647, 8, 67, 1, 67, 1, 67, 1, 67, 1, 67, 1, 67, 5, 67, 654, 8, 67, 10, 67, 12, 67, 657, 9, 67, 1, 67, 1, 67, 3, 67, 661, 8, 67, 1, 68, 1, 68, 1, 68, 3, 68, 666, 8, 68, 1, 68, 1, 68, 1, 68, 1, 69, 1, 69, 1, 69, 1, 69, 1, 69, 3, 69, 676, 8, 69, 1, 70, 1, 70, 1, 70, 1, 70, 3, 70, 682, 8, 70, 1, 70, 1, 70, 1, 70, 1, 70, 1, 70, 1, 70, 5, 70, 690, 8, 70, 10, 70, 12, 70, 693, 9, 70, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 1, 71, 3, 71, 703, 8, 71, 1, 71, 1, 71, 1, 71, 5, 71, 708, 8, 71, 10, 71, 12, 71, 711, 9, 71, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 1, 72, 5, 72, 719, 8, 72, 10, 72, 12, 72, 722, 9, 72, 1, 72, 1, 72, 3, 72, 726, 8, 72, 3, 72, 728, 8, 72, 1, 72, 1, 72, 1, 73, 1, 73, 1, 74, 1, 74, 1, 74, 1, 74, 5, 74, 738, 8, 74, 10, 74, 12, 74, 741, 9, 74, 1, 74, 1, 74, 1, 75, 1, 75, 1, 75, 1, 75, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 5, 76, 762, 8, 76, 10, 76, 12, 76, 765, 9, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 5, 76, 773, 8, 76, 10, 76, 12, 76, 776, 9, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 1, 76, 5, 76, 784, 8, 76, 10, 76, 12, 76, 787, 9, 76, 1, 76, 1, 76, 3, 76, 791, 8, 76, 1, 77, 1, 77, 1, 78, 1, 78, 3, 78, 797, 8, 78, 1, 79, 3, 79, 800, 8, 79, 1, 79, 1, 79, 1, 80, 3, 80, 805, 8, 80, 1, 80, 1, 80, 1, 81, 1, 81, 1, 82, 1, 82, 1, 83, 1, 83, 1, 83, 1, 83, 1, 83, 1, 84, 1, 84, 1, 85, 1, 85, 1, 85, 1, 85, 5, 85, 824, 8, 85, 10, 85, 12, 85, 827, 9, 85, 1, 86, 1, 86, 1, 86, 0, 5, 2, 110, 132, 140, 142, 87, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138, 140, 142, 144, 146, 148, 150, 152, 154, 156, 158, 160, 162, 164, 166, 168, 170, 172, 0, 10, 2, 0, 53, 53, 107, 107, 1, 0, 101, 102, 2, 0, 57, 57, 63, 63, 2, 0, 66, 66, 69, 69, 2, 0, 38, 38, 53, 53, 1, 0, 87, 88, 1, 0, 89, 91, 2, 0, 65, 65, 78, 78, 2, 0, 80, 80, 82, 86, 2, 0, 23, 23, 25, 26, 860, 0, 174, 1, 0, 0, 0, 2, 177, 1, 0, 0, 0, 4, 195, 1, 0, 0, 0, 6, 224, 1, 0, 0, 0, 8, 226, 1, 0, 0, 0, 10, 229, 1, 0, 0, 0, 12, 231, 1, 0, 0, 0, 14, 234, 1, 0, 0, 0, 16, 245, 1, 0, 0, 0, 18, 249, 1, 0, 0, 0, 20, 257, 1, 0, 0, 0, 22, 262, 1, 0, 0, 0, 24, 265, 1, 0, 0, 0, 26, 268, 1, 0, 0, 0, 28, 288, 1, 0, 0, 0, 30, 290, 1, 0, 0, 0, 32, 292, 1, 0, 0, 0, 34, 294, 1, 0, 0, 0, 36, 296, 1, 0, 0, 0, 38, 298, 1, 0, 0, 0, 40, 307, 1, 0, 0, 0, 42, 310, 1, 0, 0, 0, 44, 318, 1, 0, 0, 0, 46, 326, 1, 0, 0, 0, 48, 331, 1, 0, 0, 0, 50, 339, 1, 0, 0, 0, 52, 347, 1, 0, 0, 0, 54, 355, 1, 0, 0, 0, 56, 360, 1, 0, 0, 0, 58, 364, 1, 0, 0, 0, 60, 368, 1, 0, 0, 0, 62, 373, 1, 0, 0, 0, 64, 375, 1, 0, 0, 0, 66, 378, 1, 0, 0, 0, 68, 387, 1, 0, 0, 0, 70, 395, 1, 0, 0, 0, 72, 398, 1, 0, 0, 0, 74, 401, 1, 0, 0, 0, 76, 418, 1, 0, 0, 0, 78, 420, 1, 0, 0, 0, 80, 426, 1, 0, 0, 0, 82, 430, 1, 0, 0, 0, 84, 433, 1, 0, 0, 0, 86, 441, 1, 0, 0, 0, 88, 445, 1, 0, 0, 0, 90, 448, 1, 0, 0, 0, 92, 452, 1, 0, 0, 0, 94, 455, 1, 0, 0, 0, 96, 472, 1, 0, 0, 0, 98, 477, 1, 0, 0, 0, 100, 481, 1, 0, 0, 0, 102, 484, 1, 0, 0, 0, 104, 497, 1, 0, 0, 0, 106, 501, 1, 0, 0, 0, 108, 505, 1, 0, 0, 0, 110, 509, 1, 0, 0, 0, 112, 520, 1, 0, 0, 0, 114, 522, 1, 0, 0, 0, 116, 532, 1, 0, 0, 0, 118, 537, 1, 0, 0, 0, 120, 543, 1, 0, 0, 0, 122, 546, 1, 0, 0, 0, 124, 548, 1, 0, 0, 0, 126, 556, 1, 0, 0, 0, 128, 562, 1, 0, 0, 0, 130, 564, 1, 0, 0, 0, 132, 601, 1, 0, 0, 0, 134, 660, 1, 0, 0, 0, 136, 662, 1, 0, 0, 0, 138, 675, 1, 0, 0, 0, 140, 681, 1, 0, 0, 0, 142, 702, 1, 0, 0, 0, 144, 712, 1, 0, 0, 0, 146, 731, 1, 0, 0, 0, 148, 733, 1, 0, 0, 0, 150, 744, 1, 0, 0, 0, 152, 790, 1, 0, 0, 0, 154, 792, 1, 0, 0, 0, 156, 796, 1, 0, 0, 0, 158, 799, 1, 0, 0, 0, 160, 804, 1, 0, 0, 0, 162, 808, 1, 0, 0, 0, 164, 810, 1, 0, 0, 0, 166, 812, 1, 0, 0, 0, 168, 817, 1, 0, 0, 0, 170, 819, 1, 0, 0, 0, 172, 828, 1, 0, 0, 0, 174, 175, 3, 2, 1, 0, 175, 176, 5, 0, 0, 1, 176, 1, 1, 0, 0, 0, 177, 178, 6, 1, -1, 0, 178, 179, 3, 4, 2, 0, 179, 185, 1, 0, 0, 0, 180, 181, 10, 1, 0, 0, 181, 182, 5, 52, 0, 0, 182, 184, 3, 6, 3, 0, 183, 180, 1, 0, 0, 0, 184, 187, 1, 0, 0, 0, 185, 183, 1, 0, 0, 0, 185, 186, 1, 0, 0, 0, 186, 3, 1, 0, 0, 0, 187, 185, 1, 0, 0, 0, 188, 196, 3, 22, 11, 0, 189, 196, 3, 12, 6, 0, 190, 196, 3, 92, 46, 0, 191, 192, 4, 2, 1, 0, 192, 196, 3, 24, 12, 0, 193, 194, 4, 2, 2, 0, 194, 196, 3, 88, 44, 0, 195, 188, 1, 0, 0, 0, 195, 189, 1, 0, 0, 0, 195, 190, 1, 0, 0, 0, 195, 191, 1, 0, 0, 0, 195, 193, 1, 0, 0, 0, 196, 5, 1, 0, 0, 0, 197, 225, 3, 40, 20, 0, 198, 225, 3, 8, 4, 0, 199, 225, 3, 70, 35, 0, 200, 225, 3, 64, 32, 0, 201, 225, 3, 42, 21, 0, 202, 225, 3, 66, 33, 0, 203, 225, 3, 72, 36, 0, 204, 225, 3, 74, 37, 0, 205, 225, 3, 78, 39, 0, 206, 225, 3, 80, 40, 0, 207, 225, 3, 94, 47, 0, 208, 225, 3, 82, 41, 0, 209, 225, 3, 166, 83, 0, 210, 225, 3, 102, 51, 0, 211, 225, 3, 114, 57, 0, 212, 225, 3, 100, 50, 0, 213, 225, 3, 104, 52, 0, 214, 215, 4, 3, 3, 0, 215, 225, 3, 118, 59, 0, 216, 217, 4, 3, 4, 0, 217, 225, 3, 116, 58, 0, 218, 219, 4, 3, 5, 0, 219, 225, 3, 120, 60, 0, 220, 221, 4, 3, 6, 0, 221, 225, 3, 130, 65, 0, 222, 223, 4, 3, 7, 0, 223, 225, 3, 122, 61, 0, 224, 197, 1, 0, 0, 0, 224, 198, 1, 0, 0, 0, 224, 199, 1, 0, 0, 0, 224, 200, 1, 0, 0, 0, 224, 201, 1, 0, 0, 0, 224, 202, 1, 0, 0, 0, 224, 203, 1, 0, 0, 0, 224, 204, 1, 0, 0, 0, 224, 205, 1, 0, 0, 0, 224, 206, 1, 0, 0, 0, 224, 207, 1, 0, 0, 0, 224, 208, 1, 0, 0, 0, 224, 209, 1, 0, 0, 0, 224, 210, 1, 0, 0, 0, 224, 211, 1, 0, 0, 0, 224, 212, 1, 0, 0, 0, 224, 213, 1, 0, 0, 0, 224, 214, 1, 0, 0, 0, 224, 216, 1, 0, 0, 0, 224, 218, 1, 0, 0, 0, 224, 220, 1, 0, 0, 0, 224, 222, 1, 0, 0, 0, 225, 7, 1, 0, 0, 0, 226, 227, 5, 16, 0, 0, 227, 228, 3, 132, 66, 0, 228, 9, 1, 0, 0, 0, 229, 230, 3, 54, 27, 0, 230, 11, 1, 0, 0, 0, 231, 232, 5, 12, 0, 0, 232, 233, 3, 14, 7, 0, 233, 13, 1, 0, 0, 0, 234, 239, 3, 16, 8, 0, 235, 236, 5, 62, 0, 0, 236, 238, 3, 16, 8, 0, 237, 235, 1, 0, 0, 0, 238, 241, 1, 0, 0, 0, 239, 237, 1, 0, 0, 0, 239, 240, 1, 0, 0, 0, 240, 15, 1, 0, 0, 0, 241, 239, 1, 0, 0, 0, 242, 243, 3, 48, 24, 0, 243, 244, 5, 58, 0, 0, 244, 246, 1, 0, 0, 0, 245, 242, 1, 0, 0, 0, 245, 246, 1, 0, 0, 0, 246, 247, 1, 0, 0, 0, 247, 248, 3, 132, 66, 0, 248, 17, 1, 0, 0, 0, 249, 254, 3, 20, 10, 0, 250, 251, 5, 62, 0, 0, 251, 253, 3, 20, 10, 0, 252, 250, 1, 0, 0, 0, 253, 256, 1, 0, 0, 0, 254, 252, 1, 0, 0, 0, 254, 255, 1, 0, 0, 0, 255, 19, 1, 0, 0, 0, 256, 254, 1, 0, 0, 0, 257, 260, 3, 48, 24, 0, 258, 259, 5, 58, 0, 0, 259, 261, 3, 132, 66, 0, 260, 258, 1, 0, 0, 0, 260, 261, 1, 0, 0, 0, 261, 21, 1, 0, 0, 0, 262, 263, 5, 19, 0, 0, 263, 264, 3, 26, 13, 0, 264, 23, 1, 0, 0, 0, 265, 266, 5, 20, 0, 0, 266, 267, 3, 26, 13, 0, 267, 25, 1, 0, 0, 0, 268, 273, 3, 28, 14, 0, 269, 270, 5, 62, 0, 0, 270, 272, 3, 28, 14, 0, 271, 269, 1, 0, 0, 0, 272, 275, 1, 0, 0, 0, 273, 271, 1, 0, 0, 0, 273, 274, 1, 0, 0, 0, 274, 277, 1, 0, 0, 0, 275, 273, 1, 0, 0, 0, 276, 278, 3, 38, 19, 0, 277, 276, 1, 0, 0, 0, 277, 278, 1, 0, 0, 0, 278, 27, 1, 0, 0, 0, 279, 280, 3, 30, 15, 0, 280, 281, 5, 61, 0, 0, 281, 282, 3, 34, 17, 0, 282, 289, 1, 0, 0, 0, 283, 284, 3, 34, 17, 0, 284, 285, 5, 60, 0, 0, 285, 286, 3, 32, 16, 0, 286, 289, 1, 0, 0, 0, 287, 289, 3, 36, 18, 0, 288, 279, 1, 0, 0, 0, 288, 283, 1, 0, 0, 0, 288, 287, 1, 0, 0, 0, 289, 29, 1, 0, 0, 0, 290, 291, 5, 107, 0, 0, 291, 31, 1, 0, 0, 0, 292, 293, 5, 107, 0, 0, 293, 33, 1, 0, 0, 0, 294, 295, 5, 107, 0, 0, 295, 35, 1, 0, 0, 0, 296, 297, 7, 0, 0, 0, 297, 37, 1, 0, 0, 0, 298, 299, 5, 106, 0, 0, 299, 304, 5, 107, 0, 0, 300, 301, 5, 62, 0, 0, 301, 303, 5, 107, 0, 0, 302, 300, 1, 0, 0, 0, 303, 306, 1, 0, 0, 0, 304, 302, 1, 0, 0, 0, 304, 305, 1, 0, 0, 0, 305, 39, 1, 0, 0, 0, 306, 304, 1, 0, 0, 0, 307, 308, 5, 9, 0, 0, 308, 309, 3, 14, 7, 0, 309, 41, 1, 0, 0, 0, 310, 312, 5, 15, 0, 0, 311, 313, 3, 44, 22, 0, 312, 311, 1, 0, 0, 0, 312, 313, 1, 0, 0, 0, 313, 316, 1, 0, 0, 0, 314, 315, 5, 59, 0, 0, 315, 317, 3, 14, 7, 0, 316, 314, 1, 0, 0, 0, 316, 317, 1, 0, 0, 0, 317, 43, 1, 0, 0, 0, 318, 323, 3, 46, 23, 0, 319, 320, 5, 62, 0, 0, 320, 322, 3, 46, 23, 0, 321, 319, 1, 0, 0, 0, 322, 325, 1, 0, 0, 0, 323, 321, 1, 0, 0, 0, 323, 324, 1, 0, 0, 0, 324, 45, 1, 0, 0, 0, 325, 323, 1, 0, 0, 0, 326, 329, 3, 16, 8, 0, 327, 328, 5, 16, 0, 0, 328, 330, 3, 132, 66, 0, 329, 327, 1, 0, 0, 0, 329, 330, 1, 0, 0, 0, 330, 47, 1, 0, 0, 0, 331, 336, 3, 62, 31, 0, 332, 333, 5, 64, 0, 0, 333, 335, 3, 62, 31, 0, 334, 332, 1, 0, 0, 0, 335, 338, 1, 0, 0, 0, 336, 334, 1, 0, 0, 0, 336, 337, 1, 0, 0, 0, 337, 49, 1, 0, 0, 0, 338, 336, 1, 0, 0, 0, 339, 344, 3, 56, 28, 0, 340, 341, 5, 64, 0, 0, 341, 343, 3, 56, 28, 0, 342, 340, 1, 0, 0, 0, 343, 346, 1, 0, 0, 0, 344, 342, 1, 0, 0, 0, 344, 345, 1, 0, 0, 0, 345, 51, 1, 0, 0, 0, 346, 344, 1, 0, 0, 0, 347, 352, 3, 50, 25, 0, 348, 349, 5, 62, 0, 0, 349, 351, 3, 50, 25, 0, 350, 348, 1, 0, 0, 0, 351, 354, 1, 0, 0, 0, 352, 350, 1, 0, 0, 0, 352, 353, 1, 0, 0, 0, 353, 53, 1, 0, 0, 0, 354, 352, 1, 0, 0, 0, 355, 356, 7, 1, 0, 0, 356, 55, 1, 0, 0, 0, 357, 361, 5, 128, 0, 0, 358, 361, 3, 58, 29, 0, 359, 361, 3, 60, 30, 0, 360, 357, 1, 0, 0, 0, 360, 358, 1, 0, 0, 0, 360, 359, 1, 0, 0, 0, 361, 57, 1, 0, 0, 0, 362, 365, 5, 76, 0, 0, 363, 365, 5, 95, 0, 0, 364, 362, 1, 0, 0, 0, 364, 363, 1, 0, 0, 0, 365, 59, 1, 0, 0, 0, 366, 369, 5, 94, 0, 0, 367, 369, 5, 96, 0, 0, 368, 366, 1, 0, 0, 0, 368, 367, 1, 0, 0, 0, 369, 61, 1, 0, 0, 0, 370, 374, 3, 54, 27, 0, 371, 374, 3, 58, 29, 0, 372, 374, 3, 60, 30, 0, 373, 370, 1, 0, 0, 0, 373, 371, 1, 0, 0, 0, 373, 372, 1, 0, 0, 0, 374, 63, 1, 0, 0, 0, 375, 376, 5, 11, 0, 0, 376, 377, 3, 152, 76, 0, 377, 65, 1, 0, 0, 0, 378, 379, 5, 14, 0, 0, 379, 384, 3, 68, 34, 0, 380, 381, 5, 62, 0, 0, 381, 383, 3, 68, 34, 0, 382, 380, 1, 0, 0, 0, 383, 386, 1, 0, 0, 0, 384, 382, 1, 0, 0, 0, 384, 385, 1, 0, 0, 0, 385, 67, 1, 0, 0, 0, 386, 384, 1, 0, 0, 0, 387, 389, 3, 132, 66, 0, 388, 390, 7, 2, 0, 0, 389, 388, 1, 0, 0, 0, 389, 390, 1, 0, 0, 0, 390, 393, 1, 0, 0, 0, 391, 392, 5, 73, 0, 0, 392, 394, 7, 3, 0, 0, 393, 391, 1, 0, 0, 0, 393, 394, 1, 0, 0, 0, 394, 69, 1, 0, 0, 0, 395, 396, 5, 30, 0, 0, 396, 397, 3, 52, 26, 0, 397, 71, 1, 0, 0, 0, 398, 399, 5, 29, 0, 0, 399, 400, 3, 52, 26, 0, 400, 73, 1, 0, 0, 0, 401, 402, 5, 32, 0, 0, 402, 407, 3, 76, 38, 0, 403, 404, 5, 62, 0, 0, 404, 406, 3, 76, 38, 0, 405, 403, 1, 0, 0, 0, 406, 409, 1, 0, 0, 0, 407, 405, 1, 0, 0, 0, 407, 408, 1, 0, 0, 0, 408, 75, 1, 0, 0, 0, 409, 407, 1, 0, 0, 0, 410, 411, 3, 50, 25, 0, 411, 412, 5, 132, 0, 0, 412, 413, 3, 50, 25, 0, 413, 419, 1, 0, 0, 0, 414, 415, 3, 50, 25, 0, 415, 416, 5, 58, 0, 0, 416, 417, 3, 50, 25, 0, 417, 419, 1, 0, 0, 0, 418, 410, 1, 0, 0, 0, 418, 414, 1, 0, 0, 0, 419, 77, 1, 0, 0, 0, 420, 421, 5, 8, 0, 0, 421, 422, 3, 142, 71, 0, 422, 424, 3, 162, 81, 0, 423, 425, 3, 84, 42, 0, 424, 423, 1, 0, 0, 0, 424, 425, 1, 0, 0, 0, 425, 79, 1, 0, 0, 0, 426, 427, 5, 10, 0, 0, 427, 428, 3, 142, 71, 0, 428, 429, 3, 162, 81, 0, 429, 81, 1, 0, 0, 0, 430, 431, 5, 28, 0, 0, 431, 432, 3, 48, 24, 0, 432, 83, 1, 0, 0, 0, 433, 438, 3, 86, 43, 0, 434, 435, 5, 62, 0, 0, 435, 437, 3, 86, 43, 0, 436, 434, 1, 0, 0, 0, 437, 440, 1, 0, 0, 0, 438, 436, 1, 0, 0, 0, 438, 439, 1, 0, 0, 0, 439, 85, 1, 0, 0, 0, 440, 438, 1, 0, 0, 0, 441, 442, 3, 54, 27, 0, 442, 443, 5, 58, 0, 0, 443, 444, 3, 152, 76, 0, 444, 87, 1, 0, 0, 0, 445, 446, 5, 6, 0, 0, 446, 447, 3, 90, 45, 0, 447, 89, 1, 0, 0, 0, 448, 449, 5, 99, 0, 0, 449, 450, 3, 2, 1, 0, 450, 451, 5, 100, 0, 0, 451, 91, 1, 0, 0, 0, 452, 453, 5, 33, 0, 0, 453, 454, 5, 136, 0, 0, 454, 93, 1, 0, 0, 0, 455, 456, 5, 5, 0, 0, 456, 459, 3, 96, 48, 0, 457, 458, 5, 74, 0, 0, 458, 460, 3, 50, 25, 0, 459, 457, 1, 0, 0, 0, 459, 460, 1, 0, 0, 0, 460, 470, 1, 0, 0, 0, 461, 462, 5, 79, 0, 0, 462, 467, 3, 98, 49, 0, 463, 464, 5, 62, 0, 0, 464, 466, 3, 98, 49, 0, 465, 463, 1, 0, 0, 0, 466, 469, 1, 0, 0, 0, 467, 465, 1, 0, 0, 0, 467, 468, 1, 0, 0, 0, 468, 471, 1, 0, 0, 0, 469, 467, 1, 0, 0, 0, 470, 461, 1, 0, 0, 0, 470, 471, 1, 0, 0, 0, 471, 95, 1, 0, 0, 0, 472, 473, 7, 4, 0, 0, 473, 97, 1, 0, 0, 0, 474, 475, 3, 50, 25, 0, 475, 476, 5, 58, 0, 0, 476, 478, 1, 0, 0, 0, 477, 474, 1, 0, 0, 0, 477, 478, 1, 0, 0, 0, 478, 479, 1, 0, 0, 0, 479, 480, 3, 50, 25, 0, 480, 99, 1, 0, 0, 0, 481, 482, 5, 13, 0, 0, 482, 483, 3, 152, 76, 0, 483, 101, 1, 0, 0, 0, 484, 485, 5, 4, 0, 0, 485, 488, 3, 48, 24, 0, 486, 487, 5, 74, 0, 0, 487, 489, 3, 48, 24, 0, 488, 486, 1, 0, 0, 0, 488, 489, 1, 0, 0, 0, 489, 495, 1, 0, 0, 0, 490, 491, 5, 132, 0, 0, 491, 492, 3, 48, 24, 0, 492, 493, 5, 62, 0, 0, 493, 494, 3, 48, 24, 0, 494, 496, 1, 0, 0, 0, 495, 490, 1, 0, 0, 0, 495, 496, 1, 0, 0, 0, 496, 103, 1, 0, 0, 0, 497, 498, 5, 21, 0, 0, 498, 499, 3, 106, 53, 0, 499, 105, 1, 0, 0, 0, 500, 502, 3, 108, 54, 0, 501, 500, 1, 0, 0, 0, 502, 503, 1, 0, 0, 0, 503, 501, 1, 0, 0, 0, 503, 504, 1, 0, 0, 0, 504, 107, 1, 0, 0, 0, 505, 506, 5, 99, 0, 0, 506, 507, 3, 110, 55, 0, 507, 508, 5, 100, 0, 0, 508, 109, 1, 0, 0, 0, 509, 510, 6, 55, -1, 0, 510, 511, 3, 112, 56, 0, 511, 517, 1, 0, 0, 0, 512, 513, 10, 1, 0, 0, 513, 514, 5, 52, 0, 0, 514, 516, 3, 112, 56, 0, 515, 512, 1, 0, 0, 0, 516, 519, 1, 0, 0, 0, 517, 515, 1, 0, 0, 0, 517, 518, 1, 0, 0, 0, 518, 111, 1, 0, 0, 0, 519, 517, 1, 0, 0, 0, 520, 521, 3, 6, 3, 0, 521, 113, 1, 0, 0, 0, 522, 526, 5, 7, 0, 0, 523, 524, 3, 48, 24, 0, 524, 525, 5, 58, 0, 0, 525, 527, 1, 0, 0, 0, 526, 523, 1, 0, 0, 0, 526, 527, 1, 0, 0, 0, 527, 528, 1, 0, 0, 0, 528, 529, 3, 142, 71, 0, 529, 530, 5, 79, 0, 0, 530, 531, 3, 62, 31, 0, 531, 115, 1, 0, 0, 0, 532, 533, 5, 27, 0, 0, 533, 534, 3, 28, 14, 0, 534, 535, 5, 74, 0, 0, 535, 536, 3, 52, 26, 0, 536, 117, 1, 0, 0, 0, 537, 538, 5, 17, 0, 0, 538, 541, 3, 44, 22, 0, 539, 540, 5, 59, 0, 0, 540, 542, 3, 14, 7, 0, 541, 539, 1, 0, 0, 0, 541, 542, 1, 0, 0, 0, 542, 119, 1, 0, 0, 0, 543, 544, 5, 31, 0, 0, 544, 545, 3, 52, 26, 0, 545, 121, 1, 0, 0, 0, 546, 547, 5, 22, 0, 0, 547, 123, 1, 0, 0, 0, 548, 553, 3, 126, 63, 0, 549, 550, 5, 62, 0, 0, 550, 552, 3, 126, 63, 0, 551, 549, 1, 0, 0, 0, 552, 555, 1, 0, 0, 0, 553, 551, 1, 0, 0, 0, 553, 554, 1, 0, 0, 0, 554, 125, 1, 0, 0, 0, 555, 553, 1, 0, 0, 0, 556, 557, 3, 54, 27, 0, 557, 558, 5, 58, 0, 0, 558, 559, 3, 128, 64, 0, 559, 127, 1, 0, 0, 0, 560, 563, 3, 152, 76, 0, 561, 563, 3, 54, 27, 0, 562, 560, 1, 0, 0, 0, 562, 561, 1, 0, 0, 0, 563, 129, 1, 0, 0, 0, 564, 565, 5, 18, 0, 0, 565, 566, 3, 152, 76, 0, 566, 567, 5, 74, 0, 0, 567, 570, 3, 18, 9, 0, 568, 569, 5, 79, 0, 0, 569, 571, 3, 124, 62, 0, 570, 568, 1, 0, 0, 0, 570, 571, 1, 0, 0, 0, 571, 131, 1, 0, 0, 0, 572, 573, 6, 66, -1, 0, 573, 574, 5, 71, 0, 0, 574, 602, 3, 132, 66, 8, 575, 602, 3, 138, 69, 0, 576, 602, 3, 134, 67, 0, 577, 579, 3, 138, 69, 0, 578, 580, 5, 71, 0, 0, 579, 578, 1, 0, 0, 0, 579, 580, 1, 0, 0, 0, 580, 581, 1, 0, 0, 0, 581, 582, 5, 67, 0, 0, 582, 583, 5, 99, 0, 0, 583, 588, 3, 138, 69, 0, 584, 585, 5, 62, 0, 0, 585, 587, 3, 138, 69, 0, 586, 584, 1, 0, 0, 0, 587, 590, 1, 0, 0, 0, 588, 586, 1, 0, 0, 0, 588, 589, 1, 0, 0, 0, 589, 591, 1, 0, 0, 0, 590, 588, 1, 0, 0, 0, 591, 592, 5, 100, 0, 0, 592, 602, 1, 0, 0, 0, 593, 594, 3, 138, 69, 0, 594, 596, 5, 68, 0, 0, 595, 597, 5, 71, 0, 0, 596, 595, 1, 0, 0, 0, 596, 597, 1, 0, 0, 0, 597, 598, 1, 0, 0, 0, 598, 599, 5, 72, 0, 0, 599, 602, 1, 0, 0, 0, 600, 602, 3, 136, 68, 0, 601, 572, 1, 0, 0, 0, 601, 575, 1, 0, 0, 0, 601, 576, 1, 0, 0, 0, 601, 577, 1, 0, 0, 0, 601, 593, 1, 0, 0, 0, 601, 600, 1, 0, 0, 0, 602, 611, 1, 0, 0, 0, 603, 604, 10, 5, 0, 0, 604, 605, 5, 56, 0, 0, 605, 610, 3, 132, 66, 6, 606, 607, 10, 4, 0, 0, 607, 608, 5, 75, 0, 0, 608, 610, 3, 132, 66, 5, 609, 603, 1, 0, 0, 0, 609, 606, 1, 0, 0, 0, 610, 613, 1, 0, 0, 0, 611, 609, 1, 0, 0, 0, 611, 612, 1, 0, 0, 0, 612, 133, 1, 0, 0, 0, 613, 611, 1, 0, 0, 0, 614, 616, 3, 138, 69, 0, 615, 617, 5, 71, 0, 0, 616, 615, 1, 0, 0, 0, 616, 617, 1, 0, 0, 0, 617, 618, 1, 0, 0, 0, 618, 619, 5, 70, 0, 0, 619, 620, 3, 162, 81, 0, 620, 661, 1, 0, 0, 0, 621, 623, 3, 138, 69, 0, 622, 624, 5, 71, 0, 0, 623, 622, 1, 0, 0, 0, 623, 624, 1, 0, 0, 0, 624, 625, 1, 0, 0, 0, 625, 626, 5, 77, 0, 0, 626, 627, 3, 162, 81, 0, 627, 661, 1, 0, 0, 0, 628, 630, 3, 138, 69, 0, 629, 631, 5, 71, 0, 0, 630, 629, 1, 0, 0, 0, 630, 631, 1, 0, 0, 0, 631, 632, 1, 0, 0, 0, 632, 633, 5, 70, 0, 0, 633, 634, 5, 99, 0, 0, 634, 639, 3, 162, 81, 0, 635, 636, 5, 62, 0, 0, 636, 638, 3, 162, 81, 0, 637, 635, 1, 0, 0, 0, 638, 641, 1, 0, 0, 0, 639, 637, 1, 0, 0, 0, 639, 640, 1, 0, 0, 0, 640, 642, 1, 0, 0, 0, 641, 639, 1, 0, 0, 0, 642, 643, 5, 100, 0, 0, 643, 661, 1, 0, 0, 0, 644, 646, 3, 138, 69, 0, 645, 647, 5, 71, 0, 0, 646, 645, 1, 0, 0, 0, 646, 647, 1, 0, 0, 0, 647, 648, 1, 0, 0, 0, 648, 649, 5, 77, 0, 0, 649, 650, 5, 99, 0, 0, 650, 655, 3, 162, 81, 0, 651, 652, 5, 62, 0, 0, 652, 654, 3, 162, 81, 0, 653, 651, 1, 0, 0, 0, 654, 657, 1, 0, 0, 0, 655, 653, 1, 0, 0, 0, 655, 656, 1, 0, 0, 0, 656, 658, 1, 0, 0, 0, 657, 655, 1, 0, 0, 0, 658, 659, 5, 100, 0, 0, 659, 661, 1, 0, 0, 0, 660, 614, 1, 0, 0, 0, 660, 621, 1, 0, 0, 0, 660, 628, 1, 0, 0, 0, 660, 644, 1, 0, 0, 0, 661, 135, 1, 0, 0, 0, 662, 665, 3, 48, 24, 0, 663, 664, 5, 60, 0, 0, 664, 666, 3, 10, 5, 0, 665, 663, 1, 0, 0, 0, 665, 666, 1, 0, 0, 0, 666, 667, 1, 0, 0, 0, 667, 668, 5, 61, 0, 0, 668, 669, 3, 152, 76, 0, 669, 137, 1, 0, 0, 0, 670, 676, 3, 140, 70, 0, 671, 672, 3, 140, 70, 0, 672, 673, 3, 164, 82, 0, 673, 674, 3, 140, 70, 0, 674, 676, 1, 0, 0, 0, 675, 670, 1, 0, 0, 0, 675, 671, 1, 0, 0, 0, 676, 139, 1, 0, 0, 0, 677, 678, 6, 70, -1, 0, 678, 682, 3, 142, 71, 0, 679, 680, 7, 5, 0, 0, 680, 682, 3, 140, 70, 3, 681, 677, 1, 0, 0, 0, 681, 679, 1, 0, 0, 0, 682, 691, 1, 0, 0, 0, 683, 684, 10, 2, 0, 0, 684, 685, 7, 6, 0, 0, 685, 690, 3, 140, 70, 3, 686, 687, 10, 1, 0, 0, 687, 688, 7, 5, 0, 0, 688, 690, 3, 140, 70, 2, 689, 683, 1, 0, 0, 0, 689, 686, 1, 0, 0, 0, 690, 693, 1, 0, 0, 0, 691, 689, 1, 0, 0, 0, 691, 692, 1, 0, 0, 0, 692, 141, 1, 0, 0, 0, 693, 691, 1, 0, 0, 0, 694, 695, 6, 71, -1, 0, 695, 703, 3, 152, 76, 0, 696, 703, 3, 48, 24, 0, 697, 703, 3, 144, 72, 0, 698, 699, 5, 99, 0, 0, 699, 700, 3, 132, 66, 0, 700, 701, 5, 100, 0, 0, 701, 703, 1, 0, 0, 0, 702, 694, 1, 0, 0, 0, 702, 696, 1, 0, 0, 0, 702, 697, 1, 0, 0, 0, 702, 698, 1, 0, 0, 0, 703, 709, 1, 0, 0, 0, 704, 705, 10, 1, 0, 0, 705, 706, 5, 60, 0, 0, 706, 708, 3, 10, 5, 0, 707, 704, 1, 0, 0, 0, 708, 711, 1, 0, 0, 0, 709, 707, 1, 0, 0, 0, 709, 710, 1, 0, 0, 0, 710, 143, 1, 0, 0, 0, 711, 709, 1, 0, 0, 0, 712, 713, 3, 146, 73, 0, 713, 727, 5, 99, 0, 0, 714, 728, 5, 89, 0, 0, 715, 720, 3, 132, 66, 0, 716, 717, 5, 62, 0, 0, 717, 719, 3, 132, 66, 0, 718, 716, 1, 0, 0, 0, 719, 722, 1, 0, 0, 0, 720, 718, 1, 0, 0, 0, 720, 721, 1, 0, 0, 0, 721, 725, 1, 0, 0, 0, 722, 720, 1, 0, 0, 0, 723, 724, 5, 62, 0, 0, 724, 726, 3, 148, 74, 0, 725, 723, 1, 0, 0, 0, 725, 726, 1, 0, 0, 0, 726, 728, 1, 0, 0, 0, 727, 714, 1, 0, 0, 0, 727, 715, 1, 0, 0, 0, 727, 728, 1, 0, 0, 0, 728, 729, 1, 0, 0, 0, 729, 730, 5, 100, 0, 0, 730, 145, 1, 0, 0, 0, 731, 732, 3, 62, 31, 0, 732, 147, 1, 0, 0, 0, 733, 734, 5, 92, 0, 0, 734, 739, 3, 150, 75, 0, 735, 736, 5, 62, 0, 0, 736, 738, 3, 150, 75, 0, 737, 735, 1, 0, 0, 0, 738, 741, 1, 0, 0, 0, 739, 737, 1, 0, 0, 0, 739, 740, 1, 0, 0, 0, 740, 742, 1, 0, 0, 0, 741, 739, 1, 0, 0, 0, 742, 743, 5, 93, 0, 0, 743, 149, 1, 0, 0, 0, 744, 745, 3, 162, 81, 0, 745, 746, 5, 61, 0, 0, 746, 747, 3, 152, 76, 0, 747, 151, 1, 0, 0, 0, 748, 791, 5, 72, 0, 0, 749, 750, 3, 160, 80, 0, 750, 751, 5, 101, 0, 0, 751, 791, 1, 0, 0, 0, 752, 791, 3, 158, 79, 0, 753, 791, 3, 160, 80, 0, 754, 791, 3, 154, 77, 0, 755, 791, 3, 58, 29, 0, 756, 791, 3, 162, 81, 0, 757, 758, 5, 97, 0, 0, 758, 763, 3, 156, 78, 0, 759, 760, 5, 62, 0, 0, 760, 762, 3, 156, 78, 0, 761, 759, 1, 0, 0, 0, 762, 765, 1, 0, 0, 0, 763, 761, 1, 0, 0, 0, 763, 764, 1, 0, 0, 0, 764, 766, 1, 0, 0, 0, 765, 763, 1, 0, 0, 0, 766, 767, 5, 98, 0, 0, 767, 791, 1, 0, 0, 0, 768, 769, 5, 97, 0, 0, 769, 774, 3, 154, 77, 0, 770, 771, 5, 62, 0, 0, 771, 773, 3, 154, 77, 0, 772, 770, 1, 0, 0, 0, 773, 776, 1, 0, 0, 0, 774, 772, 1, 0, 0, 0, 774, 775, 1, 0, 0, 0, 775, 777, 1, 0, 0, 0, 776, 774, 1, 0, 0, 0, 777, 778, 5, 98, 0, 0, 778, 791, 1, 0, 0, 0, 779, 780, 5, 97, 0, 0, 780, 785, 3, 162, 81, 0, 781, 782, 5, 62, 0, 0, 782, 784, 3, 162, 81, 0, 783, 781, 1, 0, 0, 0, 784, 787, 1, 0, 0, 0, 785, 783, 1, 0, 0, 0, 785, 786, 1, 0, 0, 0, 786, 788, 1, 0, 0, 0, 787, 785, 1, 0, 0, 0, 788, 789, 5, 98, 0, 0, 789, 791, 1, 0, 0, 0, 790, 748, 1, 0, 0, 0, 790, 749, 1, 0, 0, 0, 790, 752, 1, 0, 0, 0, 790, 753, 1, 0, 0, 0, 790, 754, 1, 0, 0, 0, 790, 755, 1, 0, 0, 0, 790, 756, 1, 0, 0, 0, 790, 757, 1, 0, 0, 0, 790, 768, 1, 0, 0, 0, 790, 779, 1, 0, 0, 0, 791, 153, 1, 0, 0, 0, 792, 793, 7, 7, 0, 0, 793, 155, 1, 0, 0, 0, 794, 797, 3, 158, 79, 0, 795, 797, 3, 160, 80, 0, 796, 794, 1, 0, 0, 0, 796, 795, 1, 0, 0, 0, 797, 157, 1, 0, 0, 0, 798, 800, 7, 5, 0, 0, 799, 798, 1, 0, 0, 0, 799, 800, 1, 0, 0, 0, 800, 801, 1, 0, 0, 0, 801, 802, 5, 55, 0, 0, 802, 159, 1, 0, 0, 0, 803, 805, 7, 5, 0, 0, 804, 803, 1, 0, 0, 0, 804, 805, 1, 0, 0, 0, 805, 806, 1, 0, 0, 0, 806, 807, 5, 54, 0, 0, 807, 161, 1, 0, 0, 0, 808, 809, 5, 53, 0, 0, 809, 163, 1, 0, 0, 0, 810, 811, 7, 8, 0, 0, 811, 165, 1, 0, 0, 0, 812, 813, 7, 9, 0, 0, 813, 814, 5, 114, 0, 0, 814, 815, 3, 168, 84, 0, 815, 816, 3, 170, 85, 0, 816, 167, 1, 0, 0, 0, 817, 818, 3, 28, 14, 0, 818, 169, 1, 0, 0, 0, 819, 820, 5, 74, 0, 0, 820, 825, 3, 172, 86, 0, 821, 822, 5, 62, 0, 0, 822, 824, 3, 172, 86, 0, 823, 821, 1, 0, 0, 0, 824, 827, 1, 0, 0, 0, 825, 823, 1, 0, 0, 0, 825, 826, 1, 0, 0, 0, 826, 171, 1, 0, 0, 0, 827, 825, 1, 0, 0, 0, 828, 829, 3, 138, 69, 0, 829, 173, 1, 0, 0, 0, 74, 185, 195, 224, 239, 245, 254, 260, 273, 277, 288, 304, 312, 316, 323, 329, 336, 344, 352, 360, 364, 368, 373, 384, 389, 393, 407, 418, 424, 438, 459, 467, 470, 477, 488, 495, 503, 517, 526, 541, 553, 562, 570, 579, 588, 596, 601, 609, 611, 616, 623, 630, 639, 646, 655, 660, 665, 675, 681, 689, 691, 702, 709, 720, 725, 727, 739, 763, 774, 785, 790, 796, 799, 804, 825] \ No newline at end of file diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java index e8e7a309b94ef..8a87254e54a0b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java @@ -5225,15 +5225,50 @@ public T accept(ParseTreeVisitor visitor) { else return visitor.visitChildren(this); } } + @SuppressWarnings("CheckReturnValue") + public static class RlikeListExpressionContext extends RegexBooleanExpressionContext { + public ValueExpressionContext valueExpression() { + return getRuleContext(ValueExpressionContext.class,0); + } + public TerminalNode RLIKE() { return getToken(EsqlBaseParser.RLIKE, 0); } + public TerminalNode LP() { return getToken(EsqlBaseParser.LP, 0); } + public List string() { + return getRuleContexts(StringContext.class); + } + public StringContext string(int i) { + return getRuleContext(StringContext.class,i); + } + public TerminalNode RP() { return getToken(EsqlBaseParser.RP, 0); } + public TerminalNode NOT() { return getToken(EsqlBaseParser.NOT, 0); } + public List COMMA() { return getTokens(EsqlBaseParser.COMMA); } + public TerminalNode COMMA(int i) { + return getToken(EsqlBaseParser.COMMA, i); + } + @SuppressWarnings("this-escape") + public RlikeListExpressionContext(RegexBooleanExpressionContext ctx) { copyFrom(ctx); } + @Override + public void enterRule(ParseTreeListener listener) { + if ( listener instanceof EsqlBaseParserListener ) ((EsqlBaseParserListener)listener).enterRlikeListExpression(this); + } + @Override + public void exitRule(ParseTreeListener listener) { + if ( listener instanceof EsqlBaseParserListener ) ((EsqlBaseParserListener)listener).exitRlikeListExpression(this); + } + @Override + public T accept(ParseTreeVisitor visitor) { + if ( visitor instanceof EsqlBaseParserVisitor ) return ((EsqlBaseParserVisitor)visitor).visitRlikeListExpression(this); + else return visitor.visitChildren(this); + } + } public final RegexBooleanExpressionContext regexBooleanExpression() throws RecognitionException { RegexBooleanExpressionContext _localctx = new RegexBooleanExpressionContext(_ctx, getState()); enterRule(_localctx, 134, RULE_regexBooleanExpression); int _la; try { - setState(644); + setState(660); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,52,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,54,_ctx) ) { case 1: _localctx = new LikeExpressionContext(_localctx); enterOuterAlt(_localctx, 1); @@ -5320,6 +5355,48 @@ public final RegexBooleanExpressionContext regexBooleanExpression() throws Recog match(RP); } break; + case 4: + _localctx = new RlikeListExpressionContext(_localctx); + enterOuterAlt(_localctx, 4); + { + setState(644); + valueExpression(); + setState(646); + _errHandler.sync(this); + _la = _input.LA(1); + if (_la==NOT) { + { + setState(645); + match(NOT); + } + } + + setState(648); + match(RLIKE); + setState(649); + match(LP); + setState(650); + string(); + setState(655); + _errHandler.sync(this); + _la = _input.LA(1); + while (_la==COMMA) { + { + { + setState(651); + match(COMMA); + setState(652); + string(); + } + } + setState(657); + _errHandler.sync(this); + _la = _input.LA(1); + } + setState(658); + match(RP); + } + break; } } catch (RecognitionException re) { @@ -5376,23 +5453,23 @@ public final MatchBooleanExpressionContext matchBooleanExpression() throws Recog try { enterOuterAlt(_localctx, 1); { - setState(646); + setState(662); ((MatchBooleanExpressionContext)_localctx).fieldExp = qualifiedName(); - setState(649); + setState(665); _errHandler.sync(this); _la = _input.LA(1); if (_la==CAST_OP) { { - setState(647); + setState(663); match(CAST_OP); - setState(648); + setState(664); ((MatchBooleanExpressionContext)_localctx).fieldType = dataType(); } } - setState(651); + setState(667); match(COLON); - setState(652); + setState(668); ((MatchBooleanExpressionContext)_localctx).matchQuery = constant(); } } @@ -5476,14 +5553,14 @@ public final ValueExpressionContext valueExpression() throws RecognitionExceptio ValueExpressionContext _localctx = new ValueExpressionContext(_ctx, getState()); enterRule(_localctx, 138, RULE_valueExpression); try { - setState(659); + setState(675); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,54,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,56,_ctx) ) { case 1: _localctx = new ValueExpressionDefaultContext(_localctx); enterOuterAlt(_localctx, 1); { - setState(654); + setState(670); operatorExpression(0); } break; @@ -5491,11 +5568,11 @@ public final ValueExpressionContext valueExpression() throws RecognitionExceptio _localctx = new ComparisonContext(_localctx); enterOuterAlt(_localctx, 2); { - setState(655); + setState(671); ((ComparisonContext)_localctx).left = operatorExpression(0); - setState(656); + setState(672); comparisonOperator(); - setState(657); + setState(673); ((ComparisonContext)_localctx).right = operatorExpression(0); } break; @@ -5620,16 +5697,16 @@ private OperatorExpressionContext operatorExpression(int _p) throws RecognitionE int _alt; enterOuterAlt(_localctx, 1); { - setState(665); + setState(681); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,55,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,57,_ctx) ) { case 1: { _localctx = new OperatorExpressionDefaultContext(_localctx); _ctx = _localctx; _prevctx = _localctx; - setState(662); + setState(678); primaryExpression(0); } break; @@ -5638,7 +5715,7 @@ private OperatorExpressionContext operatorExpression(int _p) throws RecognitionE _localctx = new ArithmeticUnaryContext(_localctx); _ctx = _localctx; _prevctx = _localctx; - setState(663); + setState(679); ((ArithmeticUnaryContext)_localctx).operator = _input.LT(1); _la = _input.LA(1); if ( !(_la==PLUS || _la==MINUS) ) { @@ -5649,31 +5726,31 @@ private OperatorExpressionContext operatorExpression(int _p) throws RecognitionE _errHandler.reportMatch(this); consume(); } - setState(664); + setState(680); operatorExpression(3); } break; } _ctx.stop = _input.LT(-1); - setState(675); + setState(691); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,57,_ctx); + _alt = getInterpreter().adaptivePredict(_input,59,_ctx); while ( _alt!=2 && _alt!=org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER ) { if ( _alt==1 ) { if ( _parseListeners!=null ) triggerExitRuleEvent(); _prevctx = _localctx; { - setState(673); + setState(689); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,56,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,58,_ctx) ) { case 1: { _localctx = new ArithmeticBinaryContext(new OperatorExpressionContext(_parentctx, _parentState)); ((ArithmeticBinaryContext)_localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_operatorExpression); - setState(667); + setState(683); if (!(precpred(_ctx, 2))) throw new FailedPredicateException(this, "precpred(_ctx, 2)"); - setState(668); + setState(684); ((ArithmeticBinaryContext)_localctx).operator = _input.LT(1); _la = _input.LA(1); if ( !(((((_la - 89)) & ~0x3f) == 0 && ((1L << (_la - 89)) & 7L) != 0)) ) { @@ -5684,7 +5761,7 @@ private OperatorExpressionContext operatorExpression(int _p) throws RecognitionE _errHandler.reportMatch(this); consume(); } - setState(669); + setState(685); ((ArithmeticBinaryContext)_localctx).right = operatorExpression(3); } break; @@ -5693,9 +5770,9 @@ private OperatorExpressionContext operatorExpression(int _p) throws RecognitionE _localctx = new ArithmeticBinaryContext(new OperatorExpressionContext(_parentctx, _parentState)); ((ArithmeticBinaryContext)_localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_operatorExpression); - setState(670); + setState(686); if (!(precpred(_ctx, 1))) throw new FailedPredicateException(this, "precpred(_ctx, 1)"); - setState(671); + setState(687); ((ArithmeticBinaryContext)_localctx).operator = _input.LT(1); _la = _input.LA(1); if ( !(_la==PLUS || _la==MINUS) ) { @@ -5706,16 +5783,16 @@ private OperatorExpressionContext operatorExpression(int _p) throws RecognitionE _errHandler.reportMatch(this); consume(); } - setState(672); + setState(688); ((ArithmeticBinaryContext)_localctx).right = operatorExpression(2); } break; } } } - setState(677); + setState(693); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,57,_ctx); + _alt = getInterpreter().adaptivePredict(_input,59,_ctx); } } } @@ -5871,16 +5948,16 @@ private PrimaryExpressionContext primaryExpression(int _p) throws RecognitionExc int _alt; enterOuterAlt(_localctx, 1); { - setState(686); + setState(702); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,58,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,60,_ctx) ) { case 1: { _localctx = new ConstantDefaultContext(_localctx); _ctx = _localctx; _prevctx = _localctx; - setState(679); + setState(695); constant(); } break; @@ -5889,7 +5966,7 @@ private PrimaryExpressionContext primaryExpression(int _p) throws RecognitionExc _localctx = new DereferenceContext(_localctx); _ctx = _localctx; _prevctx = _localctx; - setState(680); + setState(696); qualifiedName(); } break; @@ -5898,7 +5975,7 @@ private PrimaryExpressionContext primaryExpression(int _p) throws RecognitionExc _localctx = new FunctionContext(_localctx); _ctx = _localctx; _prevctx = _localctx; - setState(681); + setState(697); functionExpression(); } break; @@ -5907,19 +5984,19 @@ private PrimaryExpressionContext primaryExpression(int _p) throws RecognitionExc _localctx = new ParenthesizedExpressionContext(_localctx); _ctx = _localctx; _prevctx = _localctx; - setState(682); + setState(698); match(LP); - setState(683); + setState(699); booleanExpression(0); - setState(684); + setState(700); match(RP); } break; } _ctx.stop = _input.LT(-1); - setState(693); + setState(709); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,59,_ctx); + _alt = getInterpreter().adaptivePredict(_input,61,_ctx); while ( _alt!=2 && _alt!=org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER ) { if ( _alt==1 ) { if ( _parseListeners!=null ) triggerExitRuleEvent(); @@ -5928,18 +6005,18 @@ private PrimaryExpressionContext primaryExpression(int _p) throws RecognitionExc { _localctx = new InlineCastContext(new PrimaryExpressionContext(_parentctx, _parentState)); pushNewRecursionContext(_localctx, _startState, RULE_primaryExpression); - setState(688); + setState(704); if (!(precpred(_ctx, 1))) throw new FailedPredicateException(this, "precpred(_ctx, 1)"); - setState(689); + setState(705); match(CAST_OP); - setState(690); + setState(706); dataType(); } } } - setState(695); + setState(711); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,59,_ctx); + _alt = getInterpreter().adaptivePredict(_input,61,_ctx); } } } @@ -6003,16 +6080,16 @@ public final FunctionExpressionContext functionExpression() throws RecognitionEx int _alt; enterOuterAlt(_localctx, 1); { - setState(696); + setState(712); functionName(); - setState(697); + setState(713); match(LP); - setState(711); + setState(727); _errHandler.sync(this); switch (_input.LA(1)) { case ASTERISK: { - setState(698); + setState(714); match(ASTERISK); } break; @@ -6035,34 +6112,34 @@ public final FunctionExpressionContext functionExpression() throws RecognitionEx case QUOTED_IDENTIFIER: { { - setState(699); + setState(715); booleanExpression(0); - setState(704); + setState(720); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,60,_ctx); + _alt = getInterpreter().adaptivePredict(_input,62,_ctx); while ( _alt!=2 && _alt!=org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER ) { if ( _alt==1 ) { { { - setState(700); + setState(716); match(COMMA); - setState(701); + setState(717); booleanExpression(0); } } } - setState(706); + setState(722); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,60,_ctx); + _alt = getInterpreter().adaptivePredict(_input,62,_ctx); } - setState(709); + setState(725); _errHandler.sync(this); _la = _input.LA(1); if (_la==COMMA) { { - setState(707); + setState(723); match(COMMA); - setState(708); + setState(724); mapExpression(); } } @@ -6075,7 +6152,7 @@ public final FunctionExpressionContext functionExpression() throws RecognitionEx default: break; } - setState(713); + setState(729); match(RP); } } @@ -6121,7 +6198,7 @@ public final FunctionNameContext functionName() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(715); + setState(731); identifierOrParameter(); } } @@ -6177,27 +6254,27 @@ public final MapExpressionContext mapExpression() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(717); + setState(733); match(LEFT_BRACES); - setState(718); + setState(734); entryExpression(); - setState(723); + setState(739); _errHandler.sync(this); _la = _input.LA(1); while (_la==COMMA) { { { - setState(719); + setState(735); match(COMMA); - setState(720); + setState(736); entryExpression(); } } - setState(725); + setState(741); _errHandler.sync(this); _la = _input.LA(1); } - setState(726); + setState(742); match(RIGHT_BRACES); } } @@ -6249,11 +6326,11 @@ public final EntryExpressionContext entryExpression() throws RecognitionExceptio try { enterOuterAlt(_localctx, 1); { - setState(728); + setState(744); ((EntryExpressionContext)_localctx).key = string(); - setState(729); + setState(745); match(COLON); - setState(730); + setState(746); ((EntryExpressionContext)_localctx).value = constant(); } } @@ -6524,14 +6601,14 @@ public final ConstantContext constant() throws RecognitionException { enterRule(_localctx, 152, RULE_constant); int _la; try { - setState(774); + setState(790); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,67,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,69,_ctx) ) { case 1: _localctx = new NullLiteralContext(_localctx); enterOuterAlt(_localctx, 1); { - setState(732); + setState(748); match(NULL); } break; @@ -6539,9 +6616,9 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new QualifiedIntegerLiteralContext(_localctx); enterOuterAlt(_localctx, 2); { - setState(733); + setState(749); integerValue(); - setState(734); + setState(750); match(UNQUOTED_IDENTIFIER); } break; @@ -6549,7 +6626,7 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new DecimalLiteralContext(_localctx); enterOuterAlt(_localctx, 3); { - setState(736); + setState(752); decimalValue(); } break; @@ -6557,7 +6634,7 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new IntegerLiteralContext(_localctx); enterOuterAlt(_localctx, 4); { - setState(737); + setState(753); integerValue(); } break; @@ -6565,7 +6642,7 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new BooleanLiteralContext(_localctx); enterOuterAlt(_localctx, 5); { - setState(738); + setState(754); booleanValue(); } break; @@ -6573,7 +6650,7 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new InputParameterContext(_localctx); enterOuterAlt(_localctx, 6); { - setState(739); + setState(755); parameter(); } break; @@ -6581,7 +6658,7 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new StringLiteralContext(_localctx); enterOuterAlt(_localctx, 7); { - setState(740); + setState(756); string(); } break; @@ -6589,27 +6666,27 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new NumericArrayLiteralContext(_localctx); enterOuterAlt(_localctx, 8); { - setState(741); + setState(757); match(OPENING_BRACKET); - setState(742); + setState(758); numericValue(); - setState(747); + setState(763); _errHandler.sync(this); _la = _input.LA(1); while (_la==COMMA) { { { - setState(743); + setState(759); match(COMMA); - setState(744); + setState(760); numericValue(); } } - setState(749); + setState(765); _errHandler.sync(this); _la = _input.LA(1); } - setState(750); + setState(766); match(CLOSING_BRACKET); } break; @@ -6617,27 +6694,27 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new BooleanArrayLiteralContext(_localctx); enterOuterAlt(_localctx, 9); { - setState(752); + setState(768); match(OPENING_BRACKET); - setState(753); + setState(769); booleanValue(); - setState(758); + setState(774); _errHandler.sync(this); _la = _input.LA(1); while (_la==COMMA) { { { - setState(754); + setState(770); match(COMMA); - setState(755); + setState(771); booleanValue(); } } - setState(760); + setState(776); _errHandler.sync(this); _la = _input.LA(1); } - setState(761); + setState(777); match(CLOSING_BRACKET); } break; @@ -6645,27 +6722,27 @@ public final ConstantContext constant() throws RecognitionException { _localctx = new StringArrayLiteralContext(_localctx); enterOuterAlt(_localctx, 10); { - setState(763); + setState(779); match(OPENING_BRACKET); - setState(764); + setState(780); string(); - setState(769); + setState(785); _errHandler.sync(this); _la = _input.LA(1); while (_la==COMMA) { { { - setState(765); + setState(781); match(COMMA); - setState(766); + setState(782); string(); } } - setState(771); + setState(787); _errHandler.sync(this); _la = _input.LA(1); } - setState(772); + setState(788); match(CLOSING_BRACKET); } break; @@ -6713,7 +6790,7 @@ public final BooleanValueContext booleanValue() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(776); + setState(792); _la = _input.LA(1); if ( !(_la==FALSE || _la==TRUE) ) { _errHandler.recoverInline(this); @@ -6768,20 +6845,20 @@ public final NumericValueContext numericValue() throws RecognitionException { NumericValueContext _localctx = new NumericValueContext(_ctx, getState()); enterRule(_localctx, 156, RULE_numericValue); try { - setState(780); + setState(796); _errHandler.sync(this); - switch ( getInterpreter().adaptivePredict(_input,68,_ctx) ) { + switch ( getInterpreter().adaptivePredict(_input,70,_ctx) ) { case 1: enterOuterAlt(_localctx, 1); { - setState(778); + setState(794); decimalValue(); } break; case 2: enterOuterAlt(_localctx, 2); { - setState(779); + setState(795); integerValue(); } break; @@ -6830,12 +6907,12 @@ public final DecimalValueContext decimalValue() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(783); + setState(799); _errHandler.sync(this); _la = _input.LA(1); if (_la==PLUS || _la==MINUS) { { - setState(782); + setState(798); _la = _input.LA(1); if ( !(_la==PLUS || _la==MINUS) ) { _errHandler.recoverInline(this); @@ -6848,7 +6925,7 @@ public final DecimalValueContext decimalValue() throws RecognitionException { } } - setState(785); + setState(801); match(DECIMAL_LITERAL); } } @@ -6895,12 +6972,12 @@ public final IntegerValueContext integerValue() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(788); + setState(804); _errHandler.sync(this); _la = _input.LA(1); if (_la==PLUS || _la==MINUS) { { - setState(787); + setState(803); _la = _input.LA(1); if ( !(_la==PLUS || _la==MINUS) ) { _errHandler.recoverInline(this); @@ -6913,7 +6990,7 @@ public final IntegerValueContext integerValue() throws RecognitionException { } } - setState(790); + setState(806); match(INTEGER_LITERAL); } } @@ -6957,7 +7034,7 @@ public final StringContext string() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(792); + setState(808); match(QUOTED_STRING); } } @@ -7007,7 +7084,7 @@ public final ComparisonOperatorContext comparisonOperator() throws RecognitionEx try { enterOuterAlt(_localctx, 1); { - setState(794); + setState(810); _la = _input.LA(1); if ( !(((((_la - 80)) & ~0x3f) == 0 && ((1L << (_la - 80)) & 125L) != 0)) ) { _errHandler.recoverInline(this); @@ -7070,7 +7147,7 @@ public final JoinCommandContext joinCommand() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(796); + setState(812); ((JoinCommandContext)_localctx).type = _input.LT(1); _la = _input.LA(1); if ( !((((_la) & ~0x3f) == 0 && ((1L << _la) & 109051904L) != 0)) ) { @@ -7081,11 +7158,11 @@ public final JoinCommandContext joinCommand() throws RecognitionException { _errHandler.reportMatch(this); consume(); } - setState(797); + setState(813); match(JOIN); - setState(798); + setState(814); joinTarget(); - setState(799); + setState(815); joinCondition(); } } @@ -7132,7 +7209,7 @@ public final JoinTargetContext joinTarget() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(801); + setState(817); ((JoinTargetContext)_localctx).index = indexPattern(); } } @@ -7187,27 +7264,27 @@ public final JoinConditionContext joinCondition() throws RecognitionException { int _alt; enterOuterAlt(_localctx, 1); { - setState(803); + setState(819); match(ON); - setState(804); + setState(820); joinPredicate(); - setState(809); + setState(825); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,71,_ctx); + _alt = getInterpreter().adaptivePredict(_input,73,_ctx); while ( _alt!=2 && _alt!=org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER ) { if ( _alt==1 ) { { { - setState(805); + setState(821); match(COMMA); - setState(806); + setState(822); joinPredicate(); } } } - setState(811); + setState(827); _errHandler.sync(this); - _alt = getInterpreter().adaptivePredict(_input,71,_ctx); + _alt = getInterpreter().adaptivePredict(_input,73,_ctx); } } } @@ -7253,7 +7330,7 @@ public final JoinPredicateContext joinPredicate() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(812); + setState(828); valueExpression(); } } @@ -7352,7 +7429,7 @@ private boolean primaryExpression_sempred(PrimaryExpressionContext _localctx, in } public static final String _serializedATN = - "\u0004\u0001\u008b\u032f\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001"+ + "\u0004\u0001\u008b\u033f\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001"+ "\u0002\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004"+ "\u0002\u0005\u0007\u0005\u0002\u0006\u0007\u0006\u0002\u0007\u0007\u0007"+ "\u0002\b\u0007\b\u0002\t\u0007\t\u0002\n\u0007\n\u0002\u000b\u0007\u000b"+ @@ -7431,216 +7508,217 @@ private boolean primaryExpression_sempred(PrimaryExpressionContext _localctx, in "\bB\nB\fB\u0265\tB\u0001C\u0001C\u0003C\u0269\bC\u0001C\u0001C\u0001C"+ "\u0001C\u0001C\u0003C\u0270\bC\u0001C\u0001C\u0001C\u0001C\u0001C\u0003"+ "C\u0277\bC\u0001C\u0001C\u0001C\u0001C\u0001C\u0005C\u027e\bC\nC\fC\u0281"+ - "\tC\u0001C\u0001C\u0003C\u0285\bC\u0001D\u0001D\u0001D\u0003D\u028a\b"+ - "D\u0001D\u0001D\u0001D\u0001E\u0001E\u0001E\u0001E\u0001E\u0003E\u0294"+ - "\bE\u0001F\u0001F\u0001F\u0001F\u0003F\u029a\bF\u0001F\u0001F\u0001F\u0001"+ - "F\u0001F\u0001F\u0005F\u02a2\bF\nF\fF\u02a5\tF\u0001G\u0001G\u0001G\u0001"+ - "G\u0001G\u0001G\u0001G\u0001G\u0003G\u02af\bG\u0001G\u0001G\u0001G\u0005"+ - "G\u02b4\bG\nG\fG\u02b7\tG\u0001H\u0001H\u0001H\u0001H\u0001H\u0001H\u0005"+ - "H\u02bf\bH\nH\fH\u02c2\tH\u0001H\u0001H\u0003H\u02c6\bH\u0003H\u02c8\b"+ - "H\u0001H\u0001H\u0001I\u0001I\u0001J\u0001J\u0001J\u0001J\u0005J\u02d2"+ - "\bJ\nJ\fJ\u02d5\tJ\u0001J\u0001J\u0001K\u0001K\u0001K\u0001K\u0001L\u0001"+ - "L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001"+ - "L\u0001L\u0005L\u02ea\bL\nL\fL\u02ed\tL\u0001L\u0001L\u0001L\u0001L\u0001"+ - "L\u0001L\u0005L\u02f5\bL\nL\fL\u02f8\tL\u0001L\u0001L\u0001L\u0001L\u0001"+ - "L\u0001L\u0005L\u0300\bL\nL\fL\u0303\tL\u0001L\u0001L\u0003L\u0307\bL"+ - "\u0001M\u0001M\u0001N\u0001N\u0003N\u030d\bN\u0001O\u0003O\u0310\bO\u0001"+ - "O\u0001O\u0001P\u0003P\u0315\bP\u0001P\u0001P\u0001Q\u0001Q\u0001R\u0001"+ - "R\u0001S\u0001S\u0001S\u0001S\u0001S\u0001T\u0001T\u0001U\u0001U\u0001"+ - "U\u0001U\u0005U\u0328\bU\nU\fU\u032b\tU\u0001V\u0001V\u0001V\u0000\u0005"+ - "\u0002n\u0084\u008c\u008eW\u0000\u0002\u0004\u0006\b\n\f\u000e\u0010\u0012"+ - "\u0014\u0016\u0018\u001a\u001c\u001e \"$&(*,.02468:<>@BDFHJLNPRTVXZ\\"+ - "^`bdfhjlnprtvxz|~\u0080\u0082\u0084\u0086\u0088\u008a\u008c\u008e\u0090"+ - "\u0092\u0094\u0096\u0098\u009a\u009c\u009e\u00a0\u00a2\u00a4\u00a6\u00a8"+ - "\u00aa\u00ac\u0000\n\u0002\u000055kk\u0001\u0000ef\u0002\u000099??\u0002"+ - "\u0000BBEE\u0002\u0000&&55\u0001\u0000WX\u0001\u0000Y[\u0002\u0000AAN"+ - "N\u0002\u0000PPRV\u0002\u0000\u0017\u0017\u0019\u001a\u0349\u0000\u00ae"+ - "\u0001\u0000\u0000\u0000\u0002\u00b1\u0001\u0000\u0000\u0000\u0004\u00c3"+ - "\u0001\u0000\u0000\u0000\u0006\u00e0\u0001\u0000\u0000\u0000\b\u00e2\u0001"+ - "\u0000\u0000\u0000\n\u00e5\u0001\u0000\u0000\u0000\f\u00e7\u0001\u0000"+ - "\u0000\u0000\u000e\u00ea\u0001\u0000\u0000\u0000\u0010\u00f5\u0001\u0000"+ - "\u0000\u0000\u0012\u00f9\u0001\u0000\u0000\u0000\u0014\u0101\u0001\u0000"+ - "\u0000\u0000\u0016\u0106\u0001\u0000\u0000\u0000\u0018\u0109\u0001\u0000"+ - "\u0000\u0000\u001a\u010c\u0001\u0000\u0000\u0000\u001c\u0120\u0001\u0000"+ - "\u0000\u0000\u001e\u0122\u0001\u0000\u0000\u0000 \u0124\u0001\u0000\u0000"+ - "\u0000\"\u0126\u0001\u0000\u0000\u0000$\u0128\u0001\u0000\u0000\u0000"+ - "&\u012a\u0001\u0000\u0000\u0000(\u0133\u0001\u0000\u0000\u0000*\u0136"+ - "\u0001\u0000\u0000\u0000,\u013e\u0001\u0000\u0000\u0000.\u0146\u0001\u0000"+ - "\u0000\u00000\u014b\u0001\u0000\u0000\u00002\u0153\u0001\u0000\u0000\u0000"+ - "4\u015b\u0001\u0000\u0000\u00006\u0163\u0001\u0000\u0000\u00008\u0168"+ - "\u0001\u0000\u0000\u0000:\u016c\u0001\u0000\u0000\u0000<\u0170\u0001\u0000"+ - "\u0000\u0000>\u0175\u0001\u0000\u0000\u0000@\u0177\u0001\u0000\u0000\u0000"+ - "B\u017a\u0001\u0000\u0000\u0000D\u0183\u0001\u0000\u0000\u0000F\u018b"+ - "\u0001\u0000\u0000\u0000H\u018e\u0001\u0000\u0000\u0000J\u0191\u0001\u0000"+ - "\u0000\u0000L\u01a2\u0001\u0000\u0000\u0000N\u01a4\u0001\u0000\u0000\u0000"+ - "P\u01aa\u0001\u0000\u0000\u0000R\u01ae\u0001\u0000\u0000\u0000T\u01b1"+ - "\u0001\u0000\u0000\u0000V\u01b9\u0001\u0000\u0000\u0000X\u01bd\u0001\u0000"+ - "\u0000\u0000Z\u01c0\u0001\u0000\u0000\u0000\\\u01c4\u0001\u0000\u0000"+ - "\u0000^\u01c7\u0001\u0000\u0000\u0000`\u01d8\u0001\u0000\u0000\u0000b"+ - "\u01dd\u0001\u0000\u0000\u0000d\u01e1\u0001\u0000\u0000\u0000f\u01e4\u0001"+ - "\u0000\u0000\u0000h\u01f1\u0001\u0000\u0000\u0000j\u01f5\u0001\u0000\u0000"+ - "\u0000l\u01f9\u0001\u0000\u0000\u0000n\u01fd\u0001\u0000\u0000\u0000p"+ - "\u0208\u0001\u0000\u0000\u0000r\u020a\u0001\u0000\u0000\u0000t\u0214\u0001"+ - "\u0000\u0000\u0000v\u0219\u0001\u0000\u0000\u0000x\u021f\u0001\u0000\u0000"+ - "\u0000z\u0222\u0001\u0000\u0000\u0000|\u0224\u0001\u0000\u0000\u0000~"+ - "\u022c\u0001\u0000\u0000\u0000\u0080\u0232\u0001\u0000\u0000\u0000\u0082"+ - "\u0234\u0001\u0000\u0000\u0000\u0084\u0259\u0001\u0000\u0000\u0000\u0086"+ - "\u0284\u0001\u0000\u0000\u0000\u0088\u0286\u0001\u0000\u0000\u0000\u008a"+ - "\u0293\u0001\u0000\u0000\u0000\u008c\u0299\u0001\u0000\u0000\u0000\u008e"+ - "\u02ae\u0001\u0000\u0000\u0000\u0090\u02b8\u0001\u0000\u0000\u0000\u0092"+ - "\u02cb\u0001\u0000\u0000\u0000\u0094\u02cd\u0001\u0000\u0000\u0000\u0096"+ - "\u02d8\u0001\u0000\u0000\u0000\u0098\u0306\u0001\u0000\u0000\u0000\u009a"+ - "\u0308\u0001\u0000\u0000\u0000\u009c\u030c\u0001\u0000\u0000\u0000\u009e"+ - "\u030f\u0001\u0000\u0000\u0000\u00a0\u0314\u0001\u0000\u0000\u0000\u00a2"+ - "\u0318\u0001\u0000\u0000\u0000\u00a4\u031a\u0001\u0000\u0000\u0000\u00a6"+ - "\u031c\u0001\u0000\u0000\u0000\u00a8\u0321\u0001\u0000\u0000\u0000\u00aa"+ - "\u0323\u0001\u0000\u0000\u0000\u00ac\u032c\u0001\u0000\u0000\u0000\u00ae"+ - "\u00af\u0003\u0002\u0001\u0000\u00af\u00b0\u0005\u0000\u0000\u0001\u00b0"+ - "\u0001\u0001\u0000\u0000\u0000\u00b1\u00b2\u0006\u0001\uffff\uffff\u0000"+ - "\u00b2\u00b3\u0003\u0004\u0002\u0000\u00b3\u00b9\u0001\u0000\u0000\u0000"+ - "\u00b4\u00b5\n\u0001\u0000\u0000\u00b5\u00b6\u00054\u0000\u0000\u00b6"+ - "\u00b8\u0003\u0006\u0003\u0000\u00b7\u00b4\u0001\u0000\u0000\u0000\u00b8"+ - "\u00bb\u0001\u0000\u0000\u0000\u00b9\u00b7\u0001\u0000\u0000\u0000\u00b9"+ - "\u00ba\u0001\u0000\u0000\u0000\u00ba\u0003\u0001\u0000\u0000\u0000\u00bb"+ - "\u00b9\u0001\u0000\u0000\u0000\u00bc\u00c4\u0003\u0016\u000b\u0000\u00bd"+ - "\u00c4\u0003\f\u0006\u0000\u00be\u00c4\u0003\\.\u0000\u00bf\u00c0\u0004"+ - "\u0002\u0001\u0000\u00c0\u00c4\u0003\u0018\f\u0000\u00c1\u00c2\u0004\u0002"+ - "\u0002\u0000\u00c2\u00c4\u0003X,\u0000\u00c3\u00bc\u0001\u0000\u0000\u0000"+ - "\u00c3\u00bd\u0001\u0000\u0000\u0000\u00c3\u00be\u0001\u0000\u0000\u0000"+ - "\u00c3\u00bf\u0001\u0000\u0000\u0000\u00c3\u00c1\u0001\u0000\u0000\u0000"+ - "\u00c4\u0005\u0001\u0000\u0000\u0000\u00c5\u00e1\u0003(\u0014\u0000\u00c6"+ - "\u00e1\u0003\b\u0004\u0000\u00c7\u00e1\u0003F#\u0000\u00c8\u00e1\u0003"+ - "@ \u0000\u00c9\u00e1\u0003*\u0015\u0000\u00ca\u00e1\u0003B!\u0000\u00cb"+ - "\u00e1\u0003H$\u0000\u00cc\u00e1\u0003J%\u0000\u00cd\u00e1\u0003N\'\u0000"+ - "\u00ce\u00e1\u0003P(\u0000\u00cf\u00e1\u0003^/\u0000\u00d0\u00e1\u0003"+ - "R)\u0000\u00d1\u00e1\u0003\u00a6S\u0000\u00d2\u00e1\u0003f3\u0000\u00d3"+ - "\u00e1\u0003r9\u0000\u00d4\u00e1\u0003d2\u0000\u00d5\u00e1\u0003h4\u0000"+ - "\u00d6\u00d7\u0004\u0003\u0003\u0000\u00d7\u00e1\u0003v;\u0000\u00d8\u00d9"+ - "\u0004\u0003\u0004\u0000\u00d9\u00e1\u0003t:\u0000\u00da\u00db\u0004\u0003"+ - "\u0005\u0000\u00db\u00e1\u0003x<\u0000\u00dc\u00dd\u0004\u0003\u0006\u0000"+ - "\u00dd\u00e1\u0003\u0082A\u0000\u00de\u00df\u0004\u0003\u0007\u0000\u00df"+ - "\u00e1\u0003z=\u0000\u00e0\u00c5\u0001\u0000\u0000\u0000\u00e0\u00c6\u0001"+ - "\u0000\u0000\u0000\u00e0\u00c7\u0001\u0000\u0000\u0000\u00e0\u00c8\u0001"+ - "\u0000\u0000\u0000\u00e0\u00c9\u0001\u0000\u0000\u0000\u00e0\u00ca\u0001"+ - "\u0000\u0000\u0000\u00e0\u00cb\u0001\u0000\u0000\u0000\u00e0\u00cc\u0001"+ - "\u0000\u0000\u0000\u00e0\u00cd\u0001\u0000\u0000\u0000\u00e0\u00ce\u0001"+ - "\u0000\u0000\u0000\u00e0\u00cf\u0001\u0000\u0000\u0000\u00e0\u00d0\u0001"+ - "\u0000\u0000\u0000\u00e0\u00d1\u0001\u0000\u0000\u0000\u00e0\u00d2\u0001"+ - "\u0000\u0000\u0000\u00e0\u00d3\u0001\u0000\u0000\u0000\u00e0\u00d4\u0001"+ - "\u0000\u0000\u0000\u00e0\u00d5\u0001\u0000\u0000\u0000\u00e0\u00d6\u0001"+ - "\u0000\u0000\u0000\u00e0\u00d8\u0001\u0000\u0000\u0000\u00e0\u00da\u0001"+ - "\u0000\u0000\u0000\u00e0\u00dc\u0001\u0000\u0000\u0000\u00e0\u00de\u0001"+ - "\u0000\u0000\u0000\u00e1\u0007\u0001\u0000\u0000\u0000\u00e2\u00e3\u0005"+ - "\u0010\u0000\u0000\u00e3\u00e4\u0003\u0084B\u0000\u00e4\t\u0001\u0000"+ - "\u0000\u0000\u00e5\u00e6\u00036\u001b\u0000\u00e6\u000b\u0001\u0000\u0000"+ - "\u0000\u00e7\u00e8\u0005\f\u0000\u0000\u00e8\u00e9\u0003\u000e\u0007\u0000"+ - "\u00e9\r\u0001\u0000\u0000\u0000\u00ea\u00ef\u0003\u0010\b\u0000\u00eb"+ - "\u00ec\u0005>\u0000\u0000\u00ec\u00ee\u0003\u0010\b\u0000\u00ed\u00eb"+ - "\u0001\u0000\u0000\u0000\u00ee\u00f1\u0001\u0000\u0000\u0000\u00ef\u00ed"+ - "\u0001\u0000\u0000\u0000\u00ef\u00f0\u0001\u0000\u0000\u0000\u00f0\u000f"+ - "\u0001\u0000\u0000\u0000\u00f1\u00ef\u0001\u0000\u0000\u0000\u00f2\u00f3"+ - "\u00030\u0018\u0000\u00f3\u00f4\u0005:\u0000\u0000\u00f4\u00f6\u0001\u0000"+ - "\u0000\u0000\u00f5\u00f2\u0001\u0000\u0000\u0000\u00f5\u00f6\u0001\u0000"+ - "\u0000\u0000\u00f6\u00f7\u0001\u0000\u0000\u0000\u00f7\u00f8\u0003\u0084"+ - "B\u0000\u00f8\u0011\u0001\u0000\u0000\u0000\u00f9\u00fe\u0003\u0014\n"+ - "\u0000\u00fa\u00fb\u0005>\u0000\u0000\u00fb\u00fd\u0003\u0014\n\u0000"+ - "\u00fc\u00fa\u0001\u0000\u0000\u0000\u00fd\u0100\u0001\u0000\u0000\u0000"+ - "\u00fe\u00fc\u0001\u0000\u0000\u0000\u00fe\u00ff\u0001\u0000\u0000\u0000"+ - "\u00ff\u0013\u0001\u0000\u0000\u0000\u0100\u00fe\u0001\u0000\u0000\u0000"+ - "\u0101\u0104\u00030\u0018\u0000\u0102\u0103\u0005:\u0000\u0000\u0103\u0105"+ - "\u0003\u0084B\u0000\u0104\u0102\u0001\u0000\u0000\u0000\u0104\u0105\u0001"+ - "\u0000\u0000\u0000\u0105\u0015\u0001\u0000\u0000\u0000\u0106\u0107\u0005"+ - "\u0013\u0000\u0000\u0107\u0108\u0003\u001a\r\u0000\u0108\u0017\u0001\u0000"+ - "\u0000\u0000\u0109\u010a\u0005\u0014\u0000\u0000\u010a\u010b\u0003\u001a"+ - "\r\u0000\u010b\u0019\u0001\u0000\u0000\u0000\u010c\u0111\u0003\u001c\u000e"+ - "\u0000\u010d\u010e\u0005>\u0000\u0000\u010e\u0110\u0003\u001c\u000e\u0000"+ - "\u010f\u010d\u0001\u0000\u0000\u0000\u0110\u0113\u0001\u0000\u0000\u0000"+ - "\u0111\u010f\u0001\u0000\u0000\u0000\u0111\u0112\u0001\u0000\u0000\u0000"+ - "\u0112\u0115\u0001\u0000\u0000\u0000\u0113\u0111\u0001\u0000\u0000\u0000"+ - "\u0114\u0116\u0003&\u0013\u0000\u0115\u0114\u0001\u0000\u0000\u0000\u0115"+ - "\u0116\u0001\u0000\u0000\u0000\u0116\u001b\u0001\u0000\u0000\u0000\u0117"+ - "\u0118\u0003\u001e\u000f\u0000\u0118\u0119\u0005=\u0000\u0000\u0119\u011a"+ - "\u0003\"\u0011\u0000\u011a\u0121\u0001\u0000\u0000\u0000\u011b\u011c\u0003"+ - "\"\u0011\u0000\u011c\u011d\u0005<\u0000\u0000\u011d\u011e\u0003 \u0010"+ - "\u0000\u011e\u0121\u0001\u0000\u0000\u0000\u011f\u0121\u0003$\u0012\u0000"+ - "\u0120\u0117\u0001\u0000\u0000\u0000\u0120\u011b\u0001\u0000\u0000\u0000"+ - "\u0120\u011f\u0001\u0000\u0000\u0000\u0121\u001d\u0001\u0000\u0000\u0000"+ - "\u0122\u0123\u0005k\u0000\u0000\u0123\u001f\u0001\u0000\u0000\u0000\u0124"+ - "\u0125\u0005k\u0000\u0000\u0125!\u0001\u0000\u0000\u0000\u0126\u0127\u0005"+ - "k\u0000\u0000\u0127#\u0001\u0000\u0000\u0000\u0128\u0129\u0007\u0000\u0000"+ - "\u0000\u0129%\u0001\u0000\u0000\u0000\u012a\u012b\u0005j\u0000\u0000\u012b"+ - "\u0130\u0005k\u0000\u0000\u012c\u012d\u0005>\u0000\u0000\u012d\u012f\u0005"+ - "k\u0000\u0000\u012e\u012c\u0001\u0000\u0000\u0000\u012f\u0132\u0001\u0000"+ - "\u0000\u0000\u0130\u012e\u0001\u0000\u0000\u0000\u0130\u0131\u0001\u0000"+ - "\u0000\u0000\u0131\'\u0001\u0000\u0000\u0000\u0132\u0130\u0001\u0000\u0000"+ - "\u0000\u0133\u0134\u0005\t\u0000\u0000\u0134\u0135\u0003\u000e\u0007\u0000"+ - "\u0135)\u0001\u0000\u0000\u0000\u0136\u0138\u0005\u000f\u0000\u0000\u0137"+ - "\u0139\u0003,\u0016\u0000\u0138\u0137\u0001\u0000\u0000\u0000\u0138\u0139"+ - "\u0001\u0000\u0000\u0000\u0139\u013c\u0001\u0000\u0000\u0000\u013a\u013b"+ - "\u0005;\u0000\u0000\u013b\u013d\u0003\u000e\u0007\u0000\u013c\u013a\u0001"+ - "\u0000\u0000\u0000\u013c\u013d\u0001\u0000\u0000\u0000\u013d+\u0001\u0000"+ - "\u0000\u0000\u013e\u0143\u0003.\u0017\u0000\u013f\u0140\u0005>\u0000\u0000"+ - "\u0140\u0142\u0003.\u0017\u0000\u0141\u013f\u0001\u0000\u0000\u0000\u0142"+ - "\u0145\u0001\u0000\u0000\u0000\u0143\u0141\u0001\u0000\u0000\u0000\u0143"+ - "\u0144\u0001\u0000\u0000\u0000\u0144-\u0001\u0000\u0000\u0000\u0145\u0143"+ - "\u0001\u0000\u0000\u0000\u0146\u0149\u0003\u0010\b\u0000\u0147\u0148\u0005"+ - "\u0010\u0000\u0000\u0148\u014a\u0003\u0084B\u0000\u0149\u0147\u0001\u0000"+ - "\u0000\u0000\u0149\u014a\u0001\u0000\u0000\u0000\u014a/\u0001\u0000\u0000"+ - "\u0000\u014b\u0150\u0003>\u001f\u0000\u014c\u014d\u0005@\u0000\u0000\u014d"+ - "\u014f\u0003>\u001f\u0000\u014e\u014c\u0001\u0000\u0000\u0000\u014f\u0152"+ - "\u0001\u0000\u0000\u0000\u0150\u014e\u0001\u0000\u0000\u0000\u0150\u0151"+ - "\u0001\u0000\u0000\u0000\u01511\u0001\u0000\u0000\u0000\u0152\u0150\u0001"+ - "\u0000\u0000\u0000\u0153\u0158\u00038\u001c\u0000\u0154\u0155\u0005@\u0000"+ - "\u0000\u0155\u0157\u00038\u001c\u0000\u0156\u0154\u0001\u0000\u0000\u0000"+ - "\u0157\u015a\u0001\u0000\u0000\u0000\u0158\u0156\u0001\u0000\u0000\u0000"+ - "\u0158\u0159\u0001\u0000\u0000\u0000\u01593\u0001\u0000\u0000\u0000\u015a"+ - "\u0158\u0001\u0000\u0000\u0000\u015b\u0160\u00032\u0019\u0000\u015c\u015d"+ - "\u0005>\u0000\u0000\u015d\u015f\u00032\u0019\u0000\u015e\u015c\u0001\u0000"+ - "\u0000\u0000\u015f\u0162\u0001\u0000\u0000\u0000\u0160\u015e\u0001\u0000"+ - "\u0000\u0000\u0160\u0161\u0001\u0000\u0000\u0000\u01615\u0001\u0000\u0000"+ - "\u0000\u0162\u0160\u0001\u0000\u0000\u0000\u0163\u0164\u0007\u0001\u0000"+ - "\u0000\u01647\u0001\u0000\u0000\u0000\u0165\u0169\u0005\u0080\u0000\u0000"+ - "\u0166\u0169\u0003:\u001d\u0000\u0167\u0169\u0003<\u001e\u0000\u0168\u0165"+ - "\u0001\u0000\u0000\u0000\u0168\u0166\u0001\u0000\u0000\u0000\u0168\u0167"+ - "\u0001\u0000\u0000\u0000\u01699\u0001\u0000\u0000\u0000\u016a\u016d\u0005"+ - "L\u0000\u0000\u016b\u016d\u0005_\u0000\u0000\u016c\u016a\u0001\u0000\u0000"+ - "\u0000\u016c\u016b\u0001\u0000\u0000\u0000\u016d;\u0001\u0000\u0000\u0000"+ - "\u016e\u0171\u0005^\u0000\u0000\u016f\u0171\u0005`\u0000\u0000\u0170\u016e"+ - "\u0001\u0000\u0000\u0000\u0170\u016f\u0001\u0000\u0000\u0000\u0171=\u0001"+ - "\u0000\u0000\u0000\u0172\u0176\u00036\u001b\u0000\u0173\u0176\u0003:\u001d"+ - "\u0000\u0174\u0176\u0003<\u001e\u0000\u0175\u0172\u0001\u0000\u0000\u0000"+ - "\u0175\u0173\u0001\u0000\u0000\u0000\u0175\u0174\u0001\u0000\u0000\u0000"+ - "\u0176?\u0001\u0000\u0000\u0000\u0177\u0178\u0005\u000b\u0000\u0000\u0178"+ - "\u0179\u0003\u0098L\u0000\u0179A\u0001\u0000\u0000\u0000\u017a\u017b\u0005"+ - "\u000e\u0000\u0000\u017b\u0180\u0003D\"\u0000\u017c\u017d\u0005>\u0000"+ - "\u0000\u017d\u017f\u0003D\"\u0000\u017e\u017c\u0001\u0000\u0000\u0000"+ - "\u017f\u0182\u0001\u0000\u0000\u0000\u0180\u017e\u0001\u0000\u0000\u0000"+ - "\u0180\u0181\u0001\u0000\u0000\u0000\u0181C\u0001\u0000\u0000\u0000\u0182"+ - "\u0180\u0001\u0000\u0000\u0000\u0183\u0185\u0003\u0084B\u0000\u0184\u0186"+ - "\u0007\u0002\u0000\u0000\u0185\u0184\u0001\u0000\u0000\u0000\u0185\u0186"+ - "\u0001\u0000\u0000\u0000\u0186\u0189\u0001\u0000\u0000\u0000\u0187\u0188"+ - "\u0005I\u0000\u0000\u0188\u018a\u0007\u0003\u0000\u0000\u0189\u0187\u0001"+ - "\u0000\u0000\u0000\u0189\u018a\u0001\u0000\u0000\u0000\u018aE\u0001\u0000"+ - "\u0000\u0000\u018b\u018c\u0005\u001e\u0000\u0000\u018c\u018d\u00034\u001a"+ - "\u0000\u018dG\u0001\u0000\u0000\u0000\u018e\u018f\u0005\u001d\u0000\u0000"+ - "\u018f\u0190\u00034\u001a\u0000\u0190I\u0001\u0000\u0000\u0000\u0191\u0192"+ - "\u0005 \u0000\u0000\u0192\u0197\u0003L&\u0000\u0193\u0194\u0005>\u0000"+ - "\u0000\u0194\u0196\u0003L&\u0000\u0195\u0193\u0001\u0000\u0000\u0000\u0196"+ - "\u0199\u0001\u0000\u0000\u0000\u0197\u0195\u0001\u0000\u0000\u0000\u0197"+ - "\u0198\u0001\u0000\u0000\u0000\u0198K\u0001\u0000\u0000\u0000\u0199\u0197"+ - "\u0001\u0000\u0000\u0000\u019a\u019b\u00032\u0019\u0000\u019b\u019c\u0005"+ - "\u0084\u0000\u0000\u019c\u019d\u00032\u0019\u0000\u019d\u01a3\u0001\u0000"+ - "\u0000\u0000\u019e\u019f\u00032\u0019\u0000\u019f\u01a0\u0005:\u0000\u0000"+ - "\u01a0\u01a1\u00032\u0019\u0000\u01a1\u01a3\u0001\u0000\u0000\u0000\u01a2"+ - "\u019a\u0001\u0000\u0000\u0000\u01a2\u019e\u0001\u0000\u0000\u0000\u01a3"+ - "M\u0001\u0000\u0000\u0000\u01a4\u01a5\u0005\b\u0000\u0000\u01a5\u01a6"+ - "\u0003\u008eG\u0000\u01a6\u01a8\u0003\u00a2Q\u0000\u01a7\u01a9\u0003T"+ - "*\u0000\u01a8\u01a7\u0001\u0000\u0000\u0000\u01a8\u01a9\u0001\u0000\u0000"+ - "\u0000\u01a9O\u0001\u0000\u0000\u0000\u01aa\u01ab\u0005\n\u0000\u0000"+ - "\u01ab\u01ac\u0003\u008eG\u0000\u01ac\u01ad\u0003\u00a2Q\u0000\u01adQ"+ - "\u0001\u0000\u0000\u0000\u01ae\u01af\u0005\u001c\u0000\u0000\u01af\u01b0"+ - "\u00030\u0018\u0000\u01b0S\u0001\u0000\u0000\u0000\u01b1\u01b6\u0003V"+ - "+\u0000\u01b2\u01b3\u0005>\u0000\u0000\u01b3\u01b5\u0003V+\u0000\u01b4"+ + "\tC\u0001C\u0001C\u0001C\u0001C\u0003C\u0287\bC\u0001C\u0001C\u0001C\u0001"+ + "C\u0001C\u0005C\u028e\bC\nC\fC\u0291\tC\u0001C\u0001C\u0003C\u0295\bC"+ + "\u0001D\u0001D\u0001D\u0003D\u029a\bD\u0001D\u0001D\u0001D\u0001E\u0001"+ + "E\u0001E\u0001E\u0001E\u0003E\u02a4\bE\u0001F\u0001F\u0001F\u0001F\u0003"+ + "F\u02aa\bF\u0001F\u0001F\u0001F\u0001F\u0001F\u0001F\u0005F\u02b2\bF\n"+ + "F\fF\u02b5\tF\u0001G\u0001G\u0001G\u0001G\u0001G\u0001G\u0001G\u0001G"+ + "\u0003G\u02bf\bG\u0001G\u0001G\u0001G\u0005G\u02c4\bG\nG\fG\u02c7\tG\u0001"+ + "H\u0001H\u0001H\u0001H\u0001H\u0001H\u0005H\u02cf\bH\nH\fH\u02d2\tH\u0001"+ + "H\u0001H\u0003H\u02d6\bH\u0003H\u02d8\bH\u0001H\u0001H\u0001I\u0001I\u0001"+ + "J\u0001J\u0001J\u0001J\u0005J\u02e2\bJ\nJ\fJ\u02e5\tJ\u0001J\u0001J\u0001"+ + "K\u0001K\u0001K\u0001K\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001"+ + "L\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0005L\u02fa\bL\nL\fL\u02fd"+ + "\tL\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0005L\u0305\bL\nL\fL\u0308"+ + "\tL\u0001L\u0001L\u0001L\u0001L\u0001L\u0001L\u0005L\u0310\bL\nL\fL\u0313"+ + "\tL\u0001L\u0001L\u0003L\u0317\bL\u0001M\u0001M\u0001N\u0001N\u0003N\u031d"+ + "\bN\u0001O\u0003O\u0320\bO\u0001O\u0001O\u0001P\u0003P\u0325\bP\u0001"+ + "P\u0001P\u0001Q\u0001Q\u0001R\u0001R\u0001S\u0001S\u0001S\u0001S\u0001"+ + "S\u0001T\u0001T\u0001U\u0001U\u0001U\u0001U\u0005U\u0338\bU\nU\fU\u033b"+ + "\tU\u0001V\u0001V\u0001V\u0000\u0005\u0002n\u0084\u008c\u008eW\u0000\u0002"+ + "\u0004\u0006\b\n\f\u000e\u0010\u0012\u0014\u0016\u0018\u001a\u001c\u001e"+ + " \"$&(*,.02468:<>@BDFHJLNPRTVXZ\\^`bdfhjlnprtvxz|~\u0080\u0082\u0084\u0086"+ + "\u0088\u008a\u008c\u008e\u0090\u0092\u0094\u0096\u0098\u009a\u009c\u009e"+ + "\u00a0\u00a2\u00a4\u00a6\u00a8\u00aa\u00ac\u0000\n\u0002\u000055kk\u0001"+ + "\u0000ef\u0002\u000099??\u0002\u0000BBEE\u0002\u0000&&55\u0001\u0000W"+ + "X\u0001\u0000Y[\u0002\u0000AANN\u0002\u0000PPRV\u0002\u0000\u0017\u0017"+ + "\u0019\u001a\u035c\u0000\u00ae\u0001\u0000\u0000\u0000\u0002\u00b1\u0001"+ + "\u0000\u0000\u0000\u0004\u00c3\u0001\u0000\u0000\u0000\u0006\u00e0\u0001"+ + "\u0000\u0000\u0000\b\u00e2\u0001\u0000\u0000\u0000\n\u00e5\u0001\u0000"+ + "\u0000\u0000\f\u00e7\u0001\u0000\u0000\u0000\u000e\u00ea\u0001\u0000\u0000"+ + "\u0000\u0010\u00f5\u0001\u0000\u0000\u0000\u0012\u00f9\u0001\u0000\u0000"+ + "\u0000\u0014\u0101\u0001\u0000\u0000\u0000\u0016\u0106\u0001\u0000\u0000"+ + "\u0000\u0018\u0109\u0001\u0000\u0000\u0000\u001a\u010c\u0001\u0000\u0000"+ + "\u0000\u001c\u0120\u0001\u0000\u0000\u0000\u001e\u0122\u0001\u0000\u0000"+ + "\u0000 \u0124\u0001\u0000\u0000\u0000\"\u0126\u0001\u0000\u0000\u0000"+ + "$\u0128\u0001\u0000\u0000\u0000&\u012a\u0001\u0000\u0000\u0000(\u0133"+ + "\u0001\u0000\u0000\u0000*\u0136\u0001\u0000\u0000\u0000,\u013e\u0001\u0000"+ + "\u0000\u0000.\u0146\u0001\u0000\u0000\u00000\u014b\u0001\u0000\u0000\u0000"+ + "2\u0153\u0001\u0000\u0000\u00004\u015b\u0001\u0000\u0000\u00006\u0163"+ + "\u0001\u0000\u0000\u00008\u0168\u0001\u0000\u0000\u0000:\u016c\u0001\u0000"+ + "\u0000\u0000<\u0170\u0001\u0000\u0000\u0000>\u0175\u0001\u0000\u0000\u0000"+ + "@\u0177\u0001\u0000\u0000\u0000B\u017a\u0001\u0000\u0000\u0000D\u0183"+ + "\u0001\u0000\u0000\u0000F\u018b\u0001\u0000\u0000\u0000H\u018e\u0001\u0000"+ + "\u0000\u0000J\u0191\u0001\u0000\u0000\u0000L\u01a2\u0001\u0000\u0000\u0000"+ + "N\u01a4\u0001\u0000\u0000\u0000P\u01aa\u0001\u0000\u0000\u0000R\u01ae"+ + "\u0001\u0000\u0000\u0000T\u01b1\u0001\u0000\u0000\u0000V\u01b9\u0001\u0000"+ + "\u0000\u0000X\u01bd\u0001\u0000\u0000\u0000Z\u01c0\u0001\u0000\u0000\u0000"+ + "\\\u01c4\u0001\u0000\u0000\u0000^\u01c7\u0001\u0000\u0000\u0000`\u01d8"+ + "\u0001\u0000\u0000\u0000b\u01dd\u0001\u0000\u0000\u0000d\u01e1\u0001\u0000"+ + "\u0000\u0000f\u01e4\u0001\u0000\u0000\u0000h\u01f1\u0001\u0000\u0000\u0000"+ + "j\u01f5\u0001\u0000\u0000\u0000l\u01f9\u0001\u0000\u0000\u0000n\u01fd"+ + "\u0001\u0000\u0000\u0000p\u0208\u0001\u0000\u0000\u0000r\u020a\u0001\u0000"+ + "\u0000\u0000t\u0214\u0001\u0000\u0000\u0000v\u0219\u0001\u0000\u0000\u0000"+ + "x\u021f\u0001\u0000\u0000\u0000z\u0222\u0001\u0000\u0000\u0000|\u0224"+ + "\u0001\u0000\u0000\u0000~\u022c\u0001\u0000\u0000\u0000\u0080\u0232\u0001"+ + "\u0000\u0000\u0000\u0082\u0234\u0001\u0000\u0000\u0000\u0084\u0259\u0001"+ + "\u0000\u0000\u0000\u0086\u0294\u0001\u0000\u0000\u0000\u0088\u0296\u0001"+ + "\u0000\u0000\u0000\u008a\u02a3\u0001\u0000\u0000\u0000\u008c\u02a9\u0001"+ + "\u0000\u0000\u0000\u008e\u02be\u0001\u0000\u0000\u0000\u0090\u02c8\u0001"+ + "\u0000\u0000\u0000\u0092\u02db\u0001\u0000\u0000\u0000\u0094\u02dd\u0001"+ + "\u0000\u0000\u0000\u0096\u02e8\u0001\u0000\u0000\u0000\u0098\u0316\u0001"+ + "\u0000\u0000\u0000\u009a\u0318\u0001\u0000\u0000\u0000\u009c\u031c\u0001"+ + "\u0000\u0000\u0000\u009e\u031f\u0001\u0000\u0000\u0000\u00a0\u0324\u0001"+ + "\u0000\u0000\u0000\u00a2\u0328\u0001\u0000\u0000\u0000\u00a4\u032a\u0001"+ + "\u0000\u0000\u0000\u00a6\u032c\u0001\u0000\u0000\u0000\u00a8\u0331\u0001"+ + "\u0000\u0000\u0000\u00aa\u0333\u0001\u0000\u0000\u0000\u00ac\u033c\u0001"+ + "\u0000\u0000\u0000\u00ae\u00af\u0003\u0002\u0001\u0000\u00af\u00b0\u0005"+ + "\u0000\u0000\u0001\u00b0\u0001\u0001\u0000\u0000\u0000\u00b1\u00b2\u0006"+ + "\u0001\uffff\uffff\u0000\u00b2\u00b3\u0003\u0004\u0002\u0000\u00b3\u00b9"+ + "\u0001\u0000\u0000\u0000\u00b4\u00b5\n\u0001\u0000\u0000\u00b5\u00b6\u0005"+ + "4\u0000\u0000\u00b6\u00b8\u0003\u0006\u0003\u0000\u00b7\u00b4\u0001\u0000"+ + "\u0000\u0000\u00b8\u00bb\u0001\u0000\u0000\u0000\u00b9\u00b7\u0001\u0000"+ + "\u0000\u0000\u00b9\u00ba\u0001\u0000\u0000\u0000\u00ba\u0003\u0001\u0000"+ + "\u0000\u0000\u00bb\u00b9\u0001\u0000\u0000\u0000\u00bc\u00c4\u0003\u0016"+ + "\u000b\u0000\u00bd\u00c4\u0003\f\u0006\u0000\u00be\u00c4\u0003\\.\u0000"+ + "\u00bf\u00c0\u0004\u0002\u0001\u0000\u00c0\u00c4\u0003\u0018\f\u0000\u00c1"+ + "\u00c2\u0004\u0002\u0002\u0000\u00c2\u00c4\u0003X,\u0000\u00c3\u00bc\u0001"+ + "\u0000\u0000\u0000\u00c3\u00bd\u0001\u0000\u0000\u0000\u00c3\u00be\u0001"+ + "\u0000\u0000\u0000\u00c3\u00bf\u0001\u0000\u0000\u0000\u00c3\u00c1\u0001"+ + "\u0000\u0000\u0000\u00c4\u0005\u0001\u0000\u0000\u0000\u00c5\u00e1\u0003"+ + "(\u0014\u0000\u00c6\u00e1\u0003\b\u0004\u0000\u00c7\u00e1\u0003F#\u0000"+ + "\u00c8\u00e1\u0003@ \u0000\u00c9\u00e1\u0003*\u0015\u0000\u00ca\u00e1"+ + "\u0003B!\u0000\u00cb\u00e1\u0003H$\u0000\u00cc\u00e1\u0003J%\u0000\u00cd"+ + "\u00e1\u0003N\'\u0000\u00ce\u00e1\u0003P(\u0000\u00cf\u00e1\u0003^/\u0000"+ + "\u00d0\u00e1\u0003R)\u0000\u00d1\u00e1\u0003\u00a6S\u0000\u00d2\u00e1"+ + "\u0003f3\u0000\u00d3\u00e1\u0003r9\u0000\u00d4\u00e1\u0003d2\u0000\u00d5"+ + "\u00e1\u0003h4\u0000\u00d6\u00d7\u0004\u0003\u0003\u0000\u00d7\u00e1\u0003"+ + "v;\u0000\u00d8\u00d9\u0004\u0003\u0004\u0000\u00d9\u00e1\u0003t:\u0000"+ + "\u00da\u00db\u0004\u0003\u0005\u0000\u00db\u00e1\u0003x<\u0000\u00dc\u00dd"+ + "\u0004\u0003\u0006\u0000\u00dd\u00e1\u0003\u0082A\u0000\u00de\u00df\u0004"+ + "\u0003\u0007\u0000\u00df\u00e1\u0003z=\u0000\u00e0\u00c5\u0001\u0000\u0000"+ + "\u0000\u00e0\u00c6\u0001\u0000\u0000\u0000\u00e0\u00c7\u0001\u0000\u0000"+ + "\u0000\u00e0\u00c8\u0001\u0000\u0000\u0000\u00e0\u00c9\u0001\u0000\u0000"+ + "\u0000\u00e0\u00ca\u0001\u0000\u0000\u0000\u00e0\u00cb\u0001\u0000\u0000"+ + "\u0000\u00e0\u00cc\u0001\u0000\u0000\u0000\u00e0\u00cd\u0001\u0000\u0000"+ + "\u0000\u00e0\u00ce\u0001\u0000\u0000\u0000\u00e0\u00cf\u0001\u0000\u0000"+ + "\u0000\u00e0\u00d0\u0001\u0000\u0000\u0000\u00e0\u00d1\u0001\u0000\u0000"+ + "\u0000\u00e0\u00d2\u0001\u0000\u0000\u0000\u00e0\u00d3\u0001\u0000\u0000"+ + "\u0000\u00e0\u00d4\u0001\u0000\u0000\u0000\u00e0\u00d5\u0001\u0000\u0000"+ + "\u0000\u00e0\u00d6\u0001\u0000\u0000\u0000\u00e0\u00d8\u0001\u0000\u0000"+ + "\u0000\u00e0\u00da\u0001\u0000\u0000\u0000\u00e0\u00dc\u0001\u0000\u0000"+ + "\u0000\u00e0\u00de\u0001\u0000\u0000\u0000\u00e1\u0007\u0001\u0000\u0000"+ + "\u0000\u00e2\u00e3\u0005\u0010\u0000\u0000\u00e3\u00e4\u0003\u0084B\u0000"+ + "\u00e4\t\u0001\u0000\u0000\u0000\u00e5\u00e6\u00036\u001b\u0000\u00e6"+ + "\u000b\u0001\u0000\u0000\u0000\u00e7\u00e8\u0005\f\u0000\u0000\u00e8\u00e9"+ + "\u0003\u000e\u0007\u0000\u00e9\r\u0001\u0000\u0000\u0000\u00ea\u00ef\u0003"+ + "\u0010\b\u0000\u00eb\u00ec\u0005>\u0000\u0000\u00ec\u00ee\u0003\u0010"+ + "\b\u0000\u00ed\u00eb\u0001\u0000\u0000\u0000\u00ee\u00f1\u0001\u0000\u0000"+ + "\u0000\u00ef\u00ed\u0001\u0000\u0000\u0000\u00ef\u00f0\u0001\u0000\u0000"+ + "\u0000\u00f0\u000f\u0001\u0000\u0000\u0000\u00f1\u00ef\u0001\u0000\u0000"+ + "\u0000\u00f2\u00f3\u00030\u0018\u0000\u00f3\u00f4\u0005:\u0000\u0000\u00f4"+ + "\u00f6\u0001\u0000\u0000\u0000\u00f5\u00f2\u0001\u0000\u0000\u0000\u00f5"+ + "\u00f6\u0001\u0000\u0000\u0000\u00f6\u00f7\u0001\u0000\u0000\u0000\u00f7"+ + "\u00f8\u0003\u0084B\u0000\u00f8\u0011\u0001\u0000\u0000\u0000\u00f9\u00fe"+ + "\u0003\u0014\n\u0000\u00fa\u00fb\u0005>\u0000\u0000\u00fb\u00fd\u0003"+ + "\u0014\n\u0000\u00fc\u00fa\u0001\u0000\u0000\u0000\u00fd\u0100\u0001\u0000"+ + "\u0000\u0000\u00fe\u00fc\u0001\u0000\u0000\u0000\u00fe\u00ff\u0001\u0000"+ + "\u0000\u0000\u00ff\u0013\u0001\u0000\u0000\u0000\u0100\u00fe\u0001\u0000"+ + "\u0000\u0000\u0101\u0104\u00030\u0018\u0000\u0102\u0103\u0005:\u0000\u0000"+ + "\u0103\u0105\u0003\u0084B\u0000\u0104\u0102\u0001\u0000\u0000\u0000\u0104"+ + "\u0105\u0001\u0000\u0000\u0000\u0105\u0015\u0001\u0000\u0000\u0000\u0106"+ + "\u0107\u0005\u0013\u0000\u0000\u0107\u0108\u0003\u001a\r\u0000\u0108\u0017"+ + "\u0001\u0000\u0000\u0000\u0109\u010a\u0005\u0014\u0000\u0000\u010a\u010b"+ + "\u0003\u001a\r\u0000\u010b\u0019\u0001\u0000\u0000\u0000\u010c\u0111\u0003"+ + "\u001c\u000e\u0000\u010d\u010e\u0005>\u0000\u0000\u010e\u0110\u0003\u001c"+ + "\u000e\u0000\u010f\u010d\u0001\u0000\u0000\u0000\u0110\u0113\u0001\u0000"+ + "\u0000\u0000\u0111\u010f\u0001\u0000\u0000\u0000\u0111\u0112\u0001\u0000"+ + "\u0000\u0000\u0112\u0115\u0001\u0000\u0000\u0000\u0113\u0111\u0001\u0000"+ + "\u0000\u0000\u0114\u0116\u0003&\u0013\u0000\u0115\u0114\u0001\u0000\u0000"+ + "\u0000\u0115\u0116\u0001\u0000\u0000\u0000\u0116\u001b\u0001\u0000\u0000"+ + "\u0000\u0117\u0118\u0003\u001e\u000f\u0000\u0118\u0119\u0005=\u0000\u0000"+ + "\u0119\u011a\u0003\"\u0011\u0000\u011a\u0121\u0001\u0000\u0000\u0000\u011b"+ + "\u011c\u0003\"\u0011\u0000\u011c\u011d\u0005<\u0000\u0000\u011d\u011e"+ + "\u0003 \u0010\u0000\u011e\u0121\u0001\u0000\u0000\u0000\u011f\u0121\u0003"+ + "$\u0012\u0000\u0120\u0117\u0001\u0000\u0000\u0000\u0120\u011b\u0001\u0000"+ + "\u0000\u0000\u0120\u011f\u0001\u0000\u0000\u0000\u0121\u001d\u0001\u0000"+ + "\u0000\u0000\u0122\u0123\u0005k\u0000\u0000\u0123\u001f\u0001\u0000\u0000"+ + "\u0000\u0124\u0125\u0005k\u0000\u0000\u0125!\u0001\u0000\u0000\u0000\u0126"+ + "\u0127\u0005k\u0000\u0000\u0127#\u0001\u0000\u0000\u0000\u0128\u0129\u0007"+ + "\u0000\u0000\u0000\u0129%\u0001\u0000\u0000\u0000\u012a\u012b\u0005j\u0000"+ + "\u0000\u012b\u0130\u0005k\u0000\u0000\u012c\u012d\u0005>\u0000\u0000\u012d"+ + "\u012f\u0005k\u0000\u0000\u012e\u012c\u0001\u0000\u0000\u0000\u012f\u0132"+ + "\u0001\u0000\u0000\u0000\u0130\u012e\u0001\u0000\u0000\u0000\u0130\u0131"+ + "\u0001\u0000\u0000\u0000\u0131\'\u0001\u0000\u0000\u0000\u0132\u0130\u0001"+ + "\u0000\u0000\u0000\u0133\u0134\u0005\t\u0000\u0000\u0134\u0135\u0003\u000e"+ + "\u0007\u0000\u0135)\u0001\u0000\u0000\u0000\u0136\u0138\u0005\u000f\u0000"+ + "\u0000\u0137\u0139\u0003,\u0016\u0000\u0138\u0137\u0001\u0000\u0000\u0000"+ + "\u0138\u0139\u0001\u0000\u0000\u0000\u0139\u013c\u0001\u0000\u0000\u0000"+ + "\u013a\u013b\u0005;\u0000\u0000\u013b\u013d\u0003\u000e\u0007\u0000\u013c"+ + "\u013a\u0001\u0000\u0000\u0000\u013c\u013d\u0001\u0000\u0000\u0000\u013d"+ + "+\u0001\u0000\u0000\u0000\u013e\u0143\u0003.\u0017\u0000\u013f\u0140\u0005"+ + ">\u0000\u0000\u0140\u0142\u0003.\u0017\u0000\u0141\u013f\u0001\u0000\u0000"+ + "\u0000\u0142\u0145\u0001\u0000\u0000\u0000\u0143\u0141\u0001\u0000\u0000"+ + "\u0000\u0143\u0144\u0001\u0000\u0000\u0000\u0144-\u0001\u0000\u0000\u0000"+ + "\u0145\u0143\u0001\u0000\u0000\u0000\u0146\u0149\u0003\u0010\b\u0000\u0147"+ + "\u0148\u0005\u0010\u0000\u0000\u0148\u014a\u0003\u0084B\u0000\u0149\u0147"+ + "\u0001\u0000\u0000\u0000\u0149\u014a\u0001\u0000\u0000\u0000\u014a/\u0001"+ + "\u0000\u0000\u0000\u014b\u0150\u0003>\u001f\u0000\u014c\u014d\u0005@\u0000"+ + "\u0000\u014d\u014f\u0003>\u001f\u0000\u014e\u014c\u0001\u0000\u0000\u0000"+ + "\u014f\u0152\u0001\u0000\u0000\u0000\u0150\u014e\u0001\u0000\u0000\u0000"+ + "\u0150\u0151\u0001\u0000\u0000\u0000\u01511\u0001\u0000\u0000\u0000\u0152"+ + "\u0150\u0001\u0000\u0000\u0000\u0153\u0158\u00038\u001c\u0000\u0154\u0155"+ + "\u0005@\u0000\u0000\u0155\u0157\u00038\u001c\u0000\u0156\u0154\u0001\u0000"+ + "\u0000\u0000\u0157\u015a\u0001\u0000\u0000\u0000\u0158\u0156\u0001\u0000"+ + "\u0000\u0000\u0158\u0159\u0001\u0000\u0000\u0000\u01593\u0001\u0000\u0000"+ + "\u0000\u015a\u0158\u0001\u0000\u0000\u0000\u015b\u0160\u00032\u0019\u0000"+ + "\u015c\u015d\u0005>\u0000\u0000\u015d\u015f\u00032\u0019\u0000\u015e\u015c"+ + "\u0001\u0000\u0000\u0000\u015f\u0162\u0001\u0000\u0000\u0000\u0160\u015e"+ + "\u0001\u0000\u0000\u0000\u0160\u0161\u0001\u0000\u0000\u0000\u01615\u0001"+ + "\u0000\u0000\u0000\u0162\u0160\u0001\u0000\u0000\u0000\u0163\u0164\u0007"+ + "\u0001\u0000\u0000\u01647\u0001\u0000\u0000\u0000\u0165\u0169\u0005\u0080"+ + "\u0000\u0000\u0166\u0169\u0003:\u001d\u0000\u0167\u0169\u0003<\u001e\u0000"+ + "\u0168\u0165\u0001\u0000\u0000\u0000\u0168\u0166\u0001\u0000\u0000\u0000"+ + "\u0168\u0167\u0001\u0000\u0000\u0000\u01699\u0001\u0000\u0000\u0000\u016a"+ + "\u016d\u0005L\u0000\u0000\u016b\u016d\u0005_\u0000\u0000\u016c\u016a\u0001"+ + "\u0000\u0000\u0000\u016c\u016b\u0001\u0000\u0000\u0000\u016d;\u0001\u0000"+ + "\u0000\u0000\u016e\u0171\u0005^\u0000\u0000\u016f\u0171\u0005`\u0000\u0000"+ + "\u0170\u016e\u0001\u0000\u0000\u0000\u0170\u016f\u0001\u0000\u0000\u0000"+ + "\u0171=\u0001\u0000\u0000\u0000\u0172\u0176\u00036\u001b\u0000\u0173\u0176"+ + "\u0003:\u001d\u0000\u0174\u0176\u0003<\u001e\u0000\u0175\u0172\u0001\u0000"+ + "\u0000\u0000\u0175\u0173\u0001\u0000\u0000\u0000\u0175\u0174\u0001\u0000"+ + "\u0000\u0000\u0176?\u0001\u0000\u0000\u0000\u0177\u0178\u0005\u000b\u0000"+ + "\u0000\u0178\u0179\u0003\u0098L\u0000\u0179A\u0001\u0000\u0000\u0000\u017a"+ + "\u017b\u0005\u000e\u0000\u0000\u017b\u0180\u0003D\"\u0000\u017c\u017d"+ + "\u0005>\u0000\u0000\u017d\u017f\u0003D\"\u0000\u017e\u017c\u0001\u0000"+ + "\u0000\u0000\u017f\u0182\u0001\u0000\u0000\u0000\u0180\u017e\u0001\u0000"+ + "\u0000\u0000\u0180\u0181\u0001\u0000\u0000\u0000\u0181C\u0001\u0000\u0000"+ + "\u0000\u0182\u0180\u0001\u0000\u0000\u0000\u0183\u0185\u0003\u0084B\u0000"+ + "\u0184\u0186\u0007\u0002\u0000\u0000\u0185\u0184\u0001\u0000\u0000\u0000"+ + "\u0185\u0186\u0001\u0000\u0000\u0000\u0186\u0189\u0001\u0000\u0000\u0000"+ + "\u0187\u0188\u0005I\u0000\u0000\u0188\u018a\u0007\u0003\u0000\u0000\u0189"+ + "\u0187\u0001\u0000\u0000\u0000\u0189\u018a\u0001\u0000\u0000\u0000\u018a"+ + "E\u0001\u0000\u0000\u0000\u018b\u018c\u0005\u001e\u0000\u0000\u018c\u018d"+ + "\u00034\u001a\u0000\u018dG\u0001\u0000\u0000\u0000\u018e\u018f\u0005\u001d"+ + "\u0000\u0000\u018f\u0190\u00034\u001a\u0000\u0190I\u0001\u0000\u0000\u0000"+ + "\u0191\u0192\u0005 \u0000\u0000\u0192\u0197\u0003L&\u0000\u0193\u0194"+ + "\u0005>\u0000\u0000\u0194\u0196\u0003L&\u0000\u0195\u0193\u0001\u0000"+ + "\u0000\u0000\u0196\u0199\u0001\u0000\u0000\u0000\u0197\u0195\u0001\u0000"+ + "\u0000\u0000\u0197\u0198\u0001\u0000\u0000\u0000\u0198K\u0001\u0000\u0000"+ + "\u0000\u0199\u0197\u0001\u0000\u0000\u0000\u019a\u019b\u00032\u0019\u0000"+ + "\u019b\u019c\u0005\u0084\u0000\u0000\u019c\u019d\u00032\u0019\u0000\u019d"+ + "\u01a3\u0001\u0000\u0000\u0000\u019e\u019f\u00032\u0019\u0000\u019f\u01a0"+ + "\u0005:\u0000\u0000\u01a0\u01a1\u00032\u0019\u0000\u01a1\u01a3\u0001\u0000"+ + "\u0000\u0000\u01a2\u019a\u0001\u0000\u0000\u0000\u01a2\u019e\u0001\u0000"+ + "\u0000\u0000\u01a3M\u0001\u0000\u0000\u0000\u01a4\u01a5\u0005\b\u0000"+ + "\u0000\u01a5\u01a6\u0003\u008eG\u0000\u01a6\u01a8\u0003\u00a2Q\u0000\u01a7"+ + "\u01a9\u0003T*\u0000\u01a8\u01a7\u0001\u0000\u0000\u0000\u01a8\u01a9\u0001"+ + "\u0000\u0000\u0000\u01a9O\u0001\u0000\u0000\u0000\u01aa\u01ab\u0005\n"+ + "\u0000\u0000\u01ab\u01ac\u0003\u008eG\u0000\u01ac\u01ad\u0003\u00a2Q\u0000"+ + "\u01adQ\u0001\u0000\u0000\u0000\u01ae\u01af\u0005\u001c\u0000\u0000\u01af"+ + "\u01b0\u00030\u0018\u0000\u01b0S\u0001\u0000\u0000\u0000\u01b1\u01b6\u0003"+ + "V+\u0000\u01b2\u01b3\u0005>\u0000\u0000\u01b3\u01b5\u0003V+\u0000\u01b4"+ "\u01b2\u0001\u0000\u0000\u0000\u01b5\u01b8\u0001\u0000\u0000\u0000\u01b6"+ "\u01b4\u0001\u0000\u0000\u0000\u01b6\u01b7\u0001\u0000\u0000\u0000\u01b7"+ "U\u0001\u0000\u0000\u0000\u01b8\u01b6\u0001\u0000\u0000\u0000\u01b9\u01ba"+ @@ -7736,11 +7814,11 @@ private boolean primaryExpression_sempred(PrimaryExpressionContext _localctx, in "\u0000\u0000\u0265\u0263\u0001\u0000\u0000\u0000\u0266\u0268\u0003\u008a"+ "E\u0000\u0267\u0269\u0005G\u0000\u0000\u0268\u0267\u0001\u0000\u0000\u0000"+ "\u0268\u0269\u0001\u0000\u0000\u0000\u0269\u026a\u0001\u0000\u0000\u0000"+ - "\u026a\u026b\u0005F\u0000\u0000\u026b\u026c\u0003\u00a2Q\u0000\u026c\u0285"+ + "\u026a\u026b\u0005F\u0000\u0000\u026b\u026c\u0003\u00a2Q\u0000\u026c\u0295"+ "\u0001\u0000\u0000\u0000\u026d\u026f\u0003\u008aE\u0000\u026e\u0270\u0005"+ "G\u0000\u0000\u026f\u026e\u0001\u0000\u0000\u0000\u026f\u0270\u0001\u0000"+ "\u0000\u0000\u0270\u0271\u0001\u0000\u0000\u0000\u0271\u0272\u0005M\u0000"+ - "\u0000\u0272\u0273\u0003\u00a2Q\u0000\u0273\u0285\u0001\u0000\u0000\u0000"+ + "\u0000\u0272\u0273\u0003\u00a2Q\u0000\u0273\u0295\u0001\u0000\u0000\u0000"+ "\u0274\u0276\u0003\u008aE\u0000\u0275\u0277\u0005G\u0000\u0000\u0276\u0275"+ "\u0001\u0000\u0000\u0000\u0276\u0277\u0001\u0000\u0000\u0000\u0277\u0278"+ "\u0001\u0000\u0000\u0000\u0278\u0279\u0005F\u0000\u0000\u0279\u027a\u0005"+ @@ -7749,108 +7827,117 @@ private boolean primaryExpression_sempred(PrimaryExpressionContext _localctx, in "\u027e\u0281\u0001\u0000\u0000\u0000\u027f\u027d\u0001\u0000\u0000\u0000"+ "\u027f\u0280\u0001\u0000\u0000\u0000\u0280\u0282\u0001\u0000\u0000\u0000"+ "\u0281\u027f\u0001\u0000\u0000\u0000\u0282\u0283\u0005d\u0000\u0000\u0283"+ - "\u0285\u0001\u0000\u0000\u0000\u0284\u0266\u0001\u0000\u0000\u0000\u0284"+ - "\u026d\u0001\u0000\u0000\u0000\u0284\u0274\u0001\u0000\u0000\u0000\u0285"+ - "\u0087\u0001\u0000\u0000\u0000\u0286\u0289\u00030\u0018\u0000\u0287\u0288"+ - "\u0005<\u0000\u0000\u0288\u028a\u0003\n\u0005\u0000\u0289\u0287\u0001"+ - "\u0000\u0000\u0000\u0289\u028a\u0001\u0000\u0000\u0000\u028a\u028b\u0001"+ - "\u0000\u0000\u0000\u028b\u028c\u0005=\u0000\u0000\u028c\u028d\u0003\u0098"+ - "L\u0000\u028d\u0089\u0001\u0000\u0000\u0000\u028e\u0294\u0003\u008cF\u0000"+ - "\u028f\u0290\u0003\u008cF\u0000\u0290\u0291\u0003\u00a4R\u0000\u0291\u0292"+ - "\u0003\u008cF\u0000\u0292\u0294\u0001\u0000\u0000\u0000\u0293\u028e\u0001"+ - "\u0000\u0000\u0000\u0293\u028f\u0001\u0000\u0000\u0000\u0294\u008b\u0001"+ - "\u0000\u0000\u0000\u0295\u0296\u0006F\uffff\uffff\u0000\u0296\u029a\u0003"+ - "\u008eG\u0000\u0297\u0298\u0007\u0005\u0000\u0000\u0298\u029a\u0003\u008c"+ - "F\u0003\u0299\u0295\u0001\u0000\u0000\u0000\u0299\u0297\u0001\u0000\u0000"+ - "\u0000\u029a\u02a3\u0001\u0000\u0000\u0000\u029b\u029c\n\u0002\u0000\u0000"+ - "\u029c\u029d\u0007\u0006\u0000\u0000\u029d\u02a2\u0003\u008cF\u0003\u029e"+ - "\u029f\n\u0001\u0000\u0000\u029f\u02a0\u0007\u0005\u0000\u0000\u02a0\u02a2"+ - "\u0003\u008cF\u0002\u02a1\u029b\u0001\u0000\u0000\u0000\u02a1\u029e\u0001"+ - "\u0000\u0000\u0000\u02a2\u02a5\u0001\u0000\u0000\u0000\u02a3\u02a1\u0001"+ - "\u0000\u0000\u0000\u02a3\u02a4\u0001\u0000\u0000\u0000\u02a4\u008d\u0001"+ - "\u0000\u0000\u0000\u02a5\u02a3\u0001\u0000\u0000\u0000\u02a6\u02a7\u0006"+ - "G\uffff\uffff\u0000\u02a7\u02af\u0003\u0098L\u0000\u02a8\u02af\u00030"+ - "\u0018\u0000\u02a9\u02af\u0003\u0090H\u0000\u02aa\u02ab\u0005c\u0000\u0000"+ - "\u02ab\u02ac\u0003\u0084B\u0000\u02ac\u02ad\u0005d\u0000\u0000\u02ad\u02af"+ - "\u0001\u0000\u0000\u0000\u02ae\u02a6\u0001\u0000\u0000\u0000\u02ae\u02a8"+ - "\u0001\u0000\u0000\u0000\u02ae\u02a9\u0001\u0000\u0000\u0000\u02ae\u02aa"+ - "\u0001\u0000\u0000\u0000\u02af\u02b5\u0001\u0000\u0000\u0000\u02b0\u02b1"+ - "\n\u0001\u0000\u0000\u02b1\u02b2\u0005<\u0000\u0000\u02b2\u02b4\u0003"+ - "\n\u0005\u0000\u02b3\u02b0\u0001\u0000\u0000\u0000\u02b4\u02b7\u0001\u0000"+ - "\u0000\u0000\u02b5\u02b3\u0001\u0000\u0000\u0000\u02b5\u02b6\u0001\u0000"+ - "\u0000\u0000\u02b6\u008f\u0001\u0000\u0000\u0000\u02b7\u02b5\u0001\u0000"+ - "\u0000\u0000\u02b8\u02b9\u0003\u0092I\u0000\u02b9\u02c7\u0005c\u0000\u0000"+ - "\u02ba\u02c8\u0005Y\u0000\u0000\u02bb\u02c0\u0003\u0084B\u0000\u02bc\u02bd"+ - "\u0005>\u0000\u0000\u02bd\u02bf\u0003\u0084B\u0000\u02be\u02bc\u0001\u0000"+ - "\u0000\u0000\u02bf\u02c2\u0001\u0000\u0000\u0000\u02c0\u02be\u0001\u0000"+ - "\u0000\u0000\u02c0\u02c1\u0001\u0000\u0000\u0000\u02c1\u02c5\u0001\u0000"+ - "\u0000\u0000\u02c2\u02c0\u0001\u0000\u0000\u0000\u02c3\u02c4\u0005>\u0000"+ - "\u0000\u02c4\u02c6\u0003\u0094J\u0000\u02c5\u02c3\u0001\u0000\u0000\u0000"+ - "\u02c5\u02c6\u0001\u0000\u0000\u0000\u02c6\u02c8\u0001\u0000\u0000\u0000"+ - "\u02c7\u02ba\u0001\u0000\u0000\u0000\u02c7\u02bb\u0001\u0000\u0000\u0000"+ - "\u02c7\u02c8\u0001\u0000\u0000\u0000\u02c8\u02c9\u0001\u0000\u0000\u0000"+ - "\u02c9\u02ca\u0005d\u0000\u0000\u02ca\u0091\u0001\u0000\u0000\u0000\u02cb"+ - "\u02cc\u0003>\u001f\u0000\u02cc\u0093\u0001\u0000\u0000\u0000\u02cd\u02ce"+ - "\u0005\\\u0000\u0000\u02ce\u02d3\u0003\u0096K\u0000\u02cf\u02d0\u0005"+ - ">\u0000\u0000\u02d0\u02d2\u0003\u0096K\u0000\u02d1\u02cf\u0001\u0000\u0000"+ - "\u0000\u02d2\u02d5\u0001\u0000\u0000\u0000\u02d3\u02d1\u0001\u0000\u0000"+ - "\u0000\u02d3\u02d4\u0001\u0000\u0000\u0000\u02d4\u02d6\u0001\u0000\u0000"+ - "\u0000\u02d5\u02d3\u0001\u0000\u0000\u0000\u02d6\u02d7\u0005]\u0000\u0000"+ - "\u02d7\u0095\u0001\u0000\u0000\u0000\u02d8\u02d9\u0003\u00a2Q\u0000\u02d9"+ - "\u02da\u0005=\u0000\u0000\u02da\u02db\u0003\u0098L\u0000\u02db\u0097\u0001"+ - "\u0000\u0000\u0000\u02dc\u0307\u0005H\u0000\u0000\u02dd\u02de\u0003\u00a0"+ - "P\u0000\u02de\u02df\u0005e\u0000\u0000\u02df\u0307\u0001\u0000\u0000\u0000"+ - "\u02e0\u0307\u0003\u009eO\u0000\u02e1\u0307\u0003\u00a0P\u0000\u02e2\u0307"+ - "\u0003\u009aM\u0000\u02e3\u0307\u0003:\u001d\u0000\u02e4\u0307\u0003\u00a2"+ - "Q\u0000\u02e5\u02e6\u0005a\u0000\u0000\u02e6\u02eb\u0003\u009cN\u0000"+ - "\u02e7\u02e8\u0005>\u0000\u0000\u02e8\u02ea\u0003\u009cN\u0000\u02e9\u02e7"+ - "\u0001\u0000\u0000\u0000\u02ea\u02ed\u0001\u0000\u0000\u0000\u02eb\u02e9"+ - "\u0001\u0000\u0000\u0000\u02eb\u02ec\u0001\u0000\u0000\u0000\u02ec\u02ee"+ - "\u0001\u0000\u0000\u0000\u02ed\u02eb\u0001\u0000\u0000\u0000\u02ee\u02ef"+ - "\u0005b\u0000\u0000\u02ef\u0307\u0001\u0000\u0000\u0000\u02f0\u02f1\u0005"+ - "a\u0000\u0000\u02f1\u02f6\u0003\u009aM\u0000\u02f2\u02f3\u0005>\u0000"+ - "\u0000\u02f3\u02f5\u0003\u009aM\u0000\u02f4\u02f2\u0001\u0000\u0000\u0000"+ - "\u02f5\u02f8\u0001\u0000\u0000\u0000\u02f6\u02f4\u0001\u0000\u0000\u0000"+ - "\u02f6\u02f7\u0001\u0000\u0000\u0000\u02f7\u02f9\u0001\u0000\u0000\u0000"+ - "\u02f8\u02f6\u0001\u0000\u0000\u0000\u02f9\u02fa\u0005b\u0000\u0000\u02fa"+ - "\u0307\u0001\u0000\u0000\u0000\u02fb\u02fc\u0005a\u0000\u0000\u02fc\u0301"+ - "\u0003\u00a2Q\u0000\u02fd\u02fe\u0005>\u0000\u0000\u02fe\u0300\u0003\u00a2"+ - "Q\u0000\u02ff\u02fd\u0001\u0000\u0000\u0000\u0300\u0303\u0001\u0000\u0000"+ - "\u0000\u0301\u02ff\u0001\u0000\u0000\u0000\u0301\u0302\u0001\u0000\u0000"+ - "\u0000\u0302\u0304\u0001\u0000\u0000\u0000\u0303\u0301\u0001\u0000\u0000"+ - "\u0000\u0304\u0305\u0005b\u0000\u0000\u0305\u0307\u0001\u0000\u0000\u0000"+ - "\u0306\u02dc\u0001\u0000\u0000\u0000\u0306\u02dd\u0001\u0000\u0000\u0000"+ - "\u0306\u02e0\u0001\u0000\u0000\u0000\u0306\u02e1\u0001\u0000\u0000\u0000"+ - "\u0306\u02e2\u0001\u0000\u0000\u0000\u0306\u02e3\u0001\u0000\u0000\u0000"+ - "\u0306\u02e4\u0001\u0000\u0000\u0000\u0306\u02e5\u0001\u0000\u0000\u0000"+ - "\u0306\u02f0\u0001\u0000\u0000\u0000\u0306\u02fb\u0001\u0000\u0000\u0000"+ - "\u0307\u0099\u0001\u0000\u0000\u0000\u0308\u0309\u0007\u0007\u0000\u0000"+ - "\u0309\u009b\u0001\u0000\u0000\u0000\u030a\u030d\u0003\u009eO\u0000\u030b"+ - "\u030d\u0003\u00a0P\u0000\u030c\u030a\u0001\u0000\u0000\u0000\u030c\u030b"+ - "\u0001\u0000\u0000\u0000\u030d\u009d\u0001\u0000\u0000\u0000\u030e\u0310"+ - "\u0007\u0005\u0000\u0000\u030f\u030e\u0001\u0000\u0000\u0000\u030f\u0310"+ - "\u0001\u0000\u0000\u0000\u0310\u0311\u0001\u0000\u0000\u0000\u0311\u0312"+ - "\u00057\u0000\u0000\u0312\u009f\u0001\u0000\u0000\u0000\u0313\u0315\u0007"+ - "\u0005\u0000\u0000\u0314\u0313\u0001\u0000\u0000\u0000\u0314\u0315\u0001"+ - "\u0000\u0000\u0000\u0315\u0316\u0001\u0000\u0000\u0000\u0316\u0317\u0005"+ - "6\u0000\u0000\u0317\u00a1\u0001\u0000\u0000\u0000\u0318\u0319\u00055\u0000"+ - "\u0000\u0319\u00a3\u0001\u0000\u0000\u0000\u031a\u031b\u0007\b\u0000\u0000"+ - "\u031b\u00a5\u0001\u0000\u0000\u0000\u031c\u031d\u0007\t\u0000\u0000\u031d"+ - "\u031e\u0005r\u0000\u0000\u031e\u031f\u0003\u00a8T\u0000\u031f\u0320\u0003"+ - "\u00aaU\u0000\u0320\u00a7\u0001\u0000\u0000\u0000\u0321\u0322\u0003\u001c"+ - "\u000e\u0000\u0322\u00a9\u0001\u0000\u0000\u0000\u0323\u0324\u0005J\u0000"+ - "\u0000\u0324\u0329\u0003\u00acV\u0000\u0325\u0326\u0005>\u0000\u0000\u0326"+ - "\u0328\u0003\u00acV\u0000\u0327\u0325\u0001\u0000\u0000\u0000\u0328\u032b"+ - "\u0001\u0000\u0000\u0000\u0329\u0327\u0001\u0000\u0000\u0000\u0329\u032a"+ - "\u0001\u0000\u0000\u0000\u032a\u00ab\u0001\u0000\u0000\u0000\u032b\u0329"+ - "\u0001\u0000\u0000\u0000\u032c\u032d\u0003\u008aE\u0000\u032d\u00ad\u0001"+ - "\u0000\u0000\u0000H\u00b9\u00c3\u00e0\u00ef\u00f5\u00fe\u0104\u0111\u0115"+ - "\u0120\u0130\u0138\u013c\u0143\u0149\u0150\u0158\u0160\u0168\u016c\u0170"+ - "\u0175\u0180\u0185\u0189\u0197\u01a2\u01a8\u01b6\u01cb\u01d3\u01d6\u01dd"+ - "\u01e8\u01ef\u01f7\u0205\u020e\u021d\u0229\u0232\u023a\u0243\u024c\u0254"+ - "\u0259\u0261\u0263\u0268\u026f\u0276\u027f\u0284\u0289\u0293\u0299\u02a1"+ - "\u02a3\u02ae\u02b5\u02c0\u02c5\u02c7\u02d3\u02eb\u02f6\u0301\u0306\u030c"+ - "\u030f\u0314\u0329"; + "\u0295\u0001\u0000\u0000\u0000\u0284\u0286\u0003\u008aE\u0000\u0285\u0287"+ + "\u0005G\u0000\u0000\u0286\u0285\u0001\u0000\u0000\u0000\u0286\u0287\u0001"+ + "\u0000\u0000\u0000\u0287\u0288\u0001\u0000\u0000\u0000\u0288\u0289\u0005"+ + "M\u0000\u0000\u0289\u028a\u0005c\u0000\u0000\u028a\u028f\u0003\u00a2Q"+ + "\u0000\u028b\u028c\u0005>\u0000\u0000\u028c\u028e\u0003\u00a2Q\u0000\u028d"+ + "\u028b\u0001\u0000\u0000\u0000\u028e\u0291\u0001\u0000\u0000\u0000\u028f"+ + "\u028d\u0001\u0000\u0000\u0000\u028f\u0290\u0001\u0000\u0000\u0000\u0290"+ + "\u0292\u0001\u0000\u0000\u0000\u0291\u028f\u0001\u0000\u0000\u0000\u0292"+ + "\u0293\u0005d\u0000\u0000\u0293\u0295\u0001\u0000\u0000\u0000\u0294\u0266"+ + "\u0001\u0000\u0000\u0000\u0294\u026d\u0001\u0000\u0000\u0000\u0294\u0274"+ + "\u0001\u0000\u0000\u0000\u0294\u0284\u0001\u0000\u0000\u0000\u0295\u0087"+ + "\u0001\u0000\u0000\u0000\u0296\u0299\u00030\u0018\u0000\u0297\u0298\u0005"+ + "<\u0000\u0000\u0298\u029a\u0003\n\u0005\u0000\u0299\u0297\u0001\u0000"+ + "\u0000\u0000\u0299\u029a\u0001\u0000\u0000\u0000\u029a\u029b\u0001\u0000"+ + "\u0000\u0000\u029b\u029c\u0005=\u0000\u0000\u029c\u029d\u0003\u0098L\u0000"+ + "\u029d\u0089\u0001\u0000\u0000\u0000\u029e\u02a4\u0003\u008cF\u0000\u029f"+ + "\u02a0\u0003\u008cF\u0000\u02a0\u02a1\u0003\u00a4R\u0000\u02a1\u02a2\u0003"+ + "\u008cF\u0000\u02a2\u02a4\u0001\u0000\u0000\u0000\u02a3\u029e\u0001\u0000"+ + "\u0000\u0000\u02a3\u029f\u0001\u0000\u0000\u0000\u02a4\u008b\u0001\u0000"+ + "\u0000\u0000\u02a5\u02a6\u0006F\uffff\uffff\u0000\u02a6\u02aa\u0003\u008e"+ + "G\u0000\u02a7\u02a8\u0007\u0005\u0000\u0000\u02a8\u02aa\u0003\u008cF\u0003"+ + "\u02a9\u02a5\u0001\u0000\u0000\u0000\u02a9\u02a7\u0001\u0000\u0000\u0000"+ + "\u02aa\u02b3\u0001\u0000\u0000\u0000\u02ab\u02ac\n\u0002\u0000\u0000\u02ac"+ + "\u02ad\u0007\u0006\u0000\u0000\u02ad\u02b2\u0003\u008cF\u0003\u02ae\u02af"+ + "\n\u0001\u0000\u0000\u02af\u02b0\u0007\u0005\u0000\u0000\u02b0\u02b2\u0003"+ + "\u008cF\u0002\u02b1\u02ab\u0001\u0000\u0000\u0000\u02b1\u02ae\u0001\u0000"+ + "\u0000\u0000\u02b2\u02b5\u0001\u0000\u0000\u0000\u02b3\u02b1\u0001\u0000"+ + "\u0000\u0000\u02b3\u02b4\u0001\u0000\u0000\u0000\u02b4\u008d\u0001\u0000"+ + "\u0000\u0000\u02b5\u02b3\u0001\u0000\u0000\u0000\u02b6\u02b7\u0006G\uffff"+ + "\uffff\u0000\u02b7\u02bf\u0003\u0098L\u0000\u02b8\u02bf\u00030\u0018\u0000"+ + "\u02b9\u02bf\u0003\u0090H\u0000\u02ba\u02bb\u0005c\u0000\u0000\u02bb\u02bc"+ + "\u0003\u0084B\u0000\u02bc\u02bd\u0005d\u0000\u0000\u02bd\u02bf\u0001\u0000"+ + "\u0000\u0000\u02be\u02b6\u0001\u0000\u0000\u0000\u02be\u02b8\u0001\u0000"+ + "\u0000\u0000\u02be\u02b9\u0001\u0000\u0000\u0000\u02be\u02ba\u0001\u0000"+ + "\u0000\u0000\u02bf\u02c5\u0001\u0000\u0000\u0000\u02c0\u02c1\n\u0001\u0000"+ + "\u0000\u02c1\u02c2\u0005<\u0000\u0000\u02c2\u02c4\u0003\n\u0005\u0000"+ + "\u02c3\u02c0\u0001\u0000\u0000\u0000\u02c4\u02c7\u0001\u0000\u0000\u0000"+ + "\u02c5\u02c3\u0001\u0000\u0000\u0000\u02c5\u02c6\u0001\u0000\u0000\u0000"+ + "\u02c6\u008f\u0001\u0000\u0000\u0000\u02c7\u02c5\u0001\u0000\u0000\u0000"+ + "\u02c8\u02c9\u0003\u0092I\u0000\u02c9\u02d7\u0005c\u0000\u0000\u02ca\u02d8"+ + "\u0005Y\u0000\u0000\u02cb\u02d0\u0003\u0084B\u0000\u02cc\u02cd\u0005>"+ + "\u0000\u0000\u02cd\u02cf\u0003\u0084B\u0000\u02ce\u02cc\u0001\u0000\u0000"+ + "\u0000\u02cf\u02d2\u0001\u0000\u0000\u0000\u02d0\u02ce\u0001\u0000\u0000"+ + "\u0000\u02d0\u02d1\u0001\u0000\u0000\u0000\u02d1\u02d5\u0001\u0000\u0000"+ + "\u0000\u02d2\u02d0\u0001\u0000\u0000\u0000\u02d3\u02d4\u0005>\u0000\u0000"+ + "\u02d4\u02d6\u0003\u0094J\u0000\u02d5\u02d3\u0001\u0000\u0000\u0000\u02d5"+ + "\u02d6\u0001\u0000\u0000\u0000\u02d6\u02d8\u0001\u0000\u0000\u0000\u02d7"+ + "\u02ca\u0001\u0000\u0000\u0000\u02d7\u02cb\u0001\u0000\u0000\u0000\u02d7"+ + "\u02d8\u0001\u0000\u0000\u0000\u02d8\u02d9\u0001\u0000\u0000\u0000\u02d9"+ + "\u02da\u0005d\u0000\u0000\u02da\u0091\u0001\u0000\u0000\u0000\u02db\u02dc"+ + "\u0003>\u001f\u0000\u02dc\u0093\u0001\u0000\u0000\u0000\u02dd\u02de\u0005"+ + "\\\u0000\u0000\u02de\u02e3\u0003\u0096K\u0000\u02df\u02e0\u0005>\u0000"+ + "\u0000\u02e0\u02e2\u0003\u0096K\u0000\u02e1\u02df\u0001\u0000\u0000\u0000"+ + "\u02e2\u02e5\u0001\u0000\u0000\u0000\u02e3\u02e1\u0001\u0000\u0000\u0000"+ + "\u02e3\u02e4\u0001\u0000\u0000\u0000\u02e4\u02e6\u0001\u0000\u0000\u0000"+ + "\u02e5\u02e3\u0001\u0000\u0000\u0000\u02e6\u02e7\u0005]\u0000\u0000\u02e7"+ + "\u0095\u0001\u0000\u0000\u0000\u02e8\u02e9\u0003\u00a2Q\u0000\u02e9\u02ea"+ + "\u0005=\u0000\u0000\u02ea\u02eb\u0003\u0098L\u0000\u02eb\u0097\u0001\u0000"+ + "\u0000\u0000\u02ec\u0317\u0005H\u0000\u0000\u02ed\u02ee\u0003\u00a0P\u0000"+ + "\u02ee\u02ef\u0005e\u0000\u0000\u02ef\u0317\u0001\u0000\u0000\u0000\u02f0"+ + "\u0317\u0003\u009eO\u0000\u02f1\u0317\u0003\u00a0P\u0000\u02f2\u0317\u0003"+ + "\u009aM\u0000\u02f3\u0317\u0003:\u001d\u0000\u02f4\u0317\u0003\u00a2Q"+ + "\u0000\u02f5\u02f6\u0005a\u0000\u0000\u02f6\u02fb\u0003\u009cN\u0000\u02f7"+ + "\u02f8\u0005>\u0000\u0000\u02f8\u02fa\u0003\u009cN\u0000\u02f9\u02f7\u0001"+ + "\u0000\u0000\u0000\u02fa\u02fd\u0001\u0000\u0000\u0000\u02fb\u02f9\u0001"+ + "\u0000\u0000\u0000\u02fb\u02fc\u0001\u0000\u0000\u0000\u02fc\u02fe\u0001"+ + "\u0000\u0000\u0000\u02fd\u02fb\u0001\u0000\u0000\u0000\u02fe\u02ff\u0005"+ + "b\u0000\u0000\u02ff\u0317\u0001\u0000\u0000\u0000\u0300\u0301\u0005a\u0000"+ + "\u0000\u0301\u0306\u0003\u009aM\u0000\u0302\u0303\u0005>\u0000\u0000\u0303"+ + "\u0305\u0003\u009aM\u0000\u0304\u0302\u0001\u0000\u0000\u0000\u0305\u0308"+ + "\u0001\u0000\u0000\u0000\u0306\u0304\u0001\u0000\u0000\u0000\u0306\u0307"+ + "\u0001\u0000\u0000\u0000\u0307\u0309\u0001\u0000\u0000\u0000\u0308\u0306"+ + "\u0001\u0000\u0000\u0000\u0309\u030a\u0005b\u0000\u0000\u030a\u0317\u0001"+ + "\u0000\u0000\u0000\u030b\u030c\u0005a\u0000\u0000\u030c\u0311\u0003\u00a2"+ + "Q\u0000\u030d\u030e\u0005>\u0000\u0000\u030e\u0310\u0003\u00a2Q\u0000"+ + "\u030f\u030d\u0001\u0000\u0000\u0000\u0310\u0313\u0001\u0000\u0000\u0000"+ + "\u0311\u030f\u0001\u0000\u0000\u0000\u0311\u0312\u0001\u0000\u0000\u0000"+ + "\u0312\u0314\u0001\u0000\u0000\u0000\u0313\u0311\u0001\u0000\u0000\u0000"+ + "\u0314\u0315\u0005b\u0000\u0000\u0315\u0317\u0001\u0000\u0000\u0000\u0316"+ + "\u02ec\u0001\u0000\u0000\u0000\u0316\u02ed\u0001\u0000\u0000\u0000\u0316"+ + "\u02f0\u0001\u0000\u0000\u0000\u0316\u02f1\u0001\u0000\u0000\u0000\u0316"+ + "\u02f2\u0001\u0000\u0000\u0000\u0316\u02f3\u0001\u0000\u0000\u0000\u0316"+ + "\u02f4\u0001\u0000\u0000\u0000\u0316\u02f5\u0001\u0000\u0000\u0000\u0316"+ + "\u0300\u0001\u0000\u0000\u0000\u0316\u030b\u0001\u0000\u0000\u0000\u0317"+ + "\u0099\u0001\u0000\u0000\u0000\u0318\u0319\u0007\u0007\u0000\u0000\u0319"+ + "\u009b\u0001\u0000\u0000\u0000\u031a\u031d\u0003\u009eO\u0000\u031b\u031d"+ + "\u0003\u00a0P\u0000\u031c\u031a\u0001\u0000\u0000\u0000\u031c\u031b\u0001"+ + "\u0000\u0000\u0000\u031d\u009d\u0001\u0000\u0000\u0000\u031e\u0320\u0007"+ + "\u0005\u0000\u0000\u031f\u031e\u0001\u0000\u0000\u0000\u031f\u0320\u0001"+ + "\u0000\u0000\u0000\u0320\u0321\u0001\u0000\u0000\u0000\u0321\u0322\u0005"+ + "7\u0000\u0000\u0322\u009f\u0001\u0000\u0000\u0000\u0323\u0325\u0007\u0005"+ + "\u0000\u0000\u0324\u0323\u0001\u0000\u0000\u0000\u0324\u0325\u0001\u0000"+ + "\u0000\u0000\u0325\u0326\u0001\u0000\u0000\u0000\u0326\u0327\u00056\u0000"+ + "\u0000\u0327\u00a1\u0001\u0000\u0000\u0000\u0328\u0329\u00055\u0000\u0000"+ + "\u0329\u00a3\u0001\u0000\u0000\u0000\u032a\u032b\u0007\b\u0000\u0000\u032b"+ + "\u00a5\u0001\u0000\u0000\u0000\u032c\u032d\u0007\t\u0000\u0000\u032d\u032e"+ + "\u0005r\u0000\u0000\u032e\u032f\u0003\u00a8T\u0000\u032f\u0330\u0003\u00aa"+ + "U\u0000\u0330\u00a7\u0001\u0000\u0000\u0000\u0331\u0332\u0003\u001c\u000e"+ + "\u0000\u0332\u00a9\u0001\u0000\u0000\u0000\u0333\u0334\u0005J\u0000\u0000"+ + "\u0334\u0339\u0003\u00acV\u0000\u0335\u0336\u0005>\u0000\u0000\u0336\u0338"+ + "\u0003\u00acV\u0000\u0337\u0335\u0001\u0000\u0000\u0000\u0338\u033b\u0001"+ + "\u0000\u0000\u0000\u0339\u0337\u0001\u0000\u0000\u0000\u0339\u033a\u0001"+ + "\u0000\u0000\u0000\u033a\u00ab\u0001\u0000\u0000\u0000\u033b\u0339\u0001"+ + "\u0000\u0000\u0000\u033c\u033d\u0003\u008aE\u0000\u033d\u00ad\u0001\u0000"+ + "\u0000\u0000J\u00b9\u00c3\u00e0\u00ef\u00f5\u00fe\u0104\u0111\u0115\u0120"+ + "\u0130\u0138\u013c\u0143\u0149\u0150\u0158\u0160\u0168\u016c\u0170\u0175"+ + "\u0180\u0185\u0189\u0197\u01a2\u01a8\u01b6\u01cb\u01d3\u01d6\u01dd\u01e8"+ + "\u01ef\u01f7\u0205\u020e\u021d\u0229\u0232\u023a\u0243\u024c\u0254\u0259"+ + "\u0261\u0263\u0268\u026f\u0276\u027f\u0286\u028f\u0294\u0299\u02a3\u02a9"+ + "\u02b1\u02b3\u02be\u02c5\u02d0\u02d5\u02d7\u02e3\u02fb\u0306\u0311\u0316"+ + "\u031c\u031f\u0324\u0339"; public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); static { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java index b9da6c86db845..b6d5af9d90870 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java @@ -980,6 +980,18 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener { *

    The default implementation does nothing.

    */ @Override public void exitLikeListExpression(EsqlBaseParser.LikeListExpressionContext ctx) { } + /** + * {@inheritDoc} + * + *

    The default implementation does nothing.

    + */ + @Override public void enterRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx) { } + /** + * {@inheritDoc} + * + *

    The default implementation does nothing.

    + */ + @Override public void exitRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx) { } /** * {@inheritDoc} * diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java index abae6c0ab1e46..ab8ed34810ddf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java @@ -580,6 +580,13 @@ public class EsqlBaseParserBaseVisitor extends AbstractParseTreeVisitor im * {@link #visitChildren} on {@code ctx}.

    */ @Override public T visitLikeListExpression(EsqlBaseParser.LikeListExpressionContext ctx) { return visitChildren(ctx); } + /** + * {@inheritDoc} + * + *

    The default implementation returns the result of calling + * {@link #visitChildren} on {@code ctx}.

    + */ + @Override public T visitRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx) { return visitChildren(ctx); } /** * {@inheritDoc} * diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java index b91810506d10a..0ab831b10b68e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java @@ -855,6 +855,18 @@ public interface EsqlBaseParserListener extends ParseTreeListener { * @param ctx the parse tree */ void exitLikeListExpression(EsqlBaseParser.LikeListExpressionContext ctx); + /** + * Enter a parse tree produced by the {@code rlikeListExpression} + * labeled alternative in {@link EsqlBaseParser#regexBooleanExpression}. + * @param ctx the parse tree + */ + void enterRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx); + /** + * Exit a parse tree produced by the {@code rlikeListExpression} + * labeled alternative in {@link EsqlBaseParser#regexBooleanExpression}. + * @param ctx the parse tree + */ + void exitRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx); /** * Enter a parse tree produced by {@link EsqlBaseParser#matchBooleanExpression}. * @param ctx the parse tree diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java index 73a849aac7bf5..b9d15fcd37b76 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java @@ -518,6 +518,13 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor { * @return the visitor result */ T visitLikeListExpression(EsqlBaseParser.LikeListExpressionContext ctx); + /** + * Visit a parse tree produced by the {@code rlikeListExpression} + * labeled alternative in {@link EsqlBaseParser#regexBooleanExpression}. + * @param ctx the parse tree + * @return the visitor result + */ + T visitRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx); /** * Visit a parse tree produced by {@link EsqlBaseParser#matchBooleanExpression}. * @param ctx the parse tree diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java index cec23786f84dc..dc60a6dbbfa0a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePatternList; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPatternList; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -44,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLikeList; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; @@ -748,7 +750,7 @@ public Expression visitRlikeExpression(EsqlBaseParser.RlikeExpressionContext ctx RLike rLike = new RLike(source, left, new RLikePattern(BytesRefs.toString(patternLiteral.fold(FoldContext.small())))); return ctx.NOT() == null ? rLike : new Not(source, rLike); } catch (InvalidArgumentException e) { - throw new ParsingException(source, "Invalid pattern for LIKE [{}]: [{}]", patternLiteral, e.getMessage()); + throw new ParsingException(source, "Invalid pattern for RLIKE [{}]: [{}]", patternLiteral, e.getMessage()); } } @@ -781,6 +783,21 @@ public Expression visitLikeListExpression(EsqlBaseParser.LikeListExpressionConte return ctx.NOT() == null ? e : new Not(source, e); } + @Override + public Expression visitRlikeListExpression(EsqlBaseParser.RlikeListExpressionContext ctx) { + Source source = source(ctx); + Expression left = expression(ctx.valueExpression()); + List rLikePatterns = ctx.string() + .stream() + .map(x -> new RLikePattern(BytesRefs.toString(visitString(x).fold(FoldContext.small())))) + .toList(); + // for now we will use the old WildcardLike function for one argument case to allow compatibility in mixed version deployments + Expression e = rLikePatterns.size() == 1 + ? new RLike(source, left, rLikePatterns.getFirst()) + : new RLikeList(source, left, new RLikePatternList(rLikePatterns)); + return ctx.NOT() == null ? e : new Not(source, e); + } + @Override public Order visitOrderExpression(EsqlBaseParser.OrderExpressionContext ctx) { return new Order( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index 74309fa0bdb85..9e232bd8f02f9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -391,21 +391,22 @@ public PlanFactory visitWhereCommand(EsqlBaseParser.WhereCommandContext ctx) { public PlanFactory visitLimitCommand(EsqlBaseParser.LimitCommandContext ctx) { Source source = source(ctx); Object val = expression(ctx.constant()).fold(FoldContext.small() /* TODO remove me */); - if (val instanceof Integer i) { - if (i < 0) { - throw new ParsingException(source, "Invalid value for LIMIT [" + i + "], expecting a non negative integer"); - } + if (val instanceof Integer i && i >= 0) { return input -> new Limit(source, new Literal(source, i, DataType.INTEGER), input); - } else { - throw new ParsingException( - source, - "Invalid value for LIMIT [" - + BytesRefs.toString(val) - + ": " - + (expression(ctx.constant()).dataType() == KEYWORD ? "String" : val.getClass().getSimpleName()) - + "], expecting a non negative integer" - ); } + + String valueType = expression(ctx.constant()).dataType().typeName(); + + throw new ParsingException( + source, + "value of [" + + source.text() + + "] must be a non negative integer, found value [" + + ctx.constant().getText() + + "] type [" + + valueType + + "]" + ); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java index 11e9a57064e5b..7307fd8efad39 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.GeneratingPlan; +import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin; import java.io.IOException; import java.util.ArrayList; @@ -295,23 +296,43 @@ public BiConsumer postAnalysisPlanVerification() { * retaining the originating cluster and restructing pages for routing, which might be complicated. */ private static void checkRemoteEnrich(LogicalPlan plan, Failures failures) { - boolean[] agg = { false }; - boolean[] enrichCoord = { false }; + // First look for remote ENRICH, and then look at its children. Going over the whole plan once is trickier as remote ENRICHs can be + // in separate FORK branches which are valid by themselves. + plan.forEachUp(Enrich.class, enrich -> checkForPlansForbiddenBeforeRemoteEnrich(enrich, failures)); + } + + /** + * For a given remote {@link Enrich}, check if there are any forbidden plans upstream. + */ + private static void checkForPlansForbiddenBeforeRemoteEnrich(Enrich enrich, Failures failures) { + if (enrich.mode != Mode.REMOTE) { + return; + } + + // TODO: shouldn't we also include FORK? Everything downstream from FORK should be coordinator-only. + // https://github.com/elastic/elasticsearch/issues/131445 + boolean[] aggregate = { false }; + boolean[] coordinatorOnlyEnrich = { false }; + boolean[] lookupJoin = { false }; - plan.forEachUp(UnaryPlan.class, u -> { + enrich.forEachUp(LogicalPlan.class, u -> { if (u instanceof Aggregate) { - agg[0] = true; - } else if (u instanceof Enrich enrich && enrich.mode() == Enrich.Mode.COORDINATOR) { - enrichCoord[0] = true; - } - if (u instanceof Enrich enrich && enrich.mode() == Enrich.Mode.REMOTE) { - if (agg[0]) { - failures.add(fail(enrich, "ENRICH with remote policy can't be executed after STATS")); - } - if (enrichCoord[0]) { - failures.add(fail(enrich, "ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); - } + aggregate[0] = true; + } else if (u instanceof Enrich upstreamEnrich && upstreamEnrich.mode() == Enrich.Mode.COORDINATOR) { + coordinatorOnlyEnrich[0] = true; + } else if (u instanceof LookupJoin) { + lookupJoin[0] = true; } }); + + if (aggregate[0]) { + failures.add(fail(enrich, "ENRICH with remote policy can't be executed after STATS")); + } + if (coordinatorOnlyEnrich[0]) { + failures.add(fail(enrich, "ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); + } + if (lookupJoin[0]) { + failures.add(fail(enrich, "ENRICH with remote policy can't be executed after LOOKUP JOIN")); + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java index 6d0991a24a36c..8b1bebad97cef 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java @@ -18,13 +18,10 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Objects; @@ -191,27 +188,7 @@ public List output() { @Override protected AttributeSet computeReferences() { - return mode.isInputPartial() - ? AttributeSet.of(intermediateAttributes) - : Aggregate.computeReferences(aggregates, groupings).subtract(AttributeSet.of(ordinalAttributes())); - } - - /** Returns the attributes that can be loaded from ordinals -- no explicit extraction is needed */ - public List ordinalAttributes() { - List orginalAttributs = new ArrayList<>(groupings.size()); - // Ordinals can be leveraged just for a single grouping. If there are multiple groupings, fields need to be laoded for the - // hash aggregator. - // CATEGORIZE requires the standard hash aggregator as well. - if (groupings().size() == 1 && groupings.get(0).anyMatch(e -> e instanceof Categorize) == false) { - var leaves = new HashSet<>(); - aggregates.stream().filter(a -> groupings.contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves())); - groupings.forEach(g -> { - if (leaves.contains(g) == false) { - orginalAttributs.add((Attribute) g); - } - }); - } - return orginalAttributs; + return mode.isInputPartial() ? AttributeSet.of(intermediateAttributes) : Aggregate.computeReferences(aggregates, groupings); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index a78a15bf1ca48..77df90f23b10c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -18,6 +18,7 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; @@ -28,10 +29,14 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext; @@ -56,6 +61,21 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper this.analysisRegistry = analysisRegistry; } + private static Bucket findBucket(AggregateExec aggregateExec, NameId bucketId) { + Holder foundBucket = new Holder<>(); + aggregateExec.forEachExpressionDown(NamedExpression.class, ne -> { + if (ne.id().equals(bucketId)) { + if (ne.children().size() > 0 && ne.children().get(0) instanceof Bucket bucket) { + foundBucket.set(bucket); + } else if (ne.children().size() > 0 && ne.children().get(0) instanceof Round round) { + // TODO: Why is this hack needed? + foundBucket.set((Bucket) round.field()); + } + } + }); + return foundBucket.get(); + } + @Override public final PhysicalOperation groupingPhysicalOperation( AggregateExec aggregateExec, @@ -99,6 +119,7 @@ public final PhysicalOperation groupingPhysicalOperation( List aggregatorFactories = new ArrayList<>(); List groupSpecs = new ArrayList<>(aggregateExec.groupings().size()); for (Expression group : aggregateExec.groupings()) { + Bucket bucket = findBucket(aggregateExec, ((ReferenceAttribute) group).id()); Attribute groupAttribute = Expressions.attribute(group); // In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different. Attribute sourceGroupAttribute = (aggregatorMode.isInputPartial() == false @@ -143,7 +164,7 @@ else if (aggregatorMode.isOutputPartial()) { } layout.append(groupAttributeLayout); Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id()); - groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group)); + groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group, bucket)); } if (aggregatorMode == AggregatorMode.FINAL) { @@ -174,16 +195,6 @@ else if (aggregatorMode.isOutputPartial()) { groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(), context ); - // ordinal grouping - } else if (groupSpecs.size() == 1 && groupSpecs.get(0).channel == null) { - operatorFactory = ordinalGroupingOperatorFactory( - source, - aggregateExec, - aggregatorFactories, - groupSpecs.get(0).attribute, - groupSpecs.get(0).elementType(), - context - ); } else { operatorFactory = new HashAggregationOperatorFactory( groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(), @@ -348,13 +359,20 @@ private static AggregatorFunctionSupplier supplier(AggregateFunction aggregateFu * @param attribute The attribute, source of this group * @param expression The expression being used to group */ - private record GroupSpec(Integer channel, Attribute attribute, Expression expression) { + private record GroupSpec(Integer channel, Attribute attribute, Expression expression, @Nullable Bucket bucket) { BlockHash.GroupSpec toHashGroupSpec() { if (channel == null) { throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); } - return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null); + Expression unwrappedExpression = Alias.unwrap(expression); + if (unwrappedExpression instanceof Categorize categorize) { + return new BlockHash.GroupSpec(channel, elementType(), categorize.categorizeDef()); + } else if (bucket != null && bucket.emitEmptyBuckets() != null) { + return new BlockHash.GroupSpec(channel, elementType(), bucket.createEmptyBucketGenerator()); + } else { + return new BlockHash.GroupSpec(channel, elementType()); + } } ElementType elementType() { @@ -362,18 +380,6 @@ ElementType elementType() { } } - /** - * Build a grouping operator that operates on ordinals if possible. - */ - public abstract Operator.OperatorFactory ordinalGroupingOperatorFactory( - PhysicalOperation source, - AggregateExec aggregateExec, - List aggregatorFactories, - Attribute attrSource, - ElementType groupType, - LocalExecutionPlannerContext context - ); - public abstract Operator.OperatorFactory timeSeriesAggregatorOperatorFactory( TimeSeriesAggregateExec ts, AggregatorMode aggregatorMode, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 3f403d3e4fcd2..e0b570267899b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -22,7 +22,6 @@ import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.LuceneCountOperator; import org.elasticsearch.compute.lucene.LuceneOperator; import org.elasticsearch.compute.lucene.LuceneSliceQueue; @@ -32,7 +31,6 @@ import org.elasticsearch.compute.lucene.read.TimeSeriesExtractFieldOperator; import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.TimeSeriesAggregationOperator; import org.elasticsearch.core.AbstractRefCounted; @@ -66,8 +64,7 @@ import org.elasticsearch.xpack.esql.core.type.KeywordEsField; import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; import org.elasticsearch.xpack.esql.core.type.PotentiallyUnmappedKeywordEsField; -import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; -import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.Sort; import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; @@ -89,7 +86,6 @@ import static org.elasticsearch.common.lucene.search.Queries.newNonNestedFilter; import static org.elasticsearch.compute.lucene.LuceneSourceOperator.NO_LIMIT; -import static org.elasticsearch.index.mapper.MappedFieldType.FieldExtractPreference.NONE; public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProviders { private static final Logger logger = LogManager.getLogger(EsPhysicalOperationProviders.class); @@ -143,17 +139,17 @@ public boolean hasReferences() { } private final List shardContexts; - private final DataPartitioning defaultDataPartitioning; + private final PhysicalSettings physicalSettings; public EsPhysicalOperationProviders( FoldContext foldContext, List shardContexts, AnalysisRegistry analysisRegistry, - DataPartitioning defaultDataPartitioning + PhysicalSettings physicalSettings ) { super(foldContext, analysisRegistry); this.shardContexts = shardContexts; - this.defaultDataPartitioning = defaultDataPartitioning; + this.physicalSettings = physicalSettings; } @Override @@ -178,7 +174,10 @@ public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fi // TODO: consolidate with ValuesSourceReaderOperator return source.with(new TimeSeriesExtractFieldOperator.Factory(fields, shardContexts), layout.build()); } else { - return source.with(new ValuesSourceReaderOperator.Factory(fields, readers, docChannel), layout.build()); + return source.with( + new ValuesSourceReaderOperator.Factory(physicalSettings.valuesLoadingJumboSize(), fields, readers, docChannel), + layout.build() + ); } } @@ -202,7 +201,7 @@ private BlockLoader getBlockLoaderFor(int shardId, Attribute attr, MappedFieldTy Expression conversion = unionTypes.getConversionExpressionForIndex(indexName); return conversion == null ? BlockLoader.CONSTANT_NULLS - : new TypeConvertingBlockLoader(blockLoader, (AbstractConvertFunction) conversion); + : new TypeConvertingBlockLoader(blockLoader, (EsqlScalarFunction) conversion); } return blockLoader; } @@ -281,7 +280,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, luceneFactory = new LuceneTopNSourceOperator.Factory( shardContexts, querySupplier(esQueryExec.query()), - context.queryPragmas().dataPartitioning(defaultDataPartitioning), + context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), context.pageSize(rowEstimatedSize), limit, @@ -292,7 +291,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, luceneFactory = new LuceneSourceOperator.Factory( shardContexts, querySupplier(esQueryExec.query()), - context.queryPragmas().dataPartitioning(defaultDataPartitioning), + context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), context.pageSize(rowEstimatedSize), limit, @@ -344,45 +343,12 @@ public LuceneCountOperator.Factory countSource(LocalExecutionPlannerContext cont return new LuceneCountOperator.Factory( shardContexts, querySupplier(queryBuilder), - context.queryPragmas().dataPartitioning(defaultDataPartitioning), + context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), limit == null ? NO_LIMIT : (Integer) limit.fold(context.foldCtx()) ); } - @Override - public final Operator.OperatorFactory ordinalGroupingOperatorFactory( - LocalExecutionPlanner.PhysicalOperation source, - AggregateExec aggregateExec, - List aggregatorFactories, - Attribute attrSource, - ElementType groupElementType, - LocalExecutionPlannerContext context - ) { - var sourceAttribute = FieldExtractExec.extractSourceAttributesFrom(aggregateExec.child()); - int docChannel = source.layout.get(sourceAttribute.id()).channel(); - List vsShardContexts = shardContexts.stream() - .map( - s -> new ValuesSourceReaderOperator.ShardContext( - s.searcher().getIndexReader(), - s::newSourceLoader, - s.storedFieldsSequentialProportion() - ) - ) - .toList(); - // The grouping-by values are ready, let's group on them directly. - // Costin: why are they ready and not already exposed in the layout? - return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory( - shardIdx -> getBlockLoaderFor(shardIdx, attrSource, NONE), - vsShardContexts, - groupElementType, - docChannel, - attrSource.name(), - aggregatorFactories, - context.pageSize(aggregateExec.estimatedRowSize()) - ); - } - @Override public Operator.OperatorFactory timeSeriesAggregatorOperatorFactory( TimeSeriesAggregateExec ts, @@ -542,9 +508,9 @@ private static class TypeConvertingBlockLoader implements BlockLoader { private final BlockLoader delegate; private final TypeConverter typeConverter; - protected TypeConvertingBlockLoader(BlockLoader delegate, AbstractConvertFunction convertFunction) { + protected TypeConvertingBlockLoader(BlockLoader delegate, EsqlScalarFunction convertFunction) { this.delegate = delegate; - this.typeConverter = TypeConverter.fromConvertFunction(convertFunction); + this.typeConverter = TypeConverter.fromScalarFunction(convertFunction); } @Override @@ -566,8 +532,8 @@ public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) throws } return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - Block block = reader.read(factory, docs); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + Block block = reader.read(factory, docs, offset); return typeConverter.convert((org.elasticsearch.compute.data.Block) block); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index ad6cb42f7f835..28204e2572842 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -528,7 +528,7 @@ private PhysicalOperation planEval(EvalExec eval, LocalExecutionPlannerContext c PhysicalOperation source = plan(eval.child(), context); for (Alias field : eval.fields()) { - var evaluatorSupplier = EvalMapper.toEvaluator(context.foldCtx(), field.child(), source.layout); + var evaluatorSupplier = EvalMapper.toEvaluator(context.foldCtx(), field.child(), source.layout, context.shardContexts); Layout.Builder layout = source.layout.builder(); layout.append(field.toAttribute()); source = source.with(new EvalOperatorFactory(evaluatorSupplier), layout.build()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PhysicalSettings.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PhysicalSettings.java new file mode 100644 index 0000000000000..4276eeaf39f9b --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PhysicalSettings.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.planner; + +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.MemorySizeValue; +import org.elasticsearch.compute.lucene.DataPartitioning; +import org.elasticsearch.monitor.jvm.JvmInfo; + +/** + * Values for cluster level settings used in physical planning. + */ +public class PhysicalSettings { + public static final Setting DEFAULT_DATA_PARTITIONING = Setting.enumSetting( + DataPartitioning.class, + "esql.default_data_partitioning", + DataPartitioning.AUTO, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting VALUES_LOADING_JUMBO_SIZE = new Setting<>("esql.values_loading_jumbo_size", settings -> { + long proportional = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() / 1024; + return ByteSizeValue.ofBytes(Math.max(proportional, ByteSizeValue.ofMb(1).getBytes())).getStringRep(); + }, + s -> MemorySizeValue.parseBytesSizeValueOrHeapRatio(s, "esql.values_loading_jumbo_size"), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private volatile DataPartitioning defaultDataPartitioning; + private volatile ByteSizeValue valuesLoadingJumboSize; + + /** + * Ctor for prod that listens for updates from the {@link ClusterService}. + */ + public PhysicalSettings(ClusterService clusterService) { + clusterService.getClusterSettings().initializeAndWatch(DEFAULT_DATA_PARTITIONING, v -> this.defaultDataPartitioning = v); + clusterService.getClusterSettings().initializeAndWatch(VALUES_LOADING_JUMBO_SIZE, v -> this.valuesLoadingJumboSize = v); + } + + /** + * Ctor for testing. + */ + public PhysicalSettings(DataPartitioning defaultDataPartitioning, ByteSizeValue valuesLoadingJumboSize) { + this.defaultDataPartitioning = defaultDataPartitioning; + this.valuesLoadingJumboSize = valuesLoadingJumboSize; + } + + public DataPartitioning defaultDataPartitioning() { + return defaultDataPartitioning; + } + + public ByteSizeValue valuesLoadingJumboSize() { + return valuesLoadingJumboSize; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java index 4dea8a50b5c17..31b95bcc5f860 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java @@ -17,7 +17,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; -import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; class TypeConverter { private final String evaluatorName; @@ -28,7 +28,7 @@ private TypeConverter(String evaluatorName, ExpressionEvaluator convertEvaluator this.convertEvaluator = convertEvaluator; } - public static TypeConverter fromConvertFunction(AbstractConvertFunction convertFunction) { + public static TypeConverter fromScalarFunction(EsqlScalarFunction convertFunction) { DriverContext driverContext1 = new DriverContext( BigArrays.NON_RECYCLING_INSTANCE, new org.elasticsearch.compute.data.BlockFactory( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java index 4d1d65d63932d..bf6f0b89efbec 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java @@ -87,7 +87,7 @@ private PhysicalPlan mapUnary(UnaryPlan unary) { PhysicalPlan mappedChild = map(unary.child()); // - // TODO - this is hard to follow and needs reworking + // TODO - this is hard to follow, causes bugs and needs reworking // https://github.com/elastic/elasticsearch/issues/115897 // if (unary instanceof Enrich enrich && enrich.mode() == Enrich.Mode.REMOTE) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java index 5064b2bbd101a..4e8a89d024b71 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java @@ -166,7 +166,7 @@ private void updateExecutionInfo(EsqlExecutionInfo executionInfo, String cluster builder.setTook(executionInfo.tookSoFar()); } if (v.getStatus() == EsqlExecutionInfo.Cluster.Status.RUNNING) { - builder.setFailures(resp.failures); + builder.addFailures(resp.failures); if (executionInfo.isStopped() || resp.failedShards > 0 || resp.failures.isEmpty() == false) { builder.setStatus(EsqlExecutionInfo.Cluster.Status.PARTIAL); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 39e3503b5fdd9..d12799ab8b170 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.plugin; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.search.SearchRequest; @@ -18,7 +19,6 @@ import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.operator.DriverCompletionInfo; import org.elasticsearch.compute.operator.DriverTaskRunner; import org.elasticsearch.compute.operator.FailureCollector; @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.session.EsqlCCSUtils; @@ -137,8 +138,7 @@ public class ComputeService { private final DataNodeComputeHandler dataNodeComputeHandler; private final ClusterComputeHandler clusterComputeHandler; private final ExchangeService exchangeService; - - private volatile DataPartitioning defaultDataPartitioning; + private final PhysicalSettings physicalSettings; @SuppressWarnings("this-escape") public ComputeService( @@ -177,7 +177,7 @@ public ComputeService( esqlExecutor, dataNodeComputeHandler ); - clusterService.getClusterSettings().initializeAndWatch(EsqlPlugin.DEFAULT_DATA_PARTITIONING, v -> this.defaultDataPartitioning = v); + this.physicalSettings = new PhysicalSettings(clusterService); } public void execute( @@ -375,9 +375,10 @@ public void executePlan( var computeListener = new ComputeListener( transportService.getThreadPool(), cancelQueryOnFailure, - listener.map(completionInfo -> { + listener.delegateFailureAndWrap((l, completionInfo) -> { + failIfAllShardsFailed(execInfo, collectedPages); execInfo.markEndQuery(); // TODO: revisit this time recording model as part of INLINESTATS improvements - return new Result(outputAttributes, collectedPages, completionInfo, execInfo); + l.onResponse(new Result(outputAttributes, collectedPages, completionInfo, execInfo)); }) ) ) { @@ -395,9 +396,13 @@ public void executePlan( var builder = new EsqlExecutionInfo.Cluster.Builder(v).setTook(tookTime); if (v.getStatus() == EsqlExecutionInfo.Cluster.Status.RUNNING) { final Integer failedShards = execInfo.getCluster(LOCAL_CLUSTER).getFailedShards(); - var status = localClusterWasInterrupted.get() || (failedShards != null && failedShards > 0) - ? EsqlExecutionInfo.Cluster.Status.PARTIAL - : EsqlExecutionInfo.Cluster.Status.SUCCESSFUL; + // Set the local cluster status (including the final driver) to partial if the query was stopped + // or encountered resolution or execution failures. + var status = localClusterWasInterrupted.get() + || (failedShards != null && failedShards > 0) + || v.getFailures().isEmpty() == false + ? EsqlExecutionInfo.Cluster.Status.PARTIAL + : EsqlExecutionInfo.Cluster.Status.SUCCESSFUL; builder.setStatus(status); } return builder.build(); @@ -445,7 +450,7 @@ public void executePlan( .setSuccessfulShards(r.getSuccessfulShards()) .setSkippedShards(r.getSkippedShards()) .setFailedShards(r.getFailedShards()) - .setFailures(r.failures) + .addFailures(r.failures) .build() ); dataNodesListener.onResponse(r.getCompletionInfo()); @@ -455,7 +460,7 @@ public void executePlan( LOCAL_CLUSTER, (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setStatus( EsqlExecutionInfo.Cluster.Status.PARTIAL - ).setFailures(List.of(new ShardSearchFailure(e))).build() + ).addFailures(List.of(new ShardSearchFailure(e))).build() ); dataNodesListener.onResponse(DriverCompletionInfo.EMPTY); } else { @@ -536,6 +541,47 @@ private static void updateExecutionInfoAfterCoordinatorOnlyQuery(EsqlExecutionIn } } + /** + * If all of target shards excluding the skipped shards failed from the local or remote clusters, then we should fail the entire query + * regardless of the partial_results configuration or skip_unavailable setting. This behavior doesn't fully align with the search API, + * which doesn't consider the failures from the remote clusters when skip_unavailable is true. + */ + static void failIfAllShardsFailed(EsqlExecutionInfo execInfo, List finalResults) { + // do not fail if any final result has results + if (finalResults.stream().anyMatch(p -> p.getPositionCount() > 0)) { + return; + } + int totalFailedShards = 0; + for (EsqlExecutionInfo.Cluster cluster : execInfo.clusterInfo.values()) { + final Integer successfulShards = cluster.getSuccessfulShards(); + if (successfulShards != null && successfulShards > 0) { + return; + } + if (cluster.getFailedShards() != null) { + totalFailedShards += cluster.getFailedShards(); + } + } + if (totalFailedShards == 0) { + return; + } + final var failureCollector = new FailureCollector(); + for (EsqlExecutionInfo.Cluster cluster : execInfo.clusterInfo.values()) { + var failedShards = cluster.getFailedShards(); + if (failedShards != null && failedShards > 0) { + assert cluster.getFailures().isEmpty() == false : "expected failures for cluster [" + cluster.getClusterAlias() + "]"; + for (ShardSearchFailure failure : cluster.getFailures()) { + if (failure.getCause() instanceof Exception e) { + failureCollector.unwrapAndCollect(e); + } else { + assert false : "unexpected failure: " + new AssertionError(failure.getCause()); + failureCollector.unwrapAndCollect(failure); + } + } + } + } + ExceptionsHelper.reThrowIfNotNull(failureCollector.getFailure()); + } + void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener listener) { listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts())); List contexts = new ArrayList<>(context.searchContexts().size()); @@ -561,7 +607,7 @@ public SourceProvider createSourceProvider() { context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis(), - defaultDataPartitioning + physicalSettings ); try { LocalExecutionPlanner planner = new LocalExecutionPlanner( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java index 7cba5eeb56278..f2f5b6b640311 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java @@ -21,7 +21,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockFactoryProvider; -import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.LuceneOperator; import org.elasticsearch.compute.lucene.TimeSeriesSourceOperator; import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperatorStatus; @@ -75,6 +74,7 @@ import org.elasticsearch.xpack.esql.io.stream.ExpressionQueryBuilder; import org.elasticsearch.xpack.esql.io.stream.PlanStreamWrapperQueryBuilder; import org.elasticsearch.xpack.esql.plan.PlanWritables; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.querylog.EsqlQueryLog; import org.elasticsearch.xpack.esql.session.IndexResolver; @@ -160,14 +160,6 @@ public class EsqlPlugin extends Plugin implements ActionPlugin, ExtensiblePlugin Setting.Property.Dynamic ); - public static final Setting DEFAULT_DATA_PARTITIONING = Setting.enumSetting( - DataPartitioning.class, - "esql.default_data_partitioning", - DataPartitioning.AUTO, - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); - /** * Tuning parameter for deciding when to use the "merge" stored field loader. * Think of it as "how similar to a sequential block of documents do I have to @@ -263,7 +255,8 @@ public List> getSettings() { ESQL_QUERYLOG_THRESHOLD_INFO_SETTING, ESQL_QUERYLOG_THRESHOLD_WARN_SETTING, ESQL_QUERYLOG_INCLUDE_USER_SETTING, - DEFAULT_DATA_PARTITIONING, + PhysicalSettings.DEFAULT_DATA_PARTITIONING, + PhysicalSettings.VALUES_LOADING_JUMBO_SIZE, STORED_FIELDS_SEQUENTIAL_PROPORTION, EsqlFlags.ESQL_STRING_LIKE_ON_INDEX ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java index 345bf3b8767ef..bdd0e382c3fd3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import java.io.IOException; import java.util.Locale; @@ -45,7 +46,7 @@ public final class QueryPragmas implements Writeable { * the enum {@link DataPartitioning} which has more documentation. Not an * {@link Setting#enumSetting} because those can't have {@code null} defaults. * {@code null} here means "use the default from the cluster setting - * named {@link EsqlPlugin#DEFAULT_DATA_PARTITIONING}." + * named {@link PhysicalSettings#DEFAULT_DATA_PARTITIONING}." */ public static final Setting DATA_PARTITIONING = Setting.simpleString("data_partitioning"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java index 2946af2ac5c23..b218b897121df 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -13,7 +13,9 @@ import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -27,15 +29,17 @@ public class KnnQuery extends Query { private final String field; private final float[] query; private final Map options; + private final List filterQueries; public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; - public KnnQuery(Source source, String field, float[] query, Map options) { + public KnnQuery(Source source, String field, float[] query, Map options, List filterQueries) { super(source); assert options != null; this.field = field; this.query = query; this.options = options; + this.filterQueries = new ArrayList<>(filterQueries); } @Override @@ -50,6 +54,9 @@ protected QueryBuilder asBuilder() { Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName()); KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity); + for (QueryBuilder filter : filterQueries) { + queryBuilder.addFilterQuery(filter); + } Number boost = (Number) options.get(BOOST_FIELD.getPreferredName()); if (boost != null) { queryBuilder.boost(boost.floatValue()); @@ -66,15 +73,17 @@ protected String innerToString() { public boolean equals(Object o) { if (super.equals(o) == false) return false; + if (o == null || getClass() != o.getClass()) return false; KnnQuery knnQuery = (KnnQuery) o; return Objects.equals(field, knnQuery.field) && Objects.deepEquals(query, knnQuery.query) - && Objects.equals(options, knnQuery.options); + && Objects.equals(options, knnQuery.options) + && Objects.equals(filterQueries, knnQuery.filterQueries); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options); + return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options, filterQueries); } @Override @@ -86,4 +95,8 @@ public boolean scorable() { public boolean containsPlan() { return false; } + + public List filterQueries() { + return filterQueries; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java index 7df5a029d724e..e4a5423d35f8e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java @@ -175,7 +175,7 @@ protected final ExecutionInfo executeWithInfo(TreeType plan) { if (tf.hasChanged()) { hasChanged = true; if (log.isTraceEnabled()) { - log.trace("Rule {} applied\n{}", rule, NodeUtils.diffString(tf.before, tf.after)); + log.trace("Rule {} applied with change\n{}", rule, NodeUtils.diffString(tf.before, tf.after)); } } else { if (log.isTraceEnabled()) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java index 0f66a839bb429..183dccf48d5ac 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.time.Instant; import java.time.ZoneId; +import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Locale; import java.util.Map; @@ -34,6 +35,7 @@ public class Configuration implements Writeable { public static final int QUERY_COMPRESS_THRESHOLD_CHARS = KB.toIntBytes(5); + public static final ZoneId DEFAULT_TZ = ZoneOffset.UTC; private final String clusterName; private final String username; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtils.java index 901057f4db61c..69d7d5999db63 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtils.java @@ -15,11 +15,9 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.compute.operator.DriverCompletionInfo; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexNotFoundException; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.IndicesExpressionGrouper; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.transport.ConnectTransportException; @@ -35,27 +33,36 @@ import org.elasticsearch.xpack.esql.plan.IndexPattern; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; public class EsqlCCSUtils { private EsqlCCSUtils() {} - static Map determineUnavailableRemoteClusters(List failures) { - Map unavailableRemotes = new HashMap<>(); + static Map> groupFailuresPerCluster(List failures) { + Map> perCluster = new HashMap<>(); for (FieldCapabilitiesFailure failure : failures) { - if (ExceptionsHelper.isRemoteUnavailableException(failure.getException())) { - for (String indexExpression : failure.getIndices()) { - if (indexExpression.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) > 0) { - unavailableRemotes.put(RemoteClusterAware.parseClusterAlias(indexExpression), failure); - } - } + String cluster = RemoteClusterAware.parseClusterAlias(failure.getIndices()[0]); + perCluster.computeIfAbsent(cluster, k -> new ArrayList<>()).add(failure); + } + return perCluster; + } + + static Map determineUnavailableRemoteClusters(Map> failures) { + Map unavailableRemotes = new HashMap<>(failures.size()); + for (var e : failures.entrySet()) { + if (Strings.isEmpty(e.getKey())) { + continue; + } + if (e.getValue().stream().allMatch(f -> ExceptionsHelper.isRemoteUnavailableException(f.getException()))) { + unavailableRemotes.put(e.getKey(), e.getValue().get(0)); } } return unavailableRemotes; @@ -136,8 +143,8 @@ static void updateExecutionInfoToReturnEmptyResult(EsqlExecutionInfo executionIn } else { builder.setStatus(EsqlExecutionInfo.Cluster.Status.SKIPPED); // add this exception to the failures list only if there is no failure already recorded there - if (v.getFailures() == null || v.getFailures().size() == 0) { - builder.setFailures(List.of(new ShardSearchFailure(exceptionForResponse))); + if (v.getFailures().isEmpty()) { + builder.addFailures(List.of(new ShardSearchFailure(exceptionForResponse))); } } return builder.build(); @@ -169,7 +176,11 @@ static String createIndexExpressionFromAvailableClusters(EsqlExecutionInfo execu } } - static void updateExecutionInfoWithUnavailableClusters(EsqlExecutionInfo execInfo, Map unavailable) { + static void updateExecutionInfoWithUnavailableClusters( + EsqlExecutionInfo execInfo, + Map> failures + ) { + Map unavailable = determineUnavailableRemoteClusters(failures); for (Map.Entry entry : unavailable.entrySet()) { String clusterAlias = entry.getKey(); boolean skipUnavailable = execInfo.getCluster(clusterAlias).isSkipUnavailable(); @@ -188,18 +199,17 @@ static void updateExecutionInfoWithUnavailableClusters(EsqlExecutionInfo execInf static void updateExecutionInfoWithClustersWithNoMatchingIndices( EsqlExecutionInfo executionInfo, IndexResolution indexResolution, - QueryBuilder filter + boolean usedFilter ) { - Set clustersWithResolvedIndices = new HashSet<>(); - // determine missing clusters + // Get the clusters which are still running, and we will check whether they have any matching indices. + // NOTE: we assume that updateExecutionInfoWithUnavailableClusters() was already run and took care of unavailable clusters. + final Set clustersWithNoMatchingIndices = executionInfo.getClusterStates(Cluster.Status.RUNNING) + .map(Cluster::getClusterAlias) + .collect(Collectors.toSet()); for (String indexName : indexResolution.resolvedIndices()) { - clustersWithResolvedIndices.add(RemoteClusterAware.parseClusterAlias(indexName)); + clustersWithNoMatchingIndices.remove(RemoteClusterAware.parseClusterAlias(indexName)); } - Set clustersRequested = executionInfo.clusterAliases(); - Set clustersWithNoMatchingIndices = Sets.difference(clustersRequested, clustersWithResolvedIndices); - clustersWithNoMatchingIndices.removeAll(indexResolution.unavailableClusters().keySet()); - - /** + /* * Rules enforced at planning time around non-matching indices * 1. fail query if no matching indices on any cluster (VerificationException) - that is handled elsewhere * 2. fail query if a cluster has no matching indices *and* a concrete index was specified - handled here @@ -211,24 +221,20 @@ static void updateExecutionInfoWithClustersWithNoMatchingIndices( * Mark it as SKIPPED with 0 shards searched and took=0. */ for (String c : clustersWithNoMatchingIndices) { - if (executionInfo.getCluster(c).getStatus() != Cluster.Status.RUNNING) { - // if cluster was already in a terminal state, we don't need to check it again - continue; - } final String indexExpression = executionInfo.getCluster(c).getIndexExpression(); if (concreteIndexRequested(executionInfo.getCluster(c).getIndexExpression())) { String error = Strings.format( "Unknown index [%s]", (c.equals(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) ? indexExpression : c + ":" + indexExpression) ); - if (executionInfo.isSkipUnavailable(c) == false || filter != null) { + if (executionInfo.isSkipUnavailable(c) == false || usedFilter) { if (fatalErrorMessage == null) { fatalErrorMessage = error; } else { fatalErrorMessage += "; " + error; } } - if (filter == null) { + if (usedFilter == false) { // We check for filter since the filter may be the reason why the index is missing, and then we don't want to mark yet markClusterWithFinalStateAndNoShards( executionInfo, @@ -238,10 +244,22 @@ static void updateExecutionInfoWithClustersWithNoMatchingIndices( ); } } else { + // We check for the valid resolution because if we have empty resolution it's still an error. if (indexResolution.isValid()) { - // no matching indices and no concrete index requested - just mark it as done, no error - // We check for the valid resolution because if we have empty resolution it's still an error. - markClusterWithFinalStateAndNoShards(executionInfo, c, Cluster.Status.SUCCESSFUL, null); + List failures = indexResolution.failures().getOrDefault(c, List.of()); + // No matching indices, no concrete index requested, and no error in field-caps; just mark as done. + if (failures.isEmpty()) { + markClusterWithFinalStateAndNoShards(executionInfo, c, Cluster.Status.SUCCESSFUL, null); + } else { + // skip reporting index_not_found exceptions to avoid spamming users with such errors + // when queries use a remote cluster wildcard, e.g., `*:my-logs*`. + Exception nonIndexNotFound = failures.stream() + .map(FieldCapabilitiesFailure::getException) + .filter(ex -> ExceptionsHelper.unwrap(ex, IndexNotFoundException.class) == null) + .findAny() + .orElse(null); + markClusterWithFinalStateAndNoShards(executionInfo, c, Cluster.Status.SKIPPED, nonIndexNotFound); + } } } } @@ -252,7 +270,7 @@ static void updateExecutionInfoWithClustersWithNoMatchingIndices( // Filter-less version, mainly for testing where we don't need filter support static void updateExecutionInfoWithClustersWithNoMatchingIndices(EsqlExecutionInfo executionInfo, IndexResolution indexResolution) { - updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution, null); + updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution, false); } // visible for testing @@ -360,7 +378,7 @@ public static void markClusterWithFinalStateAndNoShards( .setSkippedShards(Objects.requireNonNullElse(v.getSkippedShards(), 0)) .setFailedShards(Objects.requireNonNullElse(v.getFailedShards(), 0)); if (ex != null) { - builder.setFailures(List.of(new ShardSearchFailure(ex))); + builder.addFailures(List.of(new ShardSearchFailure(ex))); } return builder.build(); }); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 40a859e3f5b58..df18051bcf721 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -7,9 +7,13 @@ package org.elasticsearch.xpack.esql.session; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesFailure; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.Strings; @@ -20,15 +24,19 @@ import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverCompletionInfo; +import org.elasticsearch.compute.operator.FailureCollector; +import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.mapper.IndexModeFieldMapper; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesExpressionGrouper; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.xpack.esql.VerificationException; @@ -111,7 +119,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -205,9 +212,10 @@ public void execute(EsqlQueryRequest request, EsqlExecutionInfo executionInfo, P analyzedPlan(parsed, executionInfo, request.filter(), new EsqlCCSUtils.CssPartialErrorsActionListener(executionInfo, listener) { @Override public void onResponse(LogicalPlan analyzedPlan) { + LogicalPlan optimizedPlan = optimizedPlan(analyzedPlan); preMapper.preMapper( - analyzedPlan, - listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, optimizedPlan(p), l)) + optimizedPlan, + listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, p, l)) ); } }); @@ -331,6 +339,53 @@ private LogicalPlan parse(String query, QueryParams params) { return parsed; } + /** + * Associates errors that occurred during field-caps with the cluster info in the execution info. + * - Skips clusters that are no longer running, as they have already been marked as successful, skipped, or failed. + * - If allow_partial_results or skip_unavailable is enabled, stores the failures in the cluster info but allows execution to continue. + * - Otherwise, aborts execution with the failures. + */ + static void handleFieldCapsFailures( + boolean allowPartialResults, + EsqlExecutionInfo executionInfo, + Map> failures + ) throws Exception { + FailureCollector failureCollector = new FailureCollector(); + for (var e : failures.entrySet()) { + String clusterAlias = e.getKey(); + EsqlExecutionInfo.Cluster cluster = executionInfo.getCluster(clusterAlias); + if (cluster.getStatus() != EsqlExecutionInfo.Cluster.Status.RUNNING) { + assert cluster.getStatus() != EsqlExecutionInfo.Cluster.Status.SUCCESSFUL : "can't mark a cluster success with failures"; + continue; + } + if (allowPartialResults == false && executionInfo.isSkipUnavailable(clusterAlias) == false) { + for (FieldCapabilitiesFailure failure : e.getValue()) { + failureCollector.unwrapAndCollect(failure.getException()); + } + } else if (cluster.getFailures().isEmpty()) { + var shardFailures = e.getValue().stream().map(f -> { + ShardId shardId = null; + if (ExceptionsHelper.unwrapCause(f.getException()) instanceof ElasticsearchException es) { + shardId = es.getShardId(); + } + if (shardId != null) { + return new ShardSearchFailure(f.getException(), new SearchShardTarget(null, shardId, clusterAlias)); + } else { + return new ShardSearchFailure(f.getException()); + } + }).toList(); + executionInfo.swapCluster( + clusterAlias, + (k, curr) -> new EsqlExecutionInfo.Cluster.Builder(cluster).addFailures(shardFailures).build() + ); + } + } + Exception failure = failureCollector.getFailure(); + if (failure != null) { + throw failure; + } + } + public void analyzedPlan( LogicalPlan parsed, EsqlExecutionInfo executionInfo, @@ -342,7 +397,8 @@ public void analyzedPlan( return; } - Function analyzeAction = (l) -> { + CheckedFunction analyzeAction = (l) -> { + handleFieldCapsFailures(configuration.allowPartialResults(), executionInfo, l.indices.failures()); Analyzer analyzer = new Analyzer( new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution, l.inferenceResolution), verifier @@ -402,8 +458,8 @@ public void analyzedPlan( try { // the order here is tricky - if the cluster has been filtered and later became unavailable, // do we want to declare it successful or skipped? For now, unavailability takes precedence. - EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, result.indices.unavailableClusters()); - EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, result.indices, null); + EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, result.indices.failures()); + EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, result.indices, false); plan = analyzeAction.apply(result); } catch (Exception e) { l.onFailure(e); @@ -467,7 +523,7 @@ private PreAnalysisResult receiveLookupIndexResolution( EsqlExecutionInfo executionInfo, IndexResolution lookupIndexResolution ) { - EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, lookupIndexResolution.unavailableClusters()); + EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, lookupIndexResolution.failures()); if (lookupIndexResolution.isValid() == false) { // If the index resolution is invalid, don't bother with the rest of the analysis return result.addLookupIndexResolution(index, lookupIndexResolution); @@ -566,12 +622,7 @@ private IndexResolution checkSingleIndex( if (localIndexNames.size() == 1) { String indexName = localIndexNames.iterator().next(); EsIndex newIndex = new EsIndex(index, lookupIndexResolution.get().mapping(), Map.of(indexName, IndexMode.LOOKUP)); - return IndexResolution.valid( - newIndex, - newIndex.concreteIndices(), - lookupIndexResolution.getUnavailableShards(), - lookupIndexResolution.unavailableClusters() - ); + return IndexResolution.valid(newIndex, newIndex.concreteIndices(), lookupIndexResolution.failures()); } // validate remotes to be able to handle multiple indices in LOOKUP JOIN validateRemoteVersions(executionInfo); @@ -665,11 +716,7 @@ private void preAnalyzeMainIndices( result.fieldNames, requestFilter, listener.delegateFailure((l, indexResolution) -> { - if (configuration.allowPartialResults() == false && indexResolution.getUnavailableShards().isEmpty() == false) { - l.onFailure(indexResolution.getUnavailableShards().iterator().next()); - } else { - l.onResponse(result.withIndexResolution(indexResolution)); - } + l.onResponse(result.withIndexResolution(indexResolution)); }) ); } @@ -694,7 +741,7 @@ private boolean allCCSClustersSkipped( ActionListener logicalPlanListener ) { IndexResolution indexResolution = result.indices; - EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, indexResolution.unavailableClusters()); + EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, indexResolution.failures()); if (executionInfo.isCrossClusterSearch() && executionInfo.getClusterStates(EsqlExecutionInfo.Cluster.Status.RUNNING).findAny().isEmpty()) { // for a CCS, if all clusters have been marked as SKIPPED, nothing to search so send a sentinel Exception @@ -708,7 +755,7 @@ private boolean allCCSClustersSkipped( } private static void analyzeAndMaybeRetry( - Function analyzeAction, + CheckedFunction analyzeAction, QueryBuilder requestFilter, PreAnalysisResult result, EsqlExecutionInfo executionInfo, @@ -724,7 +771,7 @@ private static void analyzeAndMaybeRetry( if (result.indices.isValid() || requestFilter != null) { // We won't run this check with no filter and no valid indices since this may lead to false positive - missing index report // when the resolution result is not valid for a different reason. - EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, result.indices, requestFilter); + EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, result.indices, requestFilter != null); } plan = analyzeAction.apply(result); } catch (Exception e) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/IndexResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/IndexResolver.java index d2f79ceb1316f..16401574b0f58 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/IndexResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/IndexResolver.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.esql.session; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.NoShardAvailableActionException; -import org.elasticsearch.action.fieldcaps.FieldCapabilitiesFailure; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesIndexResponse; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; @@ -149,17 +147,6 @@ public static IndexResolution mergedMappings(String indexPattern, FieldCapabilit } } - Map unavailableRemotes = EsqlCCSUtils.determineUnavailableRemoteClusters( - fieldCapsResponse.getFailures() - ); - - Set unavailableShards = new HashSet<>(); - for (FieldCapabilitiesFailure failure : fieldCapsResponse.getFailures()) { - if (failure.getException() instanceof NoShardAvailableActionException e) { - unavailableShards.add(e); - } - } - Map concreteIndices = Maps.newMapWithExpectedSize(fieldCapsResponse.getIndexResponses().size()); for (FieldCapabilitiesIndexResponse ir : fieldCapsResponse.getIndexResponses()) { concreteIndices.put(ir.getIndexName(), ir.getIndexMode()); @@ -171,7 +158,8 @@ public static IndexResolution mergedMappings(String indexPattern, FieldCapabilit } // If all the mappings are empty we return an empty set of resolved indices to line up with QL var index = new EsIndex(indexPattern, rootFields, allEmpty ? Map.of() : concreteIndices, partiallyUnmappedFields); - return IndexResolution.valid(index, concreteIndices.keySet(), unavailableShards, unavailableRemotes); + var failures = EsqlCCSUtils.groupFailuresPerCluster(fieldCapsResponse.getFailures()); + return IndexResolution.valid(index, concreteIndices.keySet(), failures); } private static Map> collectFieldCaps(FieldCapabilitiesResponse fieldCapsResponse) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java index c308b317529ca..89aa2402248b8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchContextStats.java @@ -18,6 +18,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.NumericUtils; import org.elasticsearch.index.mapper.ConstantFieldType; import org.elasticsearch.index.mapper.DocCountFieldMapper.DocCountFieldType; import org.elasticsearch.index.mapper.IdFieldMapper; @@ -29,7 +30,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute.FieldName; -import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.Holder; import java.io.IOException; import java.util.LinkedHashMap; @@ -51,7 +52,11 @@ public class SearchContextStats implements SearchStats { private final List contexts; - private record FieldConfig(boolean exists, boolean hasExactSubfield, boolean indexed, boolean hasDocValues) {} + private record FieldConfig(boolean exists, boolean hasExactSubfield, boolean indexed, boolean hasDocValues, MappedFieldType fieldType) { + FieldConfig(boolean exists, boolean hasExactSubfield, boolean indexed, boolean hasDocValues) { + this(exists, hasExactSubfield, indexed, hasDocValues, null); + } + } private static class FieldStats { private Long count; @@ -93,11 +98,18 @@ private FieldConfig makeFieldConfig(String field) { boolean hasExactSubfield = true; boolean indexed = true; boolean hasDocValues = true; + boolean mixedFieldType = false; + MappedFieldType fieldType = null; // Extract the field type, it will be used by min/max later. // even if there are deleted documents, check the existence of a field // since if it's missing, deleted documents won't change that for (SearchExecutionContext context : contexts) { if (context.isFieldMapped(field)) { - var type = context.getFieldType(field); + MappedFieldType type = context.getFieldType(field); + if (fieldType == null) { + fieldType = type; + } else if (mixedFieldType == false && fieldType.typeName().equals(type.typeName()) == false) { + mixedFieldType = true; + } exists |= true; indexed &= type.isIndexed(); hasDocValues &= type.hasDocValues(); @@ -115,7 +127,7 @@ private FieldConfig makeFieldConfig(String field) { // if it does not exist on any context, no other settings are valid return new FieldConfig(false, false, false, false); } else { - return new FieldConfig(exists, hasExactSubfield, indexed, hasDocValues); + return new FieldConfig(exists, hasExactSubfield, indexed, hasDocValues, mixedFieldType ? null : fieldType); } } @@ -185,49 +197,57 @@ public long count(FieldName field, BytesRef value) { } @Override - public byte[] min(FieldName field, DataType dataType) { + public Object min(FieldName field) { var stat = cache.computeIfAbsent(field.string(), this::makeFieldStats); + // Consolidate min for indexed date fields only, skip the others and mixed-typed fields. + MappedFieldType fieldType = stat.config.fieldType; + if (fieldType == null || stat.config.indexed == false || fieldType instanceof DateFieldType == false) { + return null; + } if (stat.min == null) { - var min = new byte[][] { null }; + var min = new long[] { Long.MAX_VALUE }; + Holder foundMinValue = new Holder<>(false); doWithContexts(r -> { - byte[] localMin = PointValues.getMinPackedValue(r, field.string()); - // TODO: how to compare with the previous min - if (localMin != null) { - if (min[0] == null) { - min[0] = localMin; - } else { - throw new EsqlIllegalArgumentException("Don't know how to compare with previous min"); + byte[] minPackedValue = PointValues.getMinPackedValue(r, field.string()); + if (minPackedValue != null && minPackedValue.length == 8) { + long minValue = NumericUtils.sortableBytesToLong(minPackedValue, 0); + if (minValue <= min[0]) { + min[0] = minValue; + foundMinValue.set(true); } } return true; }, true); - stat.min = min[0]; + stat.min = foundMinValue.get() ? min[0] : null; } - // return stat.min; - return null; + return stat.min; } @Override - public byte[] max(FieldName field, DataType dataType) { + public Object max(FieldName field) { var stat = cache.computeIfAbsent(field.string(), this::makeFieldStats); + // Consolidate max for indexed date fields only, skip the others and mixed-typed fields. + MappedFieldType fieldType = stat.config.fieldType; + if (fieldType == null || stat.config.indexed == false || fieldType instanceof DateFieldType == false) { + return null; + } if (stat.max == null) { - var max = new byte[][] { null }; + var max = new long[] { Long.MIN_VALUE }; + Holder foundMaxValue = new Holder<>(false); doWithContexts(r -> { - byte[] localMax = PointValues.getMaxPackedValue(r, field.string()); - // TODO: how to compare with the previous max - if (localMax != null) { - if (max[0] == null) { - max[0] = localMax; - } else { - throw new EsqlIllegalArgumentException("Don't know how to compare with previous max"); + byte[] maxPackedValue = PointValues.getMaxPackedValue(r, field.string()); + if (maxPackedValue != null && maxPackedValue.length == 8) { + long maxValue = NumericUtils.sortableBytesToLong(maxPackedValue, 0); + if (maxValue >= max[0]) { + max[0] = maxValue; + foundMaxValue.set(true); } } return true; }, true); - stat.max = max[0]; + stat.max = foundMaxValue.get() ? max[0] : null; } - // return stat.max; - return null; + return stat.max; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchStats.java index ff1701104eca9..5c7ab1fdd6242 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/SearchStats.java @@ -10,7 +10,6 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute.FieldName; -import org.elasticsearch.xpack.esql.core.type.DataType; /** * Interface for determining information about fields in the index. @@ -33,9 +32,9 @@ public interface SearchStats { long count(FieldName field, BytesRef value); - byte[] min(FieldName field, DataType dataType); + Object min(FieldName field); - byte[] max(FieldName field, DataType dataType); + Object max(FieldName field); boolean isSingleValue(FieldName field); @@ -90,12 +89,12 @@ public long count(FieldName field, BytesRef value) { } @Override - public byte[] min(FieldName field, DataType dataType) { + public Object min(FieldName field) { return null; } @Override - public byte[] max(FieldName field, DataType dataType) { + public Object max(FieldName field) { return null; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index bdf2ba39edc66..62280a38ba608 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -277,6 +277,10 @@ public final void test() throws Throwable { "can't use match in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.MATCH_OPERATOR_COLON.capabilityName()) ); + assumeFalse( + "can't use score function in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.SCORE_FUNCTION.capabilityName()) + ); assumeFalse( "can't load metrics in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.METRICS_COMMAND.capabilityName()) @@ -299,7 +303,7 @@ public final void test() throws Throwable { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfoTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfoTests.java index 111d86669af22..19899e62ca057 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfoTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfoTests.java @@ -57,7 +57,7 @@ public void testHasMetadataPartial() { assertFalse(info.hasMetadataToReport()); info.swapCluster(key, (k, v) -> { EsqlExecutionInfo.Cluster.Builder builder = new EsqlExecutionInfo.Cluster.Builder(v); - builder.setFailures(List.of(new ShardSearchFailure(new IllegalStateException("shard failure")))); + builder.addFailures(List.of(new ShardSearchFailure(new IllegalStateException("shard failure")))); return builder.build(); }); assertTrue(info.hasMetadataToReport()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index cbb825ca9581b..fbfa18dccc477 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -36,9 +36,9 @@ import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.RANGE_TYPE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution; public final class AnalyzerTestUtils { @@ -61,27 +61,36 @@ public static Analyzer analyzer(IndexResolution indexResolution, Map lookupResolution, Verifier verifier) { + return analyzer(indexResolution, lookupResolution, defaultEnrichResolution(), verifier); + } + + public static Analyzer analyzer( + IndexResolution indexResolution, + Map lookupResolution, + EnrichResolution enrichResolution, + Verifier verifier + ) { + return analyzer(indexResolution, lookupResolution, enrichResolution, verifier, TEST_CFG); + } + + public static Analyzer analyzer( + IndexResolution indexResolution, + Map lookupResolution, + EnrichResolution enrichResolution, + Verifier verifier, + Configuration config + ) { return new Analyzer( new AnalyzerContext( - EsqlTestUtils.TEST_CFG, + config, new EsqlFunctionRegistry(), indexResolution, lookupResolution, - defaultEnrichResolution(), + enrichResolution, defaultInferenceResolution() ), verifier @@ -89,17 +98,7 @@ public static Analyzer analyzer(IndexResolution indexResolution, Map expectedElems) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | eval similarity = %s + """, similarityFunction), "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("similarity", alias.name()); + var similarity = as(alias.child(), VectorSimilarityFunction.class); + var left = as(similarity.left(), FieldAttribute.class); + assertEquals("vector", left.name()); + var right = as(similarity.right(), Literal.class); + assertThat(right.dataType(), is(DENSE_VECTOR)); + assertThat(right.value(), equalTo(expectedElems)); + } + + public void testNoDenseVectorFailsSimilarityFunction() { + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkNoDenseVectorFailsSimilarityFunction("v_cosine([0, 1, 2], 0.342)"); + } + } + + private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) { + var query = String.format(Locale.ROOT, "row a = 1 | eval similarity = %s", similarityFunction); + VerificationException error = expectThrows(VerificationException.class, () -> analyze(query)); + assertThat( + error.getMessage(), + containsString("second argument of [" + similarityFunction + "] must be" + " [dense_vector], found value [0.342] type [double]") + ); } public void testRateRequiresCounterTypes() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java index 1cda4bd599d61..54071ac86d59f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java @@ -19,7 +19,10 @@ import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.parser.EsqlParser; +import org.elasticsearch.xpack.esql.parser.ParserUtils; import org.elasticsearch.xpack.esql.parser.ParsingException; +import org.elasticsearch.xpack.esql.parser.QueryParam; +import org.elasticsearch.xpack.esql.parser.QueryParams; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter; @@ -157,9 +160,34 @@ public void testJoinTwiceOnTheSameField_TwoLookups() { } public void testInvalidLimit() { - assertEquals("1:13: Invalid value for LIMIT [foo: String], expecting a non negative integer", error("row a = 1 | limit \"foo\"")); - assertEquals("1:13: Invalid value for LIMIT [1.2: Double], expecting a non negative integer", error("row a = 1 | limit 1.2")); - assertEquals("1:13: Invalid value for LIMIT [-1], expecting a non negative integer", error("row a = 1 | limit -1")); + assertLimitWithAndWithoutParams("foo", "\"foo\"", DataType.KEYWORD); + assertLimitWithAndWithoutParams(1.2, "1.2", DataType.DOUBLE); + assertLimitWithAndWithoutParams(-1, "-1", DataType.INTEGER); + assertLimitWithAndWithoutParams(true, "true", DataType.BOOLEAN); + assertLimitWithAndWithoutParams(false, "false", DataType.BOOLEAN); + assertLimitWithAndWithoutParams(null, "null", DataType.NULL); + } + + private void assertLimitWithAndWithoutParams(Object value, String valueText, DataType type) { + assertEquals( + "1:13: value of [limit " + + valueText + + "] must be a non negative integer, found value [" + + valueText + + "] type [" + + type.typeName() + + "]", + error("row a = 1 | limit " + valueText) + ); + + assertEquals( + "1:13: value of [limit ?param] must be a non negative integer, found value [?param] type [" + type.typeName() + "]", + error( + "row a = 1 | limit ?param", + new QueryParams(List.of(new QueryParam("param", value, type, ParserUtils.ParamClassification.VALUE))) + ) + ); + } public void testInvalidSample() { @@ -181,13 +209,20 @@ public void testInvalidSample() { ); } - private String error(String query) { - ParsingException e = expectThrows(ParsingException.class, () -> defaultAnalyzer.analyze(parser.createStatement(query, TEST_CFG))); + private String error(String query, QueryParams params) { + ParsingException e = expectThrows( + ParsingException.class, + () -> defaultAnalyzer.analyze(parser.createStatement(query, params, TEST_CFG)) + ); String message = e.getMessage(); assertTrue(message.startsWith("line ")); return message.substring("line ".length()); } + private String error(String query) { + return error(query, new QueryParams()); + } + private static IndexResolution loadIndexResolution(String name) { return IndexResolution.valid(new EsIndex(INDEX_NAME, LoadMapping.loadMapping(name))); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index fc38af5569b98..b74f6b99db2e1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.parser.QueryParam; import org.elasticsearch.xpack.esql.parser.QueryParams; +import org.elasticsearch.xpack.esql.plan.logical.Enrich; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -38,9 +39,13 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadEnrichPolicyResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; @@ -574,6 +579,21 @@ public void testInvalidBucketCalls() { + "found value [\"5\"] type [keyword]" ) ); + + assertThat( + error("from test | stats max(emp_no) by bucket(hire_date, 5, true)"), + containsString( + "function expects exactly four arguments when the first one is of type [DATETIME] and the second of type [INTEGER]" + ) + ); + + assertThat( + error("from test | stats max(emp_no) by bucket(hire_date, 1 week, true)"), + containsString( + "function expects exactly two or four arguments when the first one is of type [DATETIME] and the second of type " + + "[DATE_PERIOD]" + ) + ); } public void testAggsWithInvalidGrouping() { @@ -1213,7 +1233,9 @@ public void testWeightedAvg() { public void testMatchInsideEval() throws Exception { assertEquals( - "1:36: [:] operator is only supported in WHERE and STATS commands\n" + "1:36: [:] operator is only supported in WHERE and STATS commands" + + (EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled() ? ", or in EVAL within score(.) function" : "") + + "\n" + "line 1:36: [:] operator cannot operate on [title], which is not a field from an index mapping", error("row title = \"brown fox\" | eval x = title:\"fox\" ") ); @@ -1237,7 +1259,7 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); } } @@ -1370,20 +1392,29 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); } + } private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, String functionInvocation, String functionType) throws Exception { assertThat( error("from test | eval y = " + functionInvocation, fullTextAnalyzer), - containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands") + containsString( + "[" + + functionName + + "] " + + functionType + + " is only supported in WHERE and STATS commands" + + (EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled() ? ", or in EVAL within score(.) function" : "") + ) ); assertThat( error("from test | sort " + functionInvocation + " asc", fullTextAnalyzer), containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands") + ); assertThat( error("from test | stats max_id = max(id) by " + functionInvocation, fullTextAnalyzer), @@ -1392,7 +1423,14 @@ private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, Strin if ("KQL".equals(functionName) || "QSTR".equals(functionName)) { assertThat( error("row a = " + functionInvocation, fullTextAnalyzer), - containsString("[" + functionName + "] " + functionType + " is only supported in WHERE and STATS commands") + containsString( + "[" + + functionName + + "] " + + functionType + + " is only supported in WHERE and STATS commands" + + (EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled() ? ", or in EVAL within score(.) function" : "") + ) ); } } @@ -1409,7 +1447,7 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); } } @@ -1474,7 +1512,7 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); } } @@ -1545,7 +1583,7 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -1954,6 +1992,57 @@ public void testCategorizeWithFilteredAggregations() { ); } + public void testCategorizeInvalidOptionsField() { + assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled()); + + assertEquals( + "1:31: second argument of [CATEGORIZE(last_name, first_name)] must be a map expression, received [first_name]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, first_name)") + ); + assertEquals( + "1:31: Invalid option [blah] in [CATEGORIZE(last_name, { \"blah\": 42 })], " + + "expected one of [analyzer, output_format, similarity_threshold]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"blah\": 42 })") + ); + } + + public void testCategorizeOptionOutputFormat() { + assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled()); + + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"regex\" })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"REGEX\" })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"tokens\" })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"ToKeNs\" })"); + assertEquals( + "1:31: invalid output format [blah], expecting one of [REGEX, TOKENS]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"blah\" })") + ); + assertEquals( + "1:31: invalid output format [42], expecting one of [REGEX, TOKENS]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": 42 })") + ); + } + + public void testCategorizeOptionSimilarityThreshold() { + assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled()); + + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 1 })"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 100 })"); + assertEquals( + "1:31: invalid similarity threshold [0], expecting a number between 1 and 100, inclusive", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 0 })") + ); + assertEquals( + "1:31: invalid similarity threshold [101], expecting a number between 1 and 100, inclusive", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": 101 })") + ); + assertEquals( + "1:31: Invalid option [similarity_threshold] in [CATEGORIZE(last_name, { \"similarity_threshold\": \"blah\" })], " + + "cannot cast [blah] to [integer]", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"similarity_threshold\": \"blah\" })") + ); + } + public void testChangePoint() { assumeTrue("change_point must be enabled", EsqlCapabilities.Cap.CHANGE_POINT.isEnabled()); var airports = AnalyzerTestUtils.analyzer(loadMapping("mapping-airports.json", "airports")); @@ -2073,7 +2162,7 @@ public void testFullTextFunctionOptions() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); } } @@ -2161,7 +2250,7 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); @@ -2187,7 +2276,7 @@ public void testFullTextFunctionsConstantArg() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsConstantArg("term(title, tags)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second"); checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third"); } @@ -2218,7 +2307,7 @@ public void testFullTextFunctionsInStats() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); } } @@ -2274,7 +2363,140 @@ public void testRemoteLookupJoinIsDisabled() { () -> query("FROM test,remote:test | EVAL language_code = languages | LOOKUP JOIN languages_lookup ON language_code") ); assertThat(e.getMessage(), containsString("remote clusters are not supported with LOOKUP JOIN")); + } + public void testRemoteEnrichAfterLookupJoin() { + EnrichResolution enrichResolution = new EnrichResolution(); + loadEnrichPolicyResolution( + enrichResolution, + Enrich.Mode.REMOTE, + MATCH_TYPE, + "languages", + "language_code", + "languages_idx", + "mapping-languages.json" + ); + var analyzer = AnalyzerTestUtils.analyzer( + loadMapping("mapping-default.json", "test"), + defaultLookupResolution(), + enrichResolution, + TEST_VERIFIER + ); + + String lookupCommand = randomBoolean() ? "LOOKUP JOIN test_lookup ON languages" : "LOOKUP JOIN languages_lookup ON language_code"; + + query(Strings.format(""" + FROM test + | EVAL language_code = languages + | ENRICH _remote:languages ON language_code + | %s + """, lookupCommand), analyzer); + + String err = error(Strings.format(""" + FROM test + | EVAL language_code = languages + | %s + | ENRICH _remote:languages ON language_code + """, lookupCommand), analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after LOOKUP JOIN")); + + err = error(Strings.format(""" + FROM test + | EVAL language_code = languages + | %s + | ENRICH _remote:languages ON language_code + | %s + """, lookupCommand, lookupCommand), analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after LOOKUP JOIN")); + + err = error(Strings.format(""" + FROM test + | EVAL language_code = languages + | %s + | EVAL x = 1 + | MV_EXPAND language_code + | ENRICH _remote:languages ON language_code + """, lookupCommand), analyzer); + assertThat(err, containsString("6:3: ENRICH with remote policy can't be executed after LOOKUP JOIN")); + } + + public void testRemoteEnrichAfterCoordinatorOnlyPlans() { + EnrichResolution enrichResolution = new EnrichResolution(); + loadEnrichPolicyResolution( + enrichResolution, + Enrich.Mode.REMOTE, + MATCH_TYPE, + "languages", + "language_code", + "languages_idx", + "mapping-languages.json" + ); + loadEnrichPolicyResolution( + enrichResolution, + Enrich.Mode.COORDINATOR, + MATCH_TYPE, + "languages", + "language_code", + "languages_idx", + "mapping-languages.json" + ); + var analyzer = AnalyzerTestUtils.analyzer( + loadMapping("mapping-default.json", "test"), + defaultLookupResolution(), + enrichResolution, + TEST_VERIFIER + ); + + query(""" + FROM test + | EVAL language_code = languages + | ENRICH _remote:languages ON language_code + | STATS count(*) BY language_name + """, analyzer); + + String err = error(""" + FROM test + | EVAL language_code = languages + | STATS count(*) BY language_code + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after STATS")); + + err = error(""" + FROM test + | EVAL language_code = languages + | STATS count(*) BY language_code + | EVAL x = 1 + | MV_EXPAND language_code + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("6:3: ENRICH with remote policy can't be executed after STATS")); + + query(""" + FROM test + | EVAL language_code = languages + | ENRICH _remote:languages ON language_code + | ENRICH _coordinator:languages ON language_code + """, analyzer); + + err = error(""" + FROM test + | EVAL language_code = languages + | ENRICH _coordinator:languages ON language_code + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); + + err = error(""" + FROM test + | EVAL language_code = languages + | ENRICH _coordinator:languages ON language_code + | EVAL x = 1 + | MV_EXPAND language_name + | DISSECT language_name "%{foo}" + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("7:3: ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); } private void checkFullTextFunctionsInStats(String functionInvocation) { @@ -2292,6 +2514,20 @@ private void checkFullTextFunctionsInStats(String functionInvocation) { ); } + public void testVectorSimilarityFunctionsNullArgs() throws Exception { + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first"); + checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second"); + } + } + + private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception { + assertThat( + error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer), + containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]") + ); + } + private void query(String query) { query(query, defaultAnalyzer); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java index dba6facde2b25..96b651ff76efa 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java @@ -72,7 +72,16 @@ static String functionSignature(FunctionDefinition definition) throws IOExceptio // BUCKET requires optional args to be optional together, so we need custom code to do that var nextArg = args.get(++i); assert nextArg.optional(); - Sequence seq = new Sequence(new Literal(argName), new Syntax(","), new Literal(nextArg.name)); + var nexterArg = args.get(++i); + assert nexterArg.optional(); + // TODO: Should it be possible to be able to specify "emitEmptyBuckets" but not "from" and "to"? + Sequence seq = new Sequence( + new Literal(argName), + new Syntax(","), + new Literal(nextArg.name), + new Syntax(","), + new Literal(nexterArg.name) + ); argExpressions.add(new Repetition(seq, 0, 1)); } else if (i < args.size() - 1 && args.get(i + 1).optional() == false) { // Special case with leading optional args diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateTests.java index 8943b6549e502..7da479ea28dba 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateTests.java @@ -95,8 +95,22 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier if (dataRows.size() < 2) { matcher = Matchers.nullValue(); } else { - // TODO: check the value? - matcher = Matchers.allOf(Matchers.greaterThanOrEqualTo(0.0), Matchers.lessThan(Double.POSITIVE_INFINITY)); + var maxrate = switch (fieldTypedData.type().widenSmallNumeric()) { + case INTEGER, COUNTER_INTEGER -> dataRows.stream().mapToInt(v -> (Integer) v).max().orElse(0); + case LONG, COUNTER_LONG -> dataRows.stream().mapToLong(v -> (Long) v).max().orElse(0L); + case DOUBLE, COUNTER_DOUBLE -> dataRows.stream().mapToDouble(v -> (Double) v).max().orElse(0.0); + default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type()); + }; + var minrate = switch (fieldTypedData.type().widenSmallNumeric()) { + case INTEGER, COUNTER_INTEGER -> dataRows.stream().mapToInt(v -> (Integer) v).min().orElse(0); + case LONG, COUNTER_LONG -> dataRows.stream().mapToLong(v -> (Long) v).min().orElse(0L); + case DOUBLE, COUNTER_DOUBLE -> dataRows.stream().mapToDouble(v -> (Double) v).min().orElse(0.0); + default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type()); + }; + // If the minrate is greater than 0, we need to adjust the maxrate accordingly + minrate = Math.min(minrate, 0); + maxrate = Math.max(maxrate, maxrate - minrate); + matcher = Matchers.allOf(Matchers.greaterThanOrEqualTo(minrate), Matchers.lessThanOrEqualTo(maxrate)); } return new TestCaseSupplier.TestCase( List.of(fieldTypedData, timestampsField), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 4a5708b398b18..595eb58118a09 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -51,7 +51,7 @@ public static Iterable parameters() { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); } private static List testCaseSuppliers() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java new file mode 100644 index 0000000000000..74b2dffe2e4c4 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.fulltext; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("score") +public class ScoreTests extends AbstractMatchFullTextFunctionTests { + + @BeforeClass + public static void init() { + assumeTrue("can run this only when score() function is enabled", EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled()); + } + + public ScoreTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + suppliers.add( + new TestCaseSupplier( + List.of(BOOLEAN), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(randomBoolean(), BOOLEAN, "query")), + equalTo("ScoreEvaluator" + ScoreTests.class.getSimpleName()), + DOUBLE, + equalTo(true) + ) + ) + ); + + return parameterSuppliersFromTypedData(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Score(source, args.getFirst()); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java index 5fe270a4cce42..c30c842bad3b0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java @@ -25,7 +25,8 @@ public static Bucket createRandomBucket() { Expression buckets = randomChild(); Expression from = randomChild(); Expression to = randomChild(); - return new Bucket(source, field, buckets, from, to); + Expression emitEmptyBuckets = randomChild(); + return new Bucket(source, field, buckets, from, to, emitEmptyBuckets); } @Override @@ -35,12 +36,14 @@ protected Bucket mutateInstance(Bucket instance) throws IOException { Expression buckets = instance.buckets(); Expression from = instance.from(); Expression to = instance.to(); - switch (between(0, 3)) { + Expression emitEmptyBuckets = instance.emitEmptyBuckets(); + switch (between(0, 4)) { case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); case 1 -> buckets = randomValueOtherThan(buckets, AbstractExpressionSerializationTests::randomChild); case 2 -> from = randomValueOtherThan(from, AbstractExpressionSerializationTests::randomChild); case 3 -> to = randomValueOtherThan(to, AbstractExpressionSerializationTests::randomChild); + case 4 -> emitEmptyBuckets = randomValueOtherThan(emitEmptyBuckets, AbstractExpressionSerializationTests::randomChild); } - return new Bucket(source, field, buckets, from, to); + return new Bucket(source, field, buckets, from, to, emitEmptyBuckets); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java index f01b06c23e8a8..cf4f7e0f8f015 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java @@ -319,10 +319,14 @@ private static Matcher resultsMatcher(List t protected Expression build(Source source, List args) { Expression from = null; Expression to = null; + Expression emitEmptyBuckets = null; if (args.size() > 2) { from = args.get(2); to = args.get(3); } - return new Bucket(source, args.get(0), args.get(1), from, to); + if (args.size() > 4) { + emitEmptyBuckets = args.get(4); + } + return new Bucket(source, args.get(0), args.get(1), from, to, emitEmptyBuckets); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java index f674f9b2c3d72..97d5b8e3ece96 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeErrorTests.java @@ -27,7 +27,7 @@ protected List cases() { @Override protected Expression build(Source source, List args) { - return new Categorize(source, args.get(0)); + return new Categorize(source, args.get(0), args.size() > 1 ? args.get(1) : null); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java index f69bb7eb3e7bb..296d624ee1777 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java @@ -61,7 +61,7 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new Categorize(source, args.get(0)); + return new Categorize(source, args.get(0), args.size() > 1 ? args.get(1) : null); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/EmptyBucketTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/EmptyBucketTests.java new file mode 100644 index 0000000000000..1ffa01319f4d0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/EmptyBucketTests.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Rounding; +import org.elasticsearch.common.time.DateUtils; +import org.elasticsearch.index.mapper.DateFieldMapper; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; + +public class EmptyBucketTests extends AbstractScalarFunctionTestCase { + + public EmptyBucketTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + dateCase(suppliers, "fixed date"); + return parameterSuppliersFromTypedData(suppliers); + } + + private static void dateCase(List suppliers, String name) { + DataType fromType = DataType.DATETIME; + DataType toType = DataType.DATETIME; + long date = DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.parseMillis("2023-02-17T09:00:00.00Z"); + suppliers.add(new TestCaseSupplier(name, List.of(DataType.DATETIME, DataType.INTEGER, fromType, toType), () -> { + List args = new ArrayList<>(); + args.add(new TestCaseSupplier.TypedData(date, DataType.DATETIME, "field")); + // TODO more "from" and "to" and "buckets" + args.add(new TestCaseSupplier.TypedData(50, DataType.INTEGER, "buckets").forceLiteral()); + args.add(dateBound("from", fromType, "2023-02-01T00:00:00.00Z")); + args.add(dateBound("to", toType, "2023-03-01T09:00:00.00Z")); + return new TestCaseSupplier.TestCase( + args, + "DateTruncDatetimeEvaluator[fieldVal=Attribute[channel=0], " + "rounding=Rounding[DAY_OF_MONTH in Z][fixed to midnight]]", + DataType.DATETIME, + resultsMatcher(args) + ); + })); + } + + private static TestCaseSupplier.TypedData dateBound(String name, DataType type, String date) { + Object value; + if (type == DataType.DATETIME) { + value = DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.parseMillis(date); + } else { + value = new BytesRef(date); + } + return new TestCaseSupplier.TypedData(value, type, name).forceLiteral(); + } + + private static Matcher resultsMatcher(List typedData) { + if (typedData.get(0).type() == DataType.DATETIME) { + long millis = ((Number) typedData.get(0).data()).longValue(); + long expected = Rounding.builder(Rounding.DateTimeUnit.DAY_OF_MONTH).build().prepareForUnknown().round(millis); + LogManager.getLogger(getTestClass()).info("Expected: " + Instant.ofEpochMilli(expected)); + LogManager.getLogger(getTestClass()).info("Input: " + Instant.ofEpochMilli(millis)); + return equalTo(expected); + } + if (typedData.get(0).type() == DataType.DATE_NANOS) { + long nanos = ((Number) typedData.get(0).data()).longValue(); + long expected = DateUtils.toNanoSeconds( + Rounding.builder(Rounding.DateTimeUnit.DAY_OF_MONTH).build().prepareForUnknown().round(DateUtils.toMilliSeconds(nanos)) + ); + LogManager.getLogger(getTestClass()).info("Expected: " + DateUtils.toInstant(expected)); + LogManager.getLogger(getTestClass()).info("Input: " + DateUtils.toInstant(nanos)); + return equalTo(expected); + } + return equalTo(((Number) typedData.get(0).data()).doubleValue()); + } + + @Override + protected Expression build(Source source, List args) { + Expression from = null; + Expression to = null; + Expression emitEmptyBuckets = null; + if (args.size() > 2) { + from = args.get(2); + to = args.get(3); + } + if (args.size() > 4) { + emitEmptyBuckets = args.get(4); + } + return new Bucket(source, args.get(0), args.get(1), from, to, emitEmptyBuckets); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithStaticTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithStaticTests.java new file mode 100644 index 0000000000000..ddde306deed7a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithStaticTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.capabilities.TranslationAware; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.querydsl.query.WildcardQuery; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; + +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; + +public class EndsWithStaticTests extends ESTestCase { + public void testLuceneQuery_AllLiterals_NonTranslatable() { + EndsWith function = new EndsWith(Source.EMPTY, Literal.keyword(Source.EMPTY, "test"), Literal.keyword(Source.EMPTY, "test")); + + ESTestCase.assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); + } + + public void testLuceneQuery_NonFoldableSuffix_NonTranslatable() { + EndsWith function = new EndsWith( + Source.EMPTY, + new FieldAttribute(Source.EMPTY, "field", new EsField("field", DataType.KEYWORD, Map.of(), true)), + new FieldAttribute(Source.EMPTY, "field", new EsField("suffix", DataType.KEYWORD, Map.of(), true)) + ); + + assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); + } + + public void testLuceneQuery_NonFoldableSuffix_Translatable() { + EndsWith function = new EndsWith( + Source.EMPTY, + new FieldAttribute(Source.EMPTY, "field", new EsField("suffix", DataType.KEYWORD, Map.of(), true)), + Literal.keyword(Source.EMPTY, "a*b?c\\") + ); + + assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.YES)); + + Query query = function.asQuery(LucenePushdownPredicates.DEFAULT, TranslatorHandler.TRANSLATOR_HANDLER); + + assertThat(query, equalTo(new WildcardQuery(Source.EMPTY, "field", "*a\\*b\\?c\\\\", false, false))); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithTests.java index 0efd8daaacaa0..c41b1e14257ee 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithTests.java @@ -11,23 +11,15 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.querydsl.query.WildcardQuery; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; -import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.hamcrest.Matcher; import java.util.LinkedList; import java.util.List; -import java.util.Map; import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; @@ -106,34 +98,4 @@ private static TestCaseSupplier.TestCase testCase( protected Expression build(Source source, List args) { return new EndsWith(source, args.get(0), args.get(1)); } - - public void testLuceneQuery_AllLiterals_NonTranslatable() { - var function = new EndsWith(Source.EMPTY, Literal.keyword(Source.EMPTY, "test"), Literal.keyword(Source.EMPTY, "test")); - - assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); - } - - public void testLuceneQuery_NonFoldableSuffix_NonTranslatable() { - var function = new EndsWith( - Source.EMPTY, - new FieldAttribute(Source.EMPTY, "field", new EsField("field", DataType.KEYWORD, Map.of(), true)), - new FieldAttribute(Source.EMPTY, "field", new EsField("suffix", DataType.KEYWORD, Map.of(), true)) - ); - - assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); - } - - public void testLuceneQuery_NonFoldableSuffix_Translatable() { - var function = new EndsWith( - Source.EMPTY, - new FieldAttribute(Source.EMPTY, "field", new EsField("suffix", DataType.KEYWORD, Map.of(), true)), - Literal.keyword(Source.EMPTY, "a*b?c\\") - ); - - assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.YES)); - - var query = function.asQuery(LucenePushdownPredicates.DEFAULT, TranslatorHandler.TRANSLATOR_HANDLER); - - assertThat(query, equalTo(new WildcardQuery(Source.EMPTY, "field", "*a\\*b\\?c\\\\", false, false))); - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListErrorTests.java new file mode 100644 index 0000000000000..0e2fa024bda58 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListErrorTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; + +import java.util.List; +import java.util.Set; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class RLikeListErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + @Override + protected List cases() { + return paramsToSuppliers(RLikeListTests.parameters()); + } + + @Override + protected Stream> testCandidates(List cases, Set> valid) { + /* + * We can't support certain signatures, and it's safe not to test them because + * you can't even build them.... The building comes directly from the parser + * and can only make certain types. + */ + return super.testCandidates(cases, valid).filter(sig -> sig.get(1) == DataType.KEYWORD) + .filter(sig -> sig.size() > 2 && sig.get(2) == DataType.BOOLEAN); + } + + @Override + protected Expression build(Source source, List args) { + return RLikeTests.buildRLike(logger, source, args); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo(typeErrorMessage(false, validPerPosition, signature, (v, p) -> "string")); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListSerializationTests.java new file mode 100644 index 0000000000000..ff2dd31e2c832 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListSerializationTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePatternList; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLikeList; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class RLikeListSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected RLikeList createTestInstance() { + Source source = randomSource(); + Expression child = randomChild(); + return new RLikeList(source, child, generateRandomPatternList()); + } + + @Override + protected RLikeList mutateInstance(RLikeList instance) throws IOException { + Source source = instance.source(); + Expression child = instance.field(); + List patterns = new ArrayList<>(instance.pattern().patternList()); + int childToModify = randomIntBetween(0, patterns.size() - 1); + RLikePattern pattern = patterns.get(childToModify); + if (randomBoolean()) { + child = randomValueOtherThan(child, AbstractExpressionSerializationTests::randomChild); + } else { + pattern = randomValueOtherThan(pattern, () -> new RLikePattern(randomAlphaOfLength(4))); + } + patterns.set(childToModify, pattern); + return new RLikeList(source, child, new RLikePatternList(patterns)); + } + + private RLikePatternList generateRandomPatternList() { + int numChildren = randomIntBetween(1, 10); // Ensure at least one child + List patterns = new ArrayList<>(numChildren); + for (int i = 0; i < numChildren; i++) { + RLikePattern pattern = new RLikePattern(randomAlphaOfLength(4)); + patterns.add(pattern); + } + return new RLikePatternList(patterns); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListTests.java new file mode 100644 index 0000000000000..d18c81502117d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeListTests.java @@ -0,0 +1,206 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePatternList; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLikeList; +import org.junit.AfterClass; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.util.TestUtils.randomCasing; +import static org.elasticsearch.xpack.esql.expression.function.DocsV3Support.renderNegatedOperator; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; + +public class RLikeListTests extends AbstractScalarFunctionTestCase { + public RLikeListTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + final Function escapeString = str -> { + for (String syntax : new String[] { "\\", ".", "?", "+", "*", "|", "{", "}", "[", "]", "(", ")", "\"", "<", ">", "#", "&" }) { + str = str.replace(syntax, "\\" + syntax); + } + return str; + }; + return parameters(escapeString, () -> randomAlphaOfLength(1) + "?"); + } + + static Iterable parameters(Function escapeString, Supplier optionalPattern) { + List cases = new ArrayList<>(); + cases.add( + new TestCaseSupplier( + "null", + List.of(DataType.NULL, DataType.KEYWORD, DataType.BOOLEAN), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(null, DataType.NULL, "e"), + new TestCaseSupplier.TypedData(new BytesRef(randomAlphaOfLength(10)), DataType.KEYWORD, "pattern").forceLiteral(), + new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "caseInsensitive").forceLiteral() + ), + "LiteralsEvaluator[lit=null]", + DataType.BOOLEAN, + nullValue() + ) + ) + ); + casesForString(cases, "empty string", () -> "", false, escapeString, optionalPattern); + casesForString(cases, "single ascii character", () -> randomAlphaOfLength(1), true, escapeString, optionalPattern); + casesForString(cases, "ascii string", () -> randomAlphaOfLengthBetween(2, 100), true, escapeString, optionalPattern); + casesForString(cases, "3 bytes, 1 code point", () -> "☕", false, escapeString, optionalPattern); + casesForString(cases, "6 bytes, 2 code points", () -> "❗️", false, escapeString, optionalPattern); + casesForString(cases, "100 random code points", () -> randomUnicodeOfCodepointLength(100), true, escapeString, optionalPattern); + return parameterSuppliersFromTypedData(cases); + } + + record TextAndPattern(String text, String pattern) {} + + private static void casesForString( + List cases, + String title, + Supplier textSupplier, + boolean canGenerateDifferent, + Function escapeString, + Supplier optionalPattern + ) { + cases(cases, title + " matches self", () -> { + String text = textSupplier.get(); + return new TextAndPattern(text, escapeString.apply(text)); + }, true); + cases(cases, title + " matches self case insensitive", () -> { + // RegExp doesn't support case-insensitive matching for Unicodes whose length changes when the case changes. + // Example: a case-insensitive ES regexp query for the pattern `weiß` won't match the value `WEISS` (but will match `WEIß`). + // Or `ʼn` (U+0149) vs. `ʼN` (U+02BC U+004E). + String text, caseChanged; + for (text = textSupplier.get(), caseChanged = randomCasing(text); text.length() != caseChanged.length();) { + text = textSupplier.get(); + caseChanged = randomCasing(text); + } + return new TextAndPattern(caseChanged, escapeString.apply(text)); + }, true, true); + cases(cases, title + " doesn't match self with trailing", () -> { + String text = textSupplier.get(); + return new TextAndPattern(text, escapeString.apply(text) + randomAlphaOfLength(1)); + }, false); + cases(cases, title + " doesn't match self with trailing case insensitive", () -> { + String text = textSupplier.get(); + return new TextAndPattern(randomCasing(text), escapeString.apply(text) + randomAlphaOfLength(1)); + }, true, false); + cases(cases, title + " matches self with optional trailing", () -> { + String text = randomAlphaOfLength(1); + return new TextAndPattern(text, escapeString.apply(text) + optionalPattern.get()); + }, true); + cases(cases, title + " matches self with optional trailing case insensitive", () -> { + String text = randomAlphaOfLength(1); + return new TextAndPattern(randomCasing(text), escapeString.apply(text) + optionalPattern.get()); + }, true, true); + if (canGenerateDifferent) { + cases(cases, title + " doesn't match different", () -> { + String text = textSupplier.get(); + String different = escapeString.apply(randomValueOtherThan(text, textSupplier)); + return new TextAndPattern(text, different); + }, false); + cases(cases, title + " doesn't match different case insensitive", () -> { + String text = textSupplier.get(); + Predicate predicate = t -> t.toLowerCase(Locale.ROOT).equals(text.toLowerCase(Locale.ROOT)); + String different = escapeString.apply(randomValueOtherThanMany(predicate, textSupplier)); + return new TextAndPattern(text, different); + }, true, false); + } + } + + private static void cases(List cases, String title, Supplier textAndPattern, boolean expected) { + cases(cases, title, textAndPattern, false, expected); + } + + private static void cases( + List cases, + String title, + Supplier textAndPattern, + boolean caseInsensitive, + boolean expected + ) { + for (DataType type : DataType.stringTypes()) { + cases.add(new TestCaseSupplier(title + " with " + type.esType(), List.of(type, DataType.KEYWORD, DataType.BOOLEAN), () -> { + TextAndPattern v = textAndPattern.get(); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(new BytesRef(v.text), type, "e"), + new TestCaseSupplier.TypedData(new BytesRef(v.pattern), DataType.KEYWORD, "pattern").forceLiteral(), + new TestCaseSupplier.TypedData(caseInsensitive, DataType.BOOLEAN, "caseInsensitive").forceLiteral() + ), + startsWith("AutomataMatchEvaluator[input=Attribute[channel=0], pattern=digraph Automaton {\n"), + DataType.BOOLEAN, + equalTo(expected) + ); + })); + if (caseInsensitive == false) { + cases.add(new TestCaseSupplier(title + " with " + type.esType(), List.of(type, DataType.KEYWORD), () -> { + TextAndPattern v = textAndPattern.get(); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(new BytesRef(v.text), type, "e"), + new TestCaseSupplier.TypedData(new BytesRef(v.pattern), DataType.KEYWORD, "pattern").forceLiteral() + ), + startsWith("AutomataMatchEvaluator[input=Attribute[channel=0], pattern=digraph Automaton {\n"), + DataType.BOOLEAN, + equalTo(expected) + ); + })); + } + } + } + + @Override + protected Expression build(Source source, List args) { + return buildRLikeList(logger, source, args); + } + + static Expression buildRLikeList(Logger logger, Source source, List args) { + Expression expression = args.get(0); + Literal pattern = (Literal) args.get(1); + Literal caseInsensitive = args.size() > 2 ? (Literal) args.get(2) : null; + String patternString = ((BytesRef) pattern.fold(FoldContext.small())).utf8ToString(); + boolean caseInsensitiveBool = caseInsensitive != null ? (boolean) caseInsensitive.fold(FoldContext.small()) : false; + logger.info("pattern={} caseInsensitive={}", patternString, caseInsensitiveBool); + + return caseInsensitiveBool + ? new RLikeList(source, expression, new RLikePatternList(List.of(new RLikePattern(patternString))), true) + : (randomBoolean() + ? new RLikeList(source, expression, new RLikePatternList(List.of(new RLikePattern(patternString)))) + : new RLikeList(source, expression, new RLikePatternList(List.of(new RLikePattern(patternString))), false)); + } + + @AfterClass + public static void renderNotRLike() throws Exception { + renderNegatedOperator(constructorWithFunctionInfo(RLike.class), "RLIKE", d -> d, getTestClass()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithStaticTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithStaticTests.java new file mode 100644 index 0000000000000..105ce6a9e4142 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithStaticTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.capabilities.TranslationAware; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.querydsl.query.WildcardQuery; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; + +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; + +public class StartsWithStaticTests extends ESTestCase { + + public void testLuceneQuery_AllLiterals_NonTranslatable() { + var function = new StartsWith(Source.EMPTY, Literal.keyword(Source.EMPTY, "test"), Literal.keyword(Source.EMPTY, "test")); + + assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); + } + + public void testLuceneQuery_NonFoldablePrefix_NonTranslatable() { + var function = new StartsWith( + Source.EMPTY, + new FieldAttribute(Source.EMPTY, "field", new EsField("field", DataType.KEYWORD, Map.of(), true)), + new FieldAttribute(Source.EMPTY, "field", new EsField("prefix", DataType.KEYWORD, Map.of(), true)) + ); + + assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); + } + + public void testLuceneQuery_NonFoldablePrefix_Translatable() { + var function = new StartsWith( + Source.EMPTY, + new FieldAttribute(Source.EMPTY, "field", new EsField("prefix", DataType.KEYWORD, Map.of(), true)), + Literal.keyword(Source.EMPTY, "a*b?c\\") + ); + + assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.YES)); + + var query = function.asQuery(LucenePushdownPredicates.DEFAULT, TranslatorHandler.TRANSLATOR_HANDLER); + + assertThat(query, equalTo(new WildcardQuery(Source.EMPTY, "field", "a\\*b\\?c\\\\*", false, false))); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithTests.java index 67fb9f0c41f26..e1d02472fca43 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithTests.java @@ -11,22 +11,14 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.querydsl.query.WildcardQuery; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; -import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; @@ -67,33 +59,4 @@ protected Expression build(Source source, List args) { return new StartsWith(source, args.get(0), args.get(1)); } - public void testLuceneQuery_AllLiterals_NonTranslatable() { - var function = new StartsWith(Source.EMPTY, Literal.keyword(Source.EMPTY, "test"), Literal.keyword(Source.EMPTY, "test")); - - assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); - } - - public void testLuceneQuery_NonFoldablePrefix_NonTranslatable() { - var function = new StartsWith( - Source.EMPTY, - new FieldAttribute(Source.EMPTY, "field", new EsField("field", DataType.KEYWORD, Map.of(), true)), - new FieldAttribute(Source.EMPTY, "field", new EsField("prefix", DataType.KEYWORD, Map.of(), true)) - ); - - assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.NO)); - } - - public void testLuceneQuery_NonFoldablePrefix_Translatable() { - var function = new StartsWith( - Source.EMPTY, - new FieldAttribute(Source.EMPTY, "field", new EsField("prefix", DataType.KEYWORD, Map.of(), true)), - Literal.keyword(Source.EMPTY, "a*b?c\\") - ); - - assertThat(function.translatable(LucenePushdownPredicates.DEFAULT), equalTo(TranslationAware.Translatable.YES)); - - var query = function.asQuery(LucenePushdownPredicates.DEFAULT, TranslatorHandler.TRANSLATOR_HANDLER); - - assertThat(query, equalTo(new WildcardQuery(Source.EMPTY, "field", "a\\*b\\?c\\\\*", false, false))); - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java new file mode 100644 index 0000000000000..329eba63046f4 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.hamcrest.Matchers.equalTo; + +public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase { + + protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @Before + public void checkCapability() { + assumeTrue("Similarity function is not enabled", capability().isEnabled()); + } + + /** + * Get the capability of the vector similarity function to check + */ + protected abstract EsqlCapabilities.Cap capability(); + + protected static Iterable similarityParameters( + String className, + VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction + ) { + + final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]"; + + List suppliers = new ArrayList<>(); + + // Basic test with two dense vectors + suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DENSE_VECTOR), () -> { + int dimensions = between(64, 128); + List left = randomDenseVector(dimensions); + List right = randomDenseVector(dimensions); + float[] leftArray = listToFloatArray(left); + float[] rightArray = listToFloatArray(right); + double expected = similarityFunction.calculateSimilarity(leftArray, rightArray); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), + new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") + ), + evaluatorName, + DOUBLE, + equalTo(expected) // Random vectors should have cosine similarity close to 0 + ); + })); + + return parameterSuppliersFromTypedData(suppliers); + } + + private static float[] listToFloatArray(List floatList) { + float[] floatArray = new float[floatList.size()]; + for (int i = 0; i < floatList.size(); i++) { + floatArray[i] = floatList.get(i); + } + return floatArray; + } + + protected double calculateSimilarity(List left, List right) { + return 0; + } + + /** + * @return A random dense vector for testing + * @param dimensions + */ + private static List randomDenseVector(int dimensions) { + List vector = new ArrayList<>(); + for (int i = 0; i < dimensions; i++) { + vector.add(randomFloat()); + } + return vector; + } + + @Override + protected Matcher allNullsMatcher() { + // A null value on the left or right vector. Similarity is 0 + return equalTo(0.0); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java new file mode 100644 index 0000000000000..32ba95ee0af27 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.List; +import java.util.function.Supplier; + +@FunctionName("v_cosine") +public class CosineSimilarityTests extends AbstractVectorSimilarityFunctionTestCase { + + public CosineSimilarityTests(@Name("TestCase") Supplier testCaseSupplier) { + super(testCaseSupplier); + } + + @ParametersFactory + public static Iterable parameters() { + return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION); + } + + protected EsqlCapabilities.Cap capability() { + return EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION; + } + + @Override + protected Expression build(Source source, List args) { + return new CosineSimilarity(source, args.get(0), args.get(1)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InStaticTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InStaticTests.java new file mode 100644 index 0000000000000..b2fa9f4221769 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InStaticTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.querydsl.query.TermsQuery; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; + +import java.util.Arrays; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; +import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class InStaticTests extends ESTestCase { + private static final Literal ONE = L(1); + private static final Literal TWO = L(2); + private static final Literal THREE = L(3); + + public void testInWithContainedValue() { + In in = new In(EMPTY, TWO, Arrays.asList(ONE, TWO, THREE)); + assertTrue((Boolean) in.fold(FoldContext.small())); + } + + public void testInWithNotContainedValue() { + In in = new In(EMPTY, THREE, Arrays.asList(ONE, TWO)); + assertFalse((Boolean) in.fold(FoldContext.small())); + } + + public void testHandleNullOnLeftValue() { + In in = new In(EMPTY, NULL, Arrays.asList(ONE, TWO, THREE)); + assertNull(in.fold(FoldContext.small())); + in = new In(EMPTY, NULL, Arrays.asList(ONE, NULL, THREE)); + assertNull(in.fold(FoldContext.small())); + + } + + public void testHandleNullsOnRightValue() { + In in = new In(EMPTY, THREE, Arrays.asList(ONE, NULL, THREE)); + assertTrue((Boolean) in.fold(FoldContext.small())); + in = new In(EMPTY, ONE, Arrays.asList(TWO, NULL, THREE)); + assertNull(in.fold(FoldContext.small())); + } + + public void testConvertedNull() { + In in = new In( + EMPTY, + new FieldAttribute(Source.EMPTY, "field", new EsField("suffix", DataType.KEYWORD, Map.of(), true)), + Arrays.asList(ONE, new Literal(Source.EMPTY, null, randomFrom(DataType.types())), THREE) + ); + var query = in.asQuery(LucenePushdownPredicates.DEFAULT, TranslatorHandler.TRANSLATOR_HANDLER); + assertEquals(new TermsQuery(EMPTY, "field", Set.of(1, 3)), query); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java index 449389accc37b..f56dcb220b6ca 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java @@ -13,32 +13,19 @@ import org.elasticsearch.geo.GeometryTestUtils; import org.elasticsearch.geo.ShapeTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.querydsl.query.TermsQuery; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; -import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.junit.AfterClass; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Locale; -import java.util.Map; -import java.util.Set; import java.util.function.Supplier; import java.util.stream.IntStream; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; -import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; -import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; @@ -54,49 +41,6 @@ public InTests(@Name("TestCase") Supplier testCaseSup this.testCase = testCaseSupplier.get(); } - private static final Literal ONE = L(1); - private static final Literal TWO = L(2); - private static final Literal THREE = L(3); - - public void testInWithContainedValue() { - In in = new In(EMPTY, TWO, Arrays.asList(ONE, TWO, THREE)); - assertTrue((Boolean) in.fold(FoldContext.small())); - } - - public void testInWithNotContainedValue() { - In in = new In(EMPTY, THREE, Arrays.asList(ONE, TWO)); - assertFalse((Boolean) in.fold(FoldContext.small())); - } - - public void testHandleNullOnLeftValue() { - In in = new In(EMPTY, NULL, Arrays.asList(ONE, TWO, THREE)); - assertNull(in.fold(FoldContext.small())); - in = new In(EMPTY, NULL, Arrays.asList(ONE, NULL, THREE)); - assertNull(in.fold(FoldContext.small())); - - } - - public void testHandleNullsOnRightValue() { - In in = new In(EMPTY, THREE, Arrays.asList(ONE, NULL, THREE)); - assertTrue((Boolean) in.fold(FoldContext.small())); - in = new In(EMPTY, ONE, Arrays.asList(TWO, NULL, THREE)); - assertNull(in.fold(FoldContext.small())); - } - - private static Literal L(Object value) { - return of(EMPTY, value); - } - - public void testConvertedNull() { - In in = new In( - EMPTY, - new FieldAttribute(Source.EMPTY, "field", new EsField("suffix", DataType.KEYWORD, Map.of(), true)), - Arrays.asList(ONE, new Literal(Source.EMPTY, null, randomFrom(DataType.types())), THREE) - ); - var query = in.asQuery(LucenePushdownPredicates.DEFAULT, TranslatorHandler.TRANSLATOR_HANDLER); - assertEquals(new TermsQuery(EMPTY, "field", Set.of(1, 3)), query); - } - @ParametersFactory public static Iterable parameters() { List suppliers = new ArrayList<>(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index 7dbd625c18455..0cf86378a0f70 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.core.expression.Alias; @@ -30,6 +31,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.core.type.InvalidMappedField; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; @@ -43,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.InferIsNotNull; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -59,10 +62,12 @@ import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.hamcrest.Matchers; import org.junit.BeforeClass; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -88,6 +93,9 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -780,7 +788,7 @@ public void testGroupingByMissingFields() { as(eval.child(), EsRelation.class); } - public void testPlanSanityCheck() throws Exception { + public void testVerifierOnMissingReferences() throws Exception { var plan = localPlan(""" from test | stats a = min(salary) by emp_no @@ -806,6 +814,103 @@ public void testPlanSanityCheck() throws Exception { assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references [salary")); } + private LocalLogicalPlanOptimizer getCustomRulesLocalLogicalPlanOptimizer(List> batches) { + LocalLogicalOptimizerContext context = new LocalLogicalOptimizerContext( + EsqlTestUtils.TEST_CFG, + FoldContext.small(), + TEST_SEARCH_STATS + ); + LocalLogicalPlanOptimizer customOptimizer = new LocalLogicalPlanOptimizer(context) { + @Override + protected List> batches() { + return batches; + } + }; + return customOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + var plan = localPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(UP) { + + @Override + protected LogicalPlan rule(Aggregate plan, LocalLogicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new Eval(plan.source(), plan, List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral))); + } + return plan; + } + + } + ); + LocalLogicalPlanOptimizer customRulesLocalLogicalPlanOptimizer = getCustomRulesLocalLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalLogicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() { + var plan = localPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(DOWN) { + @Override + protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Limit limit = as(plan, Limit.class); + Limit newLimit = new Limit(plan.source(), limit.limit(), limit.child()) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + + } + ); + LocalLogicalPlanOptimizer customRulesLocalLogicalPlanOptimizer = getCustomRulesLocalLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalLogicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + } + private IsNotNull isNotNull(Expression field) { return new IsNotNull(EMPTY, field); } @@ -818,22 +923,16 @@ private LocalRelation asEmptyRelation(Object o) { private LogicalPlan plan(String query, Analyzer analyzer) { var analyzed = analyzer.analyze(parser.createStatement(query, EsqlTestUtils.TEST_CFG)); - // System.out.println(analyzed); - var optimized = logicalOptimizer.optimize(analyzed); - // System.out.println(optimized); - return optimized; + return logicalOptimizer.optimize(analyzed); } - private LogicalPlan plan(String query) { + protected LogicalPlan plan(String query) { return plan(query, analyzer); } - private LogicalPlan localPlan(LogicalPlan plan, SearchStats searchStats) { + protected LogicalPlan localPlan(LogicalPlan plan, SearchStats searchStats) { var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small(), searchStats); - // System.out.println(plan); - var localPlan = new LocalLogicalPlanOptimizer(localContext).localOptimize(plan); - // System.out.println(localPlan); - return localPlan; + return new LocalLogicalPlanOptimizer(localContext).localOptimize(plan); } private LogicalPlan localPlan(String query) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index a7035f555f593..cd6371e4d4d5e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -34,12 +34,14 @@ import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.EsqlTestUtils.TestSearchStats; +import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.Verifier; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; @@ -63,6 +65,8 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -98,6 +102,7 @@ import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.rule.Rule; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.stats.SearchContextStats; import org.elasticsearch.xpack.esql.stats.SearchStats; @@ -108,6 +113,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Locale; @@ -1371,7 +1377,7 @@ public void testMultiMatchOptionsPushDown() { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); String query = """ from test @@ -1836,6 +1842,308 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) { aggExec.forEachDown(EsQueryExec.class, esQueryExec -> { assertNull(esQueryExec.query()); }); } + public void testKnnPrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 45, "integer > 10") + ); + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testKnnPrefiltersWithMultipleFilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) + | where integer > 10 + | where keyword == "test" + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + var integerFilter = wrapWithSingleQuery(query, unscore(rangeQuery("integer").gt(10)), "integer", new Source(3, 8, "integer > 10")); + var keywordFilter = wrapWithSingleQuery( + query, + unscore(termQuery("keyword", "test")), + "keyword", + new Source(4, 8, "keyword == \"test\"") + ); + QueryBuilder expectedFilterQueryBuilder = boolQuery().must(integerFilter).must(keywordFilter); + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(integerFilter).must(keywordFilter); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testPushDownConjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + // The filter condition should be pushed down to both the KNN query and the main query + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 45, "integer > 10") + ); + + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testPushDownNegatedConjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and NOT integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + // The filter condition should be pushed down to both the KNN query and the main query + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))), + "integer", + new Source(2, 45, "NOT integer > 10") + ); + + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testNotPushDownDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) or integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + // The disjunction should not be pushed down to the KNN query + KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + QueryBuilder rangeQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 44, "integer > 10") + ); + + var expectedQuery = boolQuery().should(knnQueryBuilder).should(rangeQueryBuilder); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testNotPushDownKnnWithNonPushablePrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) AND integer > 10) and ((keyword == "test") or length(text) > 10)) + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var secondLimit = as(field.child(), LimitExec.class); + var filter = as(secondLimit.child(), FilterExec.class); + var and = as(filter.condition(), And.class); + var knn = as(and.left(), Knn.class); + assertEquals("(keyword == \"test\") or length(text) > 10", knn.filterExpressions().get(0).toString()); + assertEquals("integer > 10", knn.filterExpressions().get(1).toString()); + + var fieldExtract = as(filter.child(), FieldExtractExec.class); + var queryExec = as(fieldExtract.child(), EsQueryExec.class); + + // The query should only contain the pushable condition + QueryBuilder integerGtQuery = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 47, "integer > 10") + ); + + assertEquals(integerGtQuery.toString(), queryExec.query().toString()); + } + + public void testPushDownComplexNegationsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) + and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var queryExec = as(fieldExtract.child(), EsQueryExec.class); + + QueryBuilder notKeywordQuery = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))), + "keyword", + new Source(3, 12, "keyword == \"test\"") + ); + QueryBuilder notKeywordFilter = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))), + "keyword", + new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6], 10))") + ); + + QueryBuilder notIntegerGt10 = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))), + "integer", + new Source(2, 46, "NOT integer > 10") + ); + + KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + + firstKnn.addFilterQuery(notKeywordFilter); + secondKnn.addFilterQuery(notIntegerGt10); + + // Build the main boolean query structure + BoolQueryBuilder expectedQuery = boolQuery().must(notKeywordQuery) // NOT (keyword == "test") + .must(unscore(boolQuery().mustNot(secondKnn))) + .must(boolQuery().should(firstKnn).should(notIntegerGt10)); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testMultipleKnnQueriesInPrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + // Integer range query (right side of first OR) + QueryBuilder integerRangeQuery = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 46, "integer > 10") + ); + + // Second KNN query (right side of second OR) + KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + + // Keyword term query (left side of second OR) + QueryBuilder keywordQuery = wrapWithSingleQuery( + query, + unscore(termQuery("keyword", "test")), + "keyword", + new Source(2, 66, "keyword == \"test\"") + ); + + // First OR (knn1 OR integer > 10) + var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery); + // Second OR (keyword == "test" OR knn2) + var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery); + firstKnnQuery.addFilterQuery(keywordQuery); + secondKnnQuery.addFilterQuery(integerRangeQuery); + + // Top-level AND combining both ORs + var expectedQuery = boolQuery().must(firstOr).must(secondOr); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + public void testParallelizeTimeSeriesPlan() { assumeTrue("requires snapshot builds", Build.current().isSnapshot()); var query = "TS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1h)"; @@ -2084,20 +2392,119 @@ public void testVerifierOnMissingReferences() throws Exception { // We want to verify that the localOptimize detects the missing attribute. // However, it also throws an error in one of the rules before we get to the verifier. // So we use an implementation of LocalPhysicalPlanOptimizer that does not have any rules. + LocalPhysicalPlanOptimizer optimizerWithNoRules = getCustomRulesLocalPhysicalPlanOptimizer(List.of()); + Exception e = expectThrows(IllegalStateException.class, () -> optimizerWithNoRules.localOptimize(topNExec)); + assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references [missing attr")); + } + + private LocalPhysicalPlanOptimizer getCustomRulesLocalPhysicalPlanOptimizer(List> batches) { LocalPhysicalOptimizerContext context = new LocalPhysicalOptimizerContext( new EsqlFlags(true), config, FoldContext.small(), SearchStats.EMPTY ); - LocalPhysicalPlanOptimizer optimizerWithNoopExecute = new LocalPhysicalPlanOptimizer(context) { + LocalPhysicalPlanOptimizer localPhysicalPlanOptimizer = new LocalPhysicalPlanOptimizer(context) { @Override protected List> batches() { - return List.of(); + return batches; } }; - Exception e = expectThrows(IllegalStateException.class, () -> optimizerWithNoopExecute.localOptimize(topNExec)); - assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references [missing attr")); + return localPhysicalPlanOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + + PhysicalPlan plan = plannerOptimizer.plan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, LocalPhysicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new EvalExec( + plan.source(), + plan, + List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral)) + ); + } + return plan; + } + } + ); + LocalPhysicalPlanOptimizer customRulesLocalPhysicalPlanOptimizer = getCustomRulesLocalPhysicalPlanOptimizer( + List.of(customRuleBatch) + ); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalPhysicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() throws Exception { + + PhysicalPlan plan = plannerOptimizer.plan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, LocalPhysicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + LimitExec limit = as(plan, LimitExec.class); + LimitExec newLimit = new LimitExec( + plan.source(), + limit.child(), + new Literal(Source.EMPTY, 1000, INTEGER), + randomEstimatedRowSize() + ) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + } + ); + LocalPhysicalPlanOptimizer customRulesLocalPhysicalPlanOptimizer = getCustomRulesLocalPhysicalPlanOptimizer( + List.of(customRuleBatch) + ); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalPhysicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); } private boolean isMultiTypeEsField(Expression e) { @@ -2234,4 +2641,33 @@ public String esqlQuery() { return "qstr(\"" + fieldName() + ": " + queryString() + "\")"; } } + + private class KnnFunctionTestCase extends FullTextFunctionTestCase { + + final int k; + + KnnFunctionTestCase() { + super(Knn.class, "dense_vector", randomVector()); + k = randomIntBetween(1, 10); + } + + private static Object randomVector() { + int numDims = randomIntBetween(10, 20); + float[] vector = new float[numDims]; + for (int i = 0; i < numDims; i++) { + vector[i] = randomFloat(); + } + return vector; + } + + @Override + public QueryBuilder queryBuilder() { + return new KnnVectorQueryBuilder(fieldName(), (float[]) queryString(), k, null, null, null); + } + + @Override + public String esqlQuery() { + return "knn(" + fieldName() + ", " + Arrays.toString(((float[]) queryString())) + ", " + k + ")"; + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index b3892a0cb2cbe..a0dd67105097d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -74,6 +74,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; @@ -131,6 +132,7 @@ import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import java.time.Duration; import java.util.ArrayList; @@ -178,6 +180,8 @@ import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.contains; @@ -2901,7 +2905,7 @@ public void testPruneRedundantSortClausesUsingAlias() { public void testInsist_fieldDoesNotExist_createsUnmappedFieldInRelation() { assumeTrue("Requires UNMAPPED FIELDS", EsqlCapabilities.Cap.UNMAPPED_FIELDS.isEnabled()); - LogicalPlan plan = optimizedPlan("FROM test | INSIST_🐔 foo"); + LogicalPlan plan = optimizedPlan("FROM test | INSIST_\uD83D\uDC14 foo"); var project = as(plan, Project.class); var limit = as(project.child(), Limit.class); @@ -2912,7 +2916,7 @@ public void testInsist_fieldDoesNotExist_createsUnmappedFieldInRelation() { public void testInsist_multiIndexFieldPartiallyExistsAndIsKeyword_castsAreNotSupported() { assumeTrue("Requires UNMAPPED FIELDS", EsqlCapabilities.Cap.UNMAPPED_FIELDS.isEnabled()); - var plan = planMultiIndex("FROM multi_index | INSIST_🐔 partial_type_keyword"); + var plan = planMultiIndex("FROM multi_index | INSIST_\uD83D\uDC14 partial_type_keyword"); var project = as(plan, Project.class); var limit = as(project.child(), Limit.class); var relation = as(limit.child(), EsRelation.class); @@ -2923,7 +2927,7 @@ public void testInsist_multiIndexFieldPartiallyExistsAndIsKeyword_castsAreNotSup public void testInsist_multipleInsistClauses_insistsAreFolded() { assumeTrue("Requires UNMAPPED FIELDS", EsqlCapabilities.Cap.UNMAPPED_FIELDS.isEnabled()); - var plan = planMultiIndex("FROM multi_index | INSIST_🐔 partial_type_keyword | INSIST_🐔 foo"); + var plan = planMultiIndex("FROM multi_index | INSIST_\uD83D\uDC14 partial_type_keyword | INSIST_\uD83D\uDC14 foo"); var project = as(plan, Project.class); var limit = as(project.child(), Limit.class); var relation = as(limit.child(), EsRelation.class); @@ -5560,7 +5564,7 @@ public void testPushShadowingGeneratingPlanPastProject() { List initialGeneratedExprs = ((GeneratingPlan) initialPlan).generatedAttributes(); LogicalPlan optimizedPlan = testCase.rule.apply(initialPlan); - Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false); + Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false, initialPlan.output()); assertFalse(inconsistencies.hasFailures()); Project project = as(optimizedPlan, Project.class); @@ -5611,7 +5615,7 @@ public void testPushShadowingGeneratingPlanPastRenamingProject() { List initialGeneratedExprs = ((GeneratingPlan) initialPlan).generatedAttributes(); LogicalPlan optimizedPlan = testCase.rule.apply(initialPlan); - Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false); + Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false, initialPlan.output()); assertFalse(inconsistencies.hasFailures()); Project project = as(optimizedPlan, Project.class); @@ -5667,7 +5671,7 @@ public void testPushShadowingGeneratingPlanPastRenamingProjectWithResolution() { // This ensures that our generating plan doesn't use invalid references, resp. that any rename from the Project has // been propagated into the generating plan. - Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false); + Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false, initialPlan.output()); assertFalse(inconsistencies.hasFailures()); Project project = as(optimizedPlan, Project.class); @@ -7855,4 +7859,267 @@ public void testSampleNoPushDownChangePoint() { var topN = as(changePoint.child(), TopN.class); var source = as(topN.child(), EsRelation.class); } + + public void testPushDownConjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + var knn = as(and.left(), Knn.class); + List filterExpressions = knn.filterExpressions(); + assertThat(filterExpressions.size(), equalTo(1)); + var prefilter = as(filterExpressions.get(0), GreaterThan.class); + assertThat(and.right(), equalTo(prefilter)); + var esRelation = as(filter.child(), EsRelation.class); + } + + public void testPushDownMultipleFiltersToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) + | where integer > 10 + | where keyword == "test" + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var firstAnd = as(filter.condition(), And.class); + var knn = as(firstAnd.left(), Knn.class); + var prefilterAnd = as(firstAnd.right(), And.class); + as(prefilterAnd.left(), GreaterThan.class); + as(prefilterAnd.right(), Equals.class); + List filterExpressions = knn.filterExpressions(); + assertThat(filterExpressions.size(), equalTo(1)); + assertThat(prefilterAnd, equalTo(filterExpressions.get(0))); + } + + public void testNotPushDownDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) or integer > 10 + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var or = as(filter.condition(), Or.class); + var knn = as(or.left(), Knn.class); + List filterExpressions = knn.filterExpressions(); + assertThat(filterExpressions.size(), equalTo(0)); + } + + public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + /* + and + and + or + knn(dense_vector, [0, 1, 2], 10) + integer > 10 + keyword == "test" + or + short < 5 + double > 5.0 + */ + // Both conjunctions are pushed down to knn prefilters, disjunctions are not + var query = """ + from test + | where + ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0)) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + var leftAnd = as(and.left(), And.class); + var rightOr = as(and.right(), Or.class); + var leftOr = as(leftAnd.left(), Or.class); + var knn = as(leftOr.left(), Knn.class); + var rightOrPrefilter = as(knn.filterExpressions().get(0), Or.class); + assertThat(rightOr, equalTo(rightOrPrefilter)); + var leftAndPrefilter = as(knn.filterExpressions().get(1), Equals.class); + assertThat(leftAnd.right(), equalTo(leftAndPrefilter)); + } + + public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + /* + or + or + and + knn(dense_vector, [0, 1, 2], 10) + integer > 10 + keyword == "test" + and + short < 5 + double > 5.0 + */ + // Just the conjunction is pushed down to knn prefilters, disjunctions are not + var query = """ + from test + | where + ((knn(dense_vector, [0, 1, 2], 10) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0)) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var or = as(filter.condition(), Or.class); + var leftOr = as(or.left(), Or.class); + var leftAnd = as(leftOr.left(), And.class); + var knn = as(leftAnd.left(), Knn.class); + var rightAndPrefilter = as(knn.filterExpressions().get(0), GreaterThan.class); + assertThat(leftAnd.right(), equalTo(rightAndPrefilter)); + } + + public void testMultipleKnnQueriesInPrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + /* + and + or + knn(dense_vector, [0, 1, 2], 10) + integer > 10 + or + keyword == "test" + knn(dense_vector, [4, 5, 6], 10) + */ + var query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + + // First OR (knn1 OR integer > 10) + var firstOr = as(and.left(), Or.class); + var firstKnn = as(firstOr.left(), Knn.class); + var integerGt = as(firstOr.right(), GreaterThan.class); + + // Second OR (keyword == "test" OR knn2) + var secondOr = as(and.right(), Or.class); + as(secondOr.left(), Equals.class); + var secondKnn = as(secondOr.right(), Knn.class); + + // First KNN should have the second OR as its filter + List firstKnnFilters = firstKnn.filterExpressions(); + assertThat(firstKnnFilters.size(), equalTo(1)); + assertTrue(firstKnnFilters.contains(secondOr.left())); + + // Second KNN should have the first OR as its filter + List secondKnnFilters = secondKnn.filterExpressions(); + assertThat(secondKnnFilters.size(), equalTo(1)); + assertTrue(secondKnnFilters.contains(firstOr.right())); + } + + private LogicalPlanOptimizer getCustomRulesLogicalPlanOptimizer(List> batches) { + LogicalOptimizerContext context = new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small()); + LogicalPlanOptimizer customOptimizer = new LogicalPlanOptimizer(context) { + @Override + protected List> batches() { + return batches; + } + }; + return customOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + var plan = optimizedPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(UP) { + @Override + protected LogicalPlan rule(Aggregate plan, LogicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new Eval(plan.source(), plan, List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral))); + } + return plan; + } + + } + ); + LogicalPlanOptimizer customRulesLogicalPlanOptimizer = getCustomRulesLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLogicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() { + var plan = optimizedPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(DOWN) { + @Override + protected LogicalPlan rule(LogicalPlan plan, LogicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Limit limit = as(plan, Limit.class); + Limit newLimit = new Limit(plan.source(), limit.limit(), limit.child()) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + + } + ); + LogicalPlanOptimizer customRulesLogicalPlanOptimizer = getCustomRulesLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLogicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + } + } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index 6d7f818d922ca..6850e052eda9e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.compute.aggregation.AggregatorMode; @@ -64,11 +65,13 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent; @@ -128,12 +131,14 @@ import org.elasticsearch.xpack.esql.plan.physical.LimitExec; import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec; import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec; +import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.plan.physical.UnaryExec; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.planner.mapper.Mapper; import org.elasticsearch.xpack.esql.plugin.EsqlFlags; @@ -141,6 +146,7 @@ import org.elasticsearch.xpack.esql.querydsl.query.EqualsSyntheticSourceDelegate; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.querydsl.query.SpatialRelatesQuery; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.junit.Before; @@ -186,13 +192,16 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_SHAPE; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.util.TestUtils.stripThrough; import static org.elasticsearch.xpack.esql.parser.ExpressionBuilder.MAX_EXPRESSION_DEPTH; import static org.elasticsearch.xpack.esql.parser.LogicalPlanBuilder.MAX_QUERY_DEPTH; +import static org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests.randomEstimatedRowSize; import static org.elasticsearch.xpack.esql.planner.mapper.MapperUtils.hasScoreAttribute; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsInRelativeOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -625,16 +634,16 @@ public void testTripleExtractorPerField() { } /** - * Expected - * LimitExec[10000[INTEGER]] - * \_AggregateExec[[],[AVG(salary{f}#14) AS x],FINAL] - * \_AggregateExec[[],[AVG(salary{f}#14) AS x],PARTIAL] - * \_FilterExec[ROUND(emp_no{f}#9) > 10[INTEGER]] - * \_TopNExec[[Order[last_name{f}#13,ASC,LAST]],10[INTEGER]] - * \_ExchangeExec[] - * \_ProjectExec[[salary{f}#14, first_name{f}#10, emp_no{f}#9, last_name{f}#13]] -- project away _doc - * \_FieldExtractExec[salary{f}#14, first_name{f}#10, emp_no{f}#9, last_n..] -- local field extraction - * \_EsQueryExec[test], query[][_doc{f}#16], limit[10], sort[[last_name]] + *LimitExec[10000[INTEGER],8] + * \_AggregateExec[[],[SUM(salary{f}#13460,true[BOOLEAN]) AS x#13454],FINAL,[$$x$sum{r}#13466, $$x$seen{r}#13467],8] + * \_AggregateExec[[],[SUM(salary{f}#13460,true[BOOLEAN]) AS x#13454],INITIAL,[$$x$sum{r}#13466, $$x$seen{r}#13467],8] + * \_FilterExec[ROUND(emp_no{f}#13455) > 10[INTEGER]] + * \_TopNExec[[Order[last_name{f}#13459,ASC,LAST]],10[INTEGER],58] + * \_ExchangeExec[[emp_no{f}#13455, last_name{f}#13459, salary{f}#13460],false] + * \_ProjectExec[[emp_no{f}#13455, last_name{f}#13459, salary{f}#13460]] -- project away _doc + * \_FieldExtractExec[emp_no{f}#13455, last_name{f}#13459, salary{f}#1346..] <[],[]> -- local field extraction + * \_EsQueryExec[test], indexMode[standard], query[][_doc{f}#13482], limit[10], + * sort[[FieldSort[field=last_name{f}#13459, direction=ASC, nulls=LAST]]] estimatedRowSize[74] */ public void testExtractorForField() { var plan = physicalPlan(""" @@ -658,7 +667,7 @@ public void testExtractorForField() { var exchange = asRemoteExchange(topN.child()); var project = as(exchange.child(), ProjectExec.class); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("salary", "emp_no", "last_name")); + assertThat(names(extract.attributesToExtract()), contains("emp_no", "last_name", "salary")); var source = source(extract.child()); assertThat(source.limit(), is(topN.limit())); assertThat(source.sorts(), is(fieldSorts(topN.order()))); @@ -773,7 +782,17 @@ public void testExtractorsOverridingFields() { assertThat(names(extract.attributesToExtract()), contains("emp_no")); } - public void testDoNotExtractGroupingFields() { + /** + * LimitExec[1000[INTEGER],58] + * \_AggregateExec[[first_name{f}#3520],[SUM(salary{f}#3524,true[BOOLEAN]) AS x#3518, first_name{f}#3520],FINAL,[first_name{f}#3520, + * $$x$sum{r}#3530, $$x$seen{r}#3531],58] + * \_ExchangeExec[[first_name{f}#3520, $$x$sum{r}#3530, $$x$seen{r}#3531],true] + * \_AggregateExec[[first_name{f}#3520],[SUM(salary{f}#3524,true[BOOLEAN]) AS x#3518, first_name{f}#3520],INITIAL,[first_name{f}#352 + * 0, $$x$sum{r}#3546, $$x$seen{r}#3547],58] + * \_FieldExtractExec[first_name{f}#3520, salary{f}#3524] + * \_EsQueryExec[test], indexMode[standard], query[][_doc{f}#3548], limit[], sort[] estimatedRowSize[58] + */ + public void testDoExtractGroupingFields() { var plan = physicalPlan(""" from test | stats x = sum(salary) by first_name @@ -791,12 +810,12 @@ public void testDoNotExtractGroupingFields() { assertThat(aggregate.groupings(), hasSize(1)); var extract = as(aggregate.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), equalTo(List.of("salary"))); + assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name", "salary"))); var source = source(extract.child()); // doc id and salary are ints. salary isn't extracted. // TODO salary kind of is extracted. At least sometimes it is. should it count? - assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2)); + assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2 + 50)); } public void testExtractGroupingFieldsIfAggd() { @@ -2209,7 +2228,7 @@ public void testNoPushDownChangeCase() { * ages{f}#6, last_name{f}#7, long_noidx{f}#13, salary{f}#8],false] * \_ProjectExec[[_meta_field{f}#9, emp_no{f}#3, first_name{f}#4, gender{f}#5, hire_date{f}#10, job{f}#11, job.raw{f}#12, langu * ages{f}#6, last_name{f}#7, long_noidx{f}#13, salary{f}#8]] - * \_FieldExtractExec[_meta_field{f}#9, emp_no{f}#3, first_name{f}#4, gen..]<[],[]> + * \_FieldExtractExec[_meta_field{f}#9, emp_no{f}#3, first_name{f}#4, gen..]<[],[]> * \_EsQueryExec[test], indexMode[standard], query[{"esql_single_value":{"field":"first_name","next":{"regexp":{"first_name": * {"value":"foo*","flags_value":65791,"case_insensitive":true,"max_determinized_states":10000,"boost":0.0}}}, * "source":"TO_LOWER(first_name) RLIKE \"foo*\"@2:9"}}][_doc{f}#25], limit[1000], sort[] estimatedRowSize[332] @@ -2330,10 +2349,10 @@ public void testPushDownUpperCaseChangeLike() { * uages{f}#7, last_name{f}#8, long_noidx{f}#14, salary{f}#9],false] * \_ProjectExec[[_meta_field{f}#10, emp_no{f}#4, first_name{f}#5, gender{f}#6, hire_date{f}#11, job{f}#12, job.raw{f}#13, lang * uages{f}#7, last_name{f}#8, long_noidx{f}#14, salary{f}#9]] - * \_FieldExtractExec[_meta_field{f}#10, gender{f}#6, hire_date{f}#11, jo..]<[],[]> + * \_FieldExtractExec[_meta_field{f}#10, gender{f}#6, hire_date{f}#11, jo..]<[],[]> * \_LimitExec[1000[INTEGER]] * \_FilterExec[LIKE(first_name{f}#5, "FOO*", true) OR IN(1[INTEGER],2[INTEGER],3[INTEGER],emp_no{f}#4 + 1[INTEGER])] - * \_FieldExtractExec[first_name{f}#5, emp_no{f}#4]<[],[]> + * \_FieldExtractExec[first_name{f}#5, emp_no{f}#4]<[],[]> * \_EsQueryExec[test], indexMode[standard], query[][_doc{f}#26], limit[], sort[] estimatedRowSize[332] */ public void testChangeCaseAsInsensitiveWildcardLikeNotPushedDown() { @@ -2448,22 +2467,17 @@ public void testPushDownEvalFilter() { /** * - * ProjectExec[[last_name{f}#21 AS name, first_name{f}#18 AS last_name, last_name{f}#21 AS first_name]] - * \_TopNExec[[Order[last_name{f}#21,ASC,LAST]],10[INTEGER],0] - * \_ExchangeExec[[last_name{f}#21, first_name{f}#18],false] - * \_ProjectExec[[last_name{f}#21, first_name{f}#18]] - * \_FieldExtractExec[last_name{f}#21, first_name{f}#18][] - * \_EsQueryExec[test], indexMode[standard], query[{ - * "bool":{"must":[ - * {"esql_single_value":{ - * "field":"last_name", - * "next":{"range":{"last_name":{"gt":"B","boost":1.0}}}, - * "source":"first_name > \"B\"@3:9" - * }}, - * {"exists":{"field":"first_name","boost":1.0}} - * ],"boost":1.0}}][_doc{f}#40], limit[10], sort[[ - * FieldSort[field=last_name{f}#21, direction=ASC, nulls=LAST] - * ]] estimatedRowSize[116] + * ProjectExec[[last_name{f}#13858 AS name#13841, first_name{f}#13855 AS last_name#13844, last_name{f}#13858 AS first_name#13 + * 847]] + * \_TopNExec[[Order[last_name{f}#13858,ASC,LAST]],10[INTEGER],100] + * \_ExchangeExec[[first_name{f}#13855, last_name{f}#13858],false] + * \_ProjectExec[[first_name{f}#13855, last_name{f}#13858]] + * \_FieldExtractExec[first_name{f}#13855, last_name{f}#13858]<[],[]> + * \_EsQueryExec[test], indexMode[standard], query[ + * {"bool":{"must":[{"esql_single_value":{"field":"last_name","next": + * {"range":{"last_name":{"gt":"B","boost":0.0}}},"source":"first_name > \"B\"@3:9"}}, + * {"exists":{"field":"first_name","boost":0.0}}],"boost":1.0}} + * ][_doc{f}#13879], limit[10], sort[[FieldSort[field=last_name{f}#13858, direction=ASC, nulls=LAST]]] estimatedRowSize[116] * */ public void testPushDownEvalSwapFilter() { @@ -2484,7 +2498,7 @@ public void testPushDownEvalSwapFilter() { var extract = as(project.child(), FieldExtractExec.class); assertThat( extract.attributesToExtract().stream().map(Attribute::name).collect(Collectors.toList()), - contains("last_name", "first_name") + contains("first_name", "last_name") ); // Now verify the correct Lucene push-down of both the filter and the sort @@ -2597,7 +2611,7 @@ public void testDissect() { * uages{f}#7, last_name{f}#8, long_noidx{f}#14, salary{f}#9, _index{m}#2],false] * \_ProjectExec[[_meta_field{f}#10, emp_no{f}#4, first_name{f}#5, gender{f}#6, hire_date{f}#11, job{f}#12, job.raw{f}#13, lang * uages{f}#7, last_name{f}#8, long_noidx{f}#14, salary{f}#9, _index{m}#2]] - * \_FieldExtractExec[_meta_field{f}#10, emp_no{f}#4, first_name{f}#5, ge..]<[],[]> + * \_FieldExtractExec[_meta_field{f}#10, emp_no{f}#4, first_name{f}#5, ge..]<[],[]> * \_EsQueryExec[test], indexMode[standard], query[{"wildcard":{"_index":{"wildcard":"test*","boost":0.0}}}][_doc{f}#27], * limit[1000], sort[] estimatedRowSize[382] * @@ -2866,7 +2880,7 @@ public void testFieldExtractWithoutSourceAttributes() { ) ); - var e = expectThrows(VerificationException.class, () -> physicalPlanOptimizer.verify(badPlan)); + var e = expectThrows(VerificationException.class, () -> physicalPlanOptimizer.verify(badPlan, verifiedPlan.output())); assertThat( e.getMessage(), containsString( @@ -2881,7 +2895,7 @@ public void testVerifierOnMissingReferences() { | stats s = sum(salary) by emp_no | where emp_no > 10 """); - + final var planBeforeModification = plan; plan = plan.transformUp( AggregateExec.class, a -> new AggregateExec( @@ -2895,7 +2909,7 @@ public void testVerifierOnMissingReferences() { ) ); final var finalPlan = plan; - var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan)); + var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan, planBeforeModification.output())); assertThat(e.getMessage(), containsString(" > 10[INTEGER]]] optimized incorrectly due to missing references [emp_no{f}#")); } @@ -2913,7 +2927,7 @@ public void testVerifierOnMissingReferencesWithBinaryPlans() throws Exception { var planWithInvalidJoinLeftSide = plan.transformUp(LookupJoinExec.class, join -> join.replaceChildren(join.right(), join.right())); - var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinLeftSide)); + var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinLeftSide, plan.output())); assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references from left hand side [languages")); var planWithInvalidJoinRightSide = plan.transformUp( @@ -2930,7 +2944,7 @@ public void testVerifierOnMissingReferencesWithBinaryPlans() throws Exception { ) ); - e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinRightSide)); + e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinRightSide, plan.output())); assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references from right hand side [language_code")); } @@ -2940,7 +2954,7 @@ public void testVerifierOnDuplicateOutputAttributes() { | stats s = sum(salary) by emp_no | where emp_no > 10 """); - + final var planBeforeModification = plan; plan = plan.transformUp(AggregateExec.class, a -> { List intermediates = new ArrayList<>(a.intermediateAttributes()); intermediates.add(intermediates.get(0)); @@ -2955,7 +2969,7 @@ public void testVerifierOnDuplicateOutputAttributes() { ); }); final var finalPlan = plan; - var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan)); + var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan, planBeforeModification.output())); assertThat( e.getMessage(), containsString("Plan [LimitExec[1000[INTEGER],null]] optimized incorrectly due to duplicate output attribute emp_no{f}#") @@ -3166,6 +3180,56 @@ public void testProjectAwayAllColumnsWhenOnlyTheCountMattersInStats() { assertThat(Expressions.names(esQuery.attrs()), contains("_doc")); } + /** + * LimitExec[1000[INTEGER],336] + * \_MvExpandExec[foo_1{r}#4236,foo_1{r}#4253] + * \_TopNExec[[Order[emp_no{f}#4242,ASC,LAST]],1000[INTEGER],336] + * \_ExchangeExec[[_meta_field{f}#4248, emp_no{f}#4242, first_name{f}#4243, gender{f}#4244, hire_date{f}#4249, job{f}#4250, job. + * raw{f}#4251, languages{f}#4245, last_name{f}#4246, long_noidx{f}#4252, salary{f}#4247, foo_1{r}#4236, foo_2{r}#4238], + * false] + * \_ProjectExec[[_meta_field{f}#4248, emp_no{f}#4242, first_name{f}#4243, gender{f}#4244, hire_date{f}#4249, job{f}#4250, job. + * raw{f}#4251, languages{f}#4245, last_name{f}#4246, long_noidx{f}#4252, salary{f}#4247, foo_1{r}#4236, foo_2{r}#4238]] + * \_FieldExtractExec[_meta_field{f}#4248, emp_no{f}#4242, first_name{f}#..]<[],[]> + * \_EvalExec[[1[INTEGER] AS foo_1#4236, 1[INTEGER] AS foo_2#4238]] + * \_EsQueryExec[test], indexMode[standard], query[][_doc{f}#4268], limit[1000], sort[[FieldSort[field=emp_no{f}#4242, + * direction=ASC, nulls=LAST]]] estimatedRowSize[352] + */ + public void testProjectAwayMvExpandColumnOrder() { + var plan = optimizedPlan(physicalPlan(""" + from test + | eval foo_1 = 1, foo_2 = 1 + | sort emp_no + | mv_expand foo_1 + """)); + var limit = as(plan, LimitExec.class); + var mvExpand = as(limit.child(), MvExpandExec.class); + var topN = as(mvExpand.child(), TopNExec.class); + var exchange = as(topN.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + + assertThat( + Expressions.names(project.projections()), + containsInRelativeOrder( + "_meta_field", + "emp_no", + "first_name", + "gender", + "hire_date", + "job", + "job.raw", + "languages", + "last_name", + "long_noidx", + "salary", + "foo_1", + "foo_2" + ) + ); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var eval = as(fieldExtract.child(), EvalExec.class); + EsQueryExec esQuery = as(eval.child(), EsQueryExec.class); + } + /** * ProjectExec[[a{r}#5]] * \_EvalExec[[__a_SUM@81823521{r}#15 / __a_COUNT@31645621{r}#16 AS a]] @@ -4045,7 +4109,7 @@ public void testSpatialTypesAndStatsUseDocValuesMultiAggregationsGrouped() { * \_AggregateExec[[scalerank{f}#16],[SPATIALCENTROID(location{f}#18) AS centroid, COUNT([2a][KEYWORD]) AS count],FINAL,58] * \_ExchangeExec[[scalerank{f}#16, xVal{r}#19, xDel{r}#20, yVal{r}#21, yDel{r}#22, count{r}#23, count{r}#24, seen{r}#25],true] * \_AggregateExec[[scalerank{f}#16],[SPATIALCENTROID(location{f}#18) AS centroid, COUNT([2a][KEYWORD]) AS count],PARTIAL,58] - * \_FieldExtractExec[location{f}#18][location{f}#18] + * \_FieldExtractExec[scalerank{f}#16][location{f}#18][location{f}#18] * \_EsQueryExec[airports], query[][_doc{f}#42], limit[], sort[] estimatedRowSize[54] * * Note the FieldExtractExec has 'location' set for stats: FieldExtractExec[location{f}#9][location{f}#9] @@ -5474,6 +5538,7 @@ public void testPushSpatialDistanceEvalWithSimpleStatsToSource() { * \_ExchangeExec[[country{f}#21, count{r}#24, seen{r}#25, xVal{r}#26, xDel{r}#27, yVal{r}#28, yDel{r}#29, count{r}#30],true] * \_AggregateExec[[country{f}#21],[COUNT([2a][KEYWORD]) AS count, SPATIALCENTROID(location{f}#20) AS centroid, country{f}#21],INIT * IAL,[country{f}#21, count{r}#49, seen{r}#50, xVal{r}#51, xDel{r}#52, yVal{r}#53, yDel{r}#54, count{r}#55],79] + * \_FieldExtractExec[country{f}#15254] * \_EvalExec[[STDISTANCE(location{f}#20,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) * AS distance]] * \_FieldExtractExec[location{f}#20][location{f}#20] @@ -5550,7 +5615,8 @@ public void testPushSpatialDistanceEvalWithStatsToSource() { var exchangeExec = as(aggExec.child(), ExchangeExec.class); var aggExec2 = as(exchangeExec.child(), AggregateExec.class); // TODO: Remove the eval entirely, since the distance is no longer required after filter pushdown - var evalExec = as(aggExec2.child(), EvalExec.class); + var extract = as(aggExec2.child(), FieldExtractExec.class); + var evalExec = as(extract.child(), EvalExec.class); var stDistance = as(evalExec.fields().get(0).child(), StDistance.class); assertThat("Expect distance function to expect doc-values", stDistance.leftDocValues(), is(true)); var source = assertChildIsGeoPointExtract(evalExec, FieldExtractPreference.DOC_VALUES); @@ -5662,16 +5728,15 @@ public void testPushTopNWithFilterToSource() { } /** - * ProjectExec[[abbrev{f}#12321, name{f}#12322, location{f}#12325, country{f}#12326, city{f}#12327]] - * \_TopNExec[[Order[abbrev{f}#12321,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#12321, name{f}#12322, location{f}#12325, country{f}#12326, city{f}#12327],false] - * \_ProjectExec[[abbrev{f}#12321, name{f}#12322, location{f}#12325, country{f}#12326, city{f}#12327]] - * \_FieldExtractExec[abbrev{f}#12321, name{f}#12322, location{f}#12325, ..][] + * ProjectExec[[abbrev{f}#4474, name{f}#4475, location{f}#4478, country{f}#4479, city{f}#4480]] + * \_TopNExec[[Order[abbrev{f}#4474,ASC,LAST]],5[INTEGER],221] + * \_ExchangeExec[[abbrev{f}#4474, city{f}#4480, country{f}#4479, location{f}#4478, name{f}#4475],false] + * \_ProjectExec[[abbrev{f}#4474, city{f}#4480, country{f}#4479, location{f}#4478, name{f}#4475]] + * \_FieldExtractExec[abbrev{f}#4474, city{f}#4480, country{f}#4479, loca..]<[],[]> * \_EsQueryExec[airports], - * indexMode[standard], - * query[][_doc{f}#12337], - * limit[5], - * sort[[FieldSort[field=abbrev{f}#12321, direction=ASC, nulls=LAST]]] estimatedRowSize[237] + * indexMode[standard], + * query[][_doc{f}#4490], + * limit[5], sort[[FieldSort[field=abbrev{f}#4474, direction=ASC, nulls=LAST]]] estimatedRowSize[237] */ public void testPushTopNKeywordToSource() { var optimized = optimizedPlan(physicalPlan(""" @@ -5686,9 +5751,9 @@ public void testPushTopNKeywordToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "location", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "location", "name")); var source = source(extract.child()); assertThat(source.limit(), is(topN.limit())); assertThat(source.sorts(), is(fieldSorts(topN.order()))); @@ -5704,13 +5769,13 @@ public void testPushTopNKeywordToSource() { /** * - * ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18, abbrev{f}#12 AS code]] - * \_TopNExec[[Order[abbrev{f}#12,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18],false] - * \_ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18]] - * \_FieldExtractExec[abbrev{f}#12, name{f}#13, location{f}#16, country{f..][] - * \_EsQueryExec[airports], indexMode[standard], query[][_doc{f}#29], limit[5], - * sort[[FieldSort[field=abbrev{f}#12, direction=ASC, nulls=LAST]]] estimatedRowSize[237] + * ProjectExec[[abbrev{f}#7828, name{f}#7829, location{f}#7832, country{f}#7833, city{f}#7834, abbrev{f}#7828 AS code#7820]] + * \_TopNExec[[Order[abbrev{f}#7828,ASC,LAST]],5[INTEGER],221] + * \_ExchangeExec[[abbrev{f}#7828, city{f}#7834, country{f}#7833, location{f}#7832, name{f}#7829],false] + * \_ProjectExec[[abbrev{f}#7828, city{f}#7834, country{f}#7833, location{f}#7832, name{f}#7829]] + * \_FieldExtractExec[abbrev{f}#7828, city{f}#7834, country{f}#7833, loca..]<[],[]> + * \_EsQueryExec[airports], indexMode[standard], query[][_doc{f}#7845], limit[5], + * sort[[FieldSort[field=abbrev{f}#7828, direction=ASC, nulls=LAST]]] estimatedRowSize[237] * */ public void testPushTopNAliasedKeywordToSource() { @@ -5728,9 +5793,9 @@ public void testPushTopNAliasedKeywordToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "location", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "location", "name")); var source = source(extract.child()); assertThat(source.limit(), is(topN.limit())); assertThat(source.sorts(), is(fieldSorts(topN.order()))); @@ -5745,19 +5810,19 @@ public void testPushTopNAliasedKeywordToSource() { } /** - * ProjectExec[[abbrev{f}#11, name{f}#12, location{f}#15, country{f}#16, city{f}#17]] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#11, name{f}#12, location{f}#15, country{f}#16, city{f}#17, distance{r}#4],false] - * \_ProjectExec[[abbrev{f}#11, name{f}#12, location{f}#15, country{f}#16, city{f}#17, distance{r}#4]] - * \_FieldExtractExec[abbrev{f}#11, name{f}#12, country{f}#16, city{f}#17][] - * \_EvalExec[[STDISTANCE(location{f}#15,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) - * AS distance]] - * \_FieldExtractExec[location{f}#15][] + * ProjectExec[[abbrev{f}#7283, name{f}#7284, location{f}#7287, country{f}#7288, city{f}#7289]] + * \_TopNExec[[Order[distance{r}#7276,ASC,LAST]],5[INTEGER],229] + * \_ExchangeExec[[abbrev{f}#7283, city{f}#7289, country{f}#7288, location{f}#7287, name{f}#7284, distance{r}#7276],false] + * \_ProjectExec[[abbrev{f}#7283, city{f}#7289, country{f}#7288, location{f}#7287, name{f}#7284, distance{r}#7276]] + * \_FieldExtractExec[abbrev{f}#7283, city{f}#7289, country{f}#7288, name..]<[],[]> + * \_EvalExec[[STDISTANCE(location{f}#7287,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distan + * ce#7276]] + * \_FieldExtractExec[location{f}#7287]<[],[]> * \_EsQueryExec[airports], - * indexMode[standard], - * query[][_doc{f}#28], - * limit[5], - * sort[[GeoDistanceSort[field=location{f}#15, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] + * indexMode[standard], + * query[][_doc{f}#7300], + * limit[5], + * sort[[GeoDistanceSort[field=location{f}#7287, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] */ public void testPushTopNDistanceToSource() { var optimized = optimizedPlan(physicalPlan(""" @@ -5773,9 +5838,9 @@ public void testPushTopNDistanceToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "distance")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var evalExec = as(extract.child(), EvalExec.class); var alias = as(evalExec.fields().get(0), Alias.class); assertThat(alias.name(), is("distance")); @@ -5802,20 +5867,19 @@ public void testPushTopNDistanceToSource() { } /** - * ProjectExec[[abbrev{f}#8, name{f}#9, location{f}#12, country{f}#13, city{f}#14]] - * \_TopNExec[[Order[$$order_by$0$0{r}#16,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#8, name{f}#9, location{f}#12, country{f}#13, city{f}#14, $$order_by$0$0{r}#16],false] - * \_ProjectExec[[abbrev{f}#8, name{f}#9, location{f}#12, country{f}#13, city{f}#14, $$order_by$0$0{r}#16]] - * \_FieldExtractExec[abbrev{f}#8, name{f}#9, country{f}#13, city{f}#14][] - * \_EvalExec[[ - * STDISTANCE(location{f}#12,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS $$order_by$0$0 - * ]] - * \_FieldExtractExec[location{f}#12][] + *ProjectExec[[abbrev{f}#5258, name{f}#5259, location{f}#5262, country{f}#5263, city{f}#5264]] + * \_TopNExec[[Order[$$order_by$0$0{r}#5266,ASC,LAST]],5[INTEGER],229] + * \_ExchangeExec[[abbrev{f}#5258, city{f}#5264, country{f}#5263, location{f}#5262, name{f}#5259, $$order_by$0$0{r}#5266],false] + * \_ProjectExec[[abbrev{f}#5258, city{f}#5264, country{f}#5263, location{f}#5262, name{f}#5259, $$order_by$0$0{r}#5266]] + * \_FieldExtractExec[abbrev{f}#5258, city{f}#5264, country{f}#5263, name..]<[],[]> + * \_EvalExec[[STDISTANCE(location{f}#5262,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS $$orde + * r_by$0$0#5266]] + * \_FieldExtractExec[location{f}#5262]<[],[]> * \_EsQueryExec[airports], - * indexMode[standard], - * query[][_doc{f}#26], - * limit[5], - * sort[[GeoDistanceSort[field=location{f}#12, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] + * indexMode[standard], + * query[][_doc{f}#5276], + * limit[5], + * sort[[GeoDistanceSort[field=location{f}#5262, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] */ public void testPushTopNInlineDistanceToSource() { var optimized = optimizedPlan(physicalPlan(""" @@ -5835,15 +5899,15 @@ public void testPushTopNInlineDistanceToSource() { names(project.projections()), contains( equalTo("abbrev"), - equalTo("name"), - equalTo("location"), - equalTo("country"), equalTo("city"), + equalTo("country"), + equalTo("location"), + equalTo("name"), startsWith("$$order_by$0$") ) ); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var evalExec = as(extract.child(), EvalExec.class); var alias = as(evalExec.fields().get(0), Alias.class); assertThat(alias.name(), startsWith("$$order_by$0$")); @@ -5872,14 +5936,14 @@ public void testPushTopNInlineDistanceToSource() { /** * - * ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18]] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18, distance{r}#4],false] - * \_ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18, distance{r}#4]] - * \_FieldExtractExec[abbrev{f}#12, name{f}#13, country{f}#17, city{f}#18][] - * \_EvalExec[[STDISTANCE(location{f}#16,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance - * ]] - * \_FieldExtractExec[location{f}#16][] + * ProjectExec[[abbrev{f}#361, name{f}#362, location{f}#365, country{f}#366, city{f}#367]] + * \_TopNExec[[Order[distance{r}#353,ASC,LAST]],5[INTEGER],229] + * \_ExchangeExec[[abbrev{f}#361, city{f}#367, country{f}#366, location{f}#365, name{f}#362, distance{r}#353],false] + * \_ProjectExec[[abbrev{f}#361, city{f}#367, country{f}#366, location{f}#365, name{f}#362, distance{r}#353]] + * \_FieldExtractExec[abbrev{f}#361, city{f}#367, country{f}#366, name{f}..]<[],[]> + * \_EvalExec[[STDISTANCE(location{f}#365,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distanc + * e#353]] + * \_FieldExtractExec[location{f}#365]<[],[]> * \_EsQueryExec[airports], indexMode[standard], query[ * { * "geo_shape":{ @@ -5892,7 +5956,7 @@ public void testPushTopNInlineDistanceToSource() { * } * } * } - * }][_doc{f}#29], limit[5], sort[[GeoDistanceSort[field=location{f}#16, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] + * ][_doc{f}#378], limit[5], sort[[GeoDistanceSort[field=location{f}#365, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] * */ public void testPushTopNDistanceWithFilterToSource() { @@ -5910,9 +5974,9 @@ public void testPushTopNDistanceWithFilterToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "distance")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var evalExec = as(extract.child(), EvalExec.class); var alias = as(evalExec.fields().get(0), Alias.class); assertThat(alias.name(), is("distance")); @@ -5948,48 +6012,25 @@ public void testPushTopNDistanceWithFilterToSource() { /** * - * ProjectExec[[abbrev{f}#14, name{f}#15, location{f}#18, country{f}#19, city{f}#20]] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#14, name{f}#15, location{f}#18, country{f}#19, city{f}#20, distance{r}#4],false] - * \_ProjectExec[[abbrev{f}#14, name{f}#15, location{f}#18, country{f}#19, city{f}#20, distance{r}#4]] - * \_FieldExtractExec[abbrev{f}#14, name{f}#15, country{f}#19, city{f}#20][] - * \_EvalExec[[STDISTANCE(location{f}#18,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) - * AS distance]] - * \_FieldExtractExec[location{f}#18][] - * \_EsQueryExec[airports], indexMode[standard], query[{ - * "bool":{ - * "filter":[ - * { - * "esql_single_value":{ - * "field":"scalerank", - * "next":{"range":{"scalerank":{"lt":6,"boost":1.0}}}, - * "source":"scalerank lt 6@3:31" - * } - * }, - * { - * "bool":{ - * "must":[ - * {"geo_shape":{ - * "location":{ - * "relation":"INTERSECTS", - * "shape":{"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]} - * } - * }}, - * {"geo_shape":{ - * "location":{ - * "relation":"DISJOINT", - * "shape":{"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]} - * } - * }} - * ], - * "boost":1.0 - * } - * } - * ], - * "boost":1.0 - * }}][_doc{f}#31], limit[5], sort[[ - * GeoDistanceSort[field=location{f}#18, direction=ASC, lat=55.673, lon=12.565] - * ]] estimatedRowSize[245] + * ProjectExec[[abbrev{f}#6367, name{f}#6368, location{f}#6371, country{f}#6372, city{f}#6373]] + * \_TopNExec[[Order[distance{r}#6357,ASC,LAST]],5[INTEGER],229] + * \_ExchangeExec[[abbrev{f}#6367, city{f}#6373, country{f}#6372, location{f}#6371, name{f}#6368, distance{r}#6357],false] + * \_ProjectExec[[abbrev{f}#6367, city{f}#6373, country{f}#6372, location{f}#6371, name{f}#6368, distance{r}#6357]] + * \_FieldExtractExec[abbrev{f}#6367, city{f}#6373, country{f}#6372, name..]<[],[]> + * \_EvalExec[[STDISTANCE(location{f}#6371,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distan + * ce#6357]] + * \_FieldExtractExec[location{f}#6371]<[],[]> + * \_EsQueryExec[airports], indexMode[standard], query[ + * {"bool":{"filter":[{"esql_single_value":{"field":"scalerank","next":{"range": + * {"scalerank":{"lt":6,"boost":0.0}}},"source":"scalerank < 6@3:31"}}, + * {"bool":{"must":[{"geo_shape": + * {"location":{"relation":"INTERSECTS","shape": + * {"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape": + * {"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]}}}}] + * ,"boost":1.0}}],"boost":1.0}} + * ][_doc{f}#6384], limit[5], sort[ + * [GeoDistanceSort[field=location{f}#6371, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] * */ public void testPushTopNDistanceWithCompoundFilterToSource() { @@ -6007,9 +6048,9 @@ public void testPushTopNDistanceWithCompoundFilterToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "distance")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var evalExec = as(extract.child(), EvalExec.class); var alias = as(evalExec.fields().get(0), Alias.class); assertThat(alias.name(), is("distance")); @@ -6047,35 +6088,28 @@ public void testPushTopNDistanceWithCompoundFilterToSource() { /** * Tests that multiple sorts, including distance and a field, are pushed down to the source. * - * ProjectExec[[abbrev{f}#25, name{f}#26, location{f}#29, country{f}#30, city{f}#31, scalerank{f}#27, scale{r}#7]] - * \_TopNExec[[ - * Order[distance{r}#4,ASC,LAST], - * Order[scalerank{f}#27,ASC,LAST], - * Order[scale{r}#7,DESC,FIRST], - * Order[loc{r}#10,DESC,FIRST] - * ],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#25, name{f}#26, location{f}#29, country{f}#30, city{f}#31, scalerank{f}#27, scale{r}#7, - * distance{r}#4, loc{r}#10],false] - * \_ProjectExec[[abbrev{f}#25, name{f}#26, location{f}#29, country{f}#30, city{f}#31, scalerank{f}#27, scale{r}#7, - * distance{r}#4, loc{r}#10]] - * \_FieldExtractExec[abbrev{f}#25, name{f}#26, country{f}#30, city{f}#31][] - * \_EvalExec[[ - * STDISTANCE(location{f}#29,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance, - * 10[INTEGER] - scalerank{f}#27 AS scale, TOSTRING(location{f}#29) AS loc - * ]] - * \_FieldExtractExec[location{f}#29, scalerank{f}#27][] - * \_EsQueryExec[airports], indexMode[standard], query[{ - * "bool":{ - * "filter":[ - * {"esql_single_value":{"field":"scalerank","next":{...},"source":"scalerank < 6@3:31"}}, - * {"bool":{ - * "must":[ - * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, - * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} - * ],"boost":1.0}}],"boost":1.0}}][_doc{f}#44], limit[5], sort[[ - * GeoDistanceSort[field=location{f}#29, direction=ASC, lat=55.673, lon=12.565], - * FieldSort[field=scalerank{f}#27, direction=ASC, nulls=LAST] - * ]] estimatedRowSize[303] + * ProjectExec[[abbrev{f}#7429, name{f}#7430, location{f}#7433, country{f}#7434, city{f}#7435, scalerank{f}#7431, scale{r}#74 + * 11]] + * \_TopNExec[[Order[distance{r}#7408,ASC,LAST], Order[scalerank{f}#7431,ASC,LAST], Order[scale{r}#7411,DESC,FIRST], Order[l + * oc{r}#7414,DESC,FIRST]],5[INTEGER],287] + * \_ExchangeExec[[abbrev{f}#7429, city{f}#7435, country{f}#7434, location{f}#7433, name{f}#7430, scalerank{f}#7431, distance{r} + * #7408, scale{r}#7411, loc{r}#7414],false] + * \_ProjectExec[[abbrev{f}#7429, city{f}#7435, country{f}#7434, location{f}#7433, name{f}#7430, scalerank{f}#7431, distance{r} + * #7408, scale{r}#7411, loc{r}#7414]] + * \_FieldExtractExec[abbrev{f}#7429, city{f}#7435, country{f}#7434, name..]<[],[]> + * \_EvalExec[[STDISTANCE(location{f}#7433,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distan + * ce#7408, 10[INTEGER] - scalerank{f}#7431 AS scale#7411, TOSTRING(location{f}#7433) AS loc#7414]] + * \_FieldExtractExec[location{f}#7433, scalerank{f}#7431]<[],[]> + * \_EsQueryExec[airports], indexMode[standard], query[ + * {"bool":{"filter":[{"esql_single_value":{"field":"scalerank","next": + * {"range":{"scalerank":{"lt":6,"boost":0.0}}},"source":"scalerank < 6@3:31"}}, + * {"bool":{"must":[{"geo_shape":{"location":{"relation":"INTERSECTS","shape": + * {"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape": + * {"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]}}}}], + * "boost":1.0}}],"boost":1.0}}][_doc{f}#7448], limit[5], sort[ + * [GeoDistanceSort[field=location{f}#7433, direction=ASC, lat=55.673, lon=12.565], + * FieldSort[field=scalerank{f}#7431, direction=ASC, nulls=LAST]]] estimatedRowSize[303] * */ public void testPushTopNDistanceAndPushableFieldWithCompoundFilterToSource() { @@ -6096,10 +6130,10 @@ public void testPushTopNDistanceAndPushableFieldWithCompoundFilterToSource() { project = as(exchange.child(), ProjectExec.class); assertThat( names(project.projections()), - contains("abbrev", "name", "location", "country", "city", "scalerank", "scale", "distance", "loc") + contains("abbrev", "city", "country", "location", "name", "scalerank", "distance", "scale", "loc") ); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var evalExec = as(extract.child(), EvalExec.class); var alias = as(evalExec.fields().get(0), Alias.class); assertThat(alias.name(), is("distance")); @@ -6141,26 +6175,30 @@ public void testPushTopNDistanceAndPushableFieldWithCompoundFilterToSource() { /** * This test shows that if the filter contains a predicate on the same field that is sorted, we cannot push down the sort. * - * ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scalerank{f}#25 AS scale]] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scalerank{f}#25,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scalerank{f}#25, distance{r}#4],false] - * \_ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scalerank{f}#25, distance{r}#4]] - * \_FieldExtractExec[abbrev{f}#23, name{f}#24, country{f}#28, city{f}#29][] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scalerank{f}#25,ASC,LAST]],5[INTEGER],208] - * \_FieldExtractExec[scalerank{f}#25][] - * \_FilterExec[SUBSTRING(position{r}#7,1[INTEGER],5[INTEGER]) == [50 4f 49 4e 54][KEYWORD]] - * \_EvalExec[[ - * STDISTANCE(location{f}#27,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance, - * TOSTRING(location{f}#27) AS position - * ]] - * \_FieldExtractExec[location{f}#27][] - * \_EsQueryExec[airports], indexMode[standard], query[{ - * "bool":{"filter":[ - * {"esql_single_value":{"field":"scalerank","next":{"range":{"scalerank":{"lt":6,"boost":1.0}}},"source":...}}, - * {"bool":{"must":[ - * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, - * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} - * ],"boost":1.0}}],"boost":1.0}}][_doc{f}#42], limit[], sort[] estimatedRowSize[87] + * ProjectExec[[abbrev{f}#4856, name{f}#4857, location{f}#4860, country{f}#4861, city{f}#4862, scalerank{f}#4858 AS scale#484 + * 3]] + * \_TopNExec[[Order[distance{r}#4837,ASC,LAST], Order[scalerank{f}#4858,ASC,LAST]],5[INTEGER],233] + * \_ExchangeExec[[abbrev{f}#4856, city{f}#4862, country{f}#4861, location{f}#4860, name{f}#4857, scalerank{f}#4858, distance{r} + * #4837],false] + * \_ProjectExec[[abbrev{f}#4856, city{f}#4862, country{f}#4861, location{f}#4860, name{f}#4857, scalerank{f}#4858, distance{r} + * #4837]] + * \_FieldExtractExec[abbrev{f}#4856, city{f}#4862, country{f}#4861, name..]<[],[]> + * \_TopNExec[[Order[distance{r}#4837,ASC,LAST], Order[scalerank{f}#4858,ASC,LAST]],5[INTEGER],303] + * \_FieldExtractExec[scalerank{f}#4858]<[],[]> + * \_FilterExec[SUBSTRING(position{r}#4840,1[INTEGER],5[INTEGER]) == POINT[KEYWORD]] + * \_EvalExec[[STDISTANCE(location{f}#4860,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS + * distance#4837, TOSTRING(location{f}#4860) AS position#4840]] + * \_FieldExtractExec[location{f}#4860]<[],[]> + * \_EsQueryExec[airports], indexMode[standard], query[ + * {"bool":{"filter":[ + * {"esql_single_value": + * {"field":"scalerank","next":{"range":{"scalerank":{"lt":6,"boost":0.0}}},"source":"scale < 6@3:93"}}, + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape": + * {"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape": + * {"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]}}}} + * ],"boost":1.0}}],"boost":1.0}}][_doc{f}#4875], limit[], sort[] estimatedRowSize[87] * */ public void testPushTopNDistanceAndNonPushableEvalWithCompoundFilterToSource() { @@ -6179,9 +6217,9 @@ public void testPushTopNDistanceAndNonPushableEvalWithCompoundFilterToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "scalerank", "distance")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var topNChild = as(extract.child(), TopNExec.class); extract = as(topNChild.child(), FieldExtractExec.class); assertThat(names(extract.attributesToExtract()), contains("scalerank")); @@ -6216,27 +6254,25 @@ public void testPushTopNDistanceAndNonPushableEvalWithCompoundFilterToSource() { /** * This test shows that if the filter contains a predicate on the same field that is sorted, we cannot push down the sort. * - * ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scale{r}#10]] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scale{r}#10,ASC,LAST]],5[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scale{r}#10, distance{r}#4],false] - * \_ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scale{r}#10, distance{r}#4]] - * \_FieldExtractExec[abbrev{f}#23, name{f}#24, country{f}#28, city{f}#29][] - * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scale{r}#10,ASC,LAST]],5[INTEGER],208] - * \_FilterExec[ - * SUBSTRING(position{r}#7,1[INTEGER],5[INTEGER]) == [50 4f 49 4e 54][KEYWORD] - * AND scale{r}#10 > 3[INTEGER] - * ] - * \_EvalExec[[ - * STDISTANCE(location{f}#27,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance, - * TOSTRING(location{f}#27) AS position, - * 10[INTEGER] - scalerank{f}#25 AS scale - * ]] - * \_FieldExtractExec[location{f}#27, scalerank{f}#25][] - * \_EsQueryExec[airports], indexMode[standard], query[{ - * "bool":{"must":[ - * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, - * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} - * ],"boost":1.0}}][_doc{f}#42], limit[], sort[] estimatedRowSize[91] + *ProjectExec[[abbrev{f}#1447, name{f}#1448, location{f}#1451, country{f}#1452, city{f}#1453, scalerank{r}#1434]] + * \_TopNExec[[Order[distance{r}#1428,ASC,LAST], Order[scalerank{r}#1434,ASC,LAST]],5[INTEGER],233] + * \_ExchangeExec[[abbrev{f}#1447, city{f}#1453, country{f}#1452, location{f}#1451, name{f}#1448, distance{r}#1428, scalerank{r} + * #1434],false] + * \_ProjectExec[[abbrev{f}#1447, city{f}#1453, country{f}#1452, location{f}#1451, name{f}#1448, distance{r}#1428, scalerank{r} + * #1434]] + * \_FieldExtractExec[abbrev{f}#1447, city{f}#1453, country{f}#1452, name..]<[],[]> + * \_TopNExec[[Order[distance{r}#1428,ASC,LAST], Order[scalerank{r}#1434,ASC,LAST]],5[INTEGER],303] + * \_FilterExec[SUBSTRING(position{r}#1431,1[INTEGER],5[INTEGER]) == POINT[KEYWORD] AND scalerank{r}#1434 > 3[INTEGER]] + * \_EvalExec[[STDISTANCE(location{f}#1451,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distan + * ce#1428, TOSTRING(location{f}#1451) AS position#1431, 10[INTEGER] - scalerank{f}#1449 AS scalerank#1434]] + * \_FieldExtractExec[location{f}#1451, scalerank{f}#1449]<[],[]> + * \_EsQueryExec[airports], indexMode[standard], query[ + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape": + * {"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape": + * {"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]}}}} + * ],"boost":1.0}}][_doc{f}#1466], limit[], sort[] estimatedRowSize[91] * */ public void testPushTopNDistanceAndNonPushableEvalsWithCompoundFilterToSource() { @@ -6255,9 +6291,9 @@ public void testPushTopNDistanceAndNonPushableEvalsWithCompoundFilterToSource() var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "distance", "scalerank")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var topNChild = as(extract.child(), TopNExec.class); var filter = as(topNChild.child(), FilterExec.class); assertThat(filter.condition(), isA(And.class)); @@ -6332,9 +6368,9 @@ public void testPushTopNDistanceWithCompoundFilterToSourceAndDisjunctiveNonPusha var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "scalerank", "distance")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name")); var topNChild = as(extract.child(), TopNExec.class); var filter = as(topNChild.child(), FilterExec.class); assertThat(filter.condition(), isA(Or.class)); @@ -6361,28 +6397,29 @@ public void testPushTopNDistanceWithCompoundFilterToSourceAndDisjunctiveNonPusha /** * - * ProjectExec[[abbrev{f}#15, name{f}#16, location{f}#19, country{f}#20, city{f}#21]] - * \_TopNExec[[Order[scalerank{f}#17,ASC,LAST], Order[distance{r}#4,ASC,LAST]],15[INTEGER],0] - * \_ExchangeExec[[abbrev{f}#15, name{f}#16, location{f}#19, country{f}#20, city{f}#21, scalerank{f}#17, distance{r}#4],false] - * \_ProjectExec[[abbrev{f}#15, name{f}#16, location{f}#19, country{f}#20, city{f}#21, scalerank{f}#17, distance{r}#4]] - * \_FieldExtractExec[abbrev{f}#15, name{f}#16, country{f}#20, city{f}#21, ..][] - * \_EvalExec[[STDISTANCE(location{f}#19,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) - * AS distance]] - * \_FieldExtractExec[location{f}#19][] - * \_EsQueryExec[airports], indexMode[standard], query[{ - * "bool":{ - * "filter":[ - * {"esql_single_value":{"field":"scalerank",...,"source":"scalerank lt 6@3:31"}}, - * {"bool":{"must":[ - * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, - * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} - * ],"boost":1.0}} - * ],"boost":1.0 - * } - * }][_doc{f}#32], limit[], sort[[ - * FieldSort[field=scalerank{f}#17, direction=ASC, nulls=LAST], - * GeoDistanceSort[field=location{f}#19, direction=ASC, lat=55.673, lon=12.565] - * ]] estimatedRowSize[37] + * ProjectExec[[abbrev{f}#6090, name{f}#6091, location{f}#6094, country{f}#6095, city{f}#6096]] + * \_TopNExec[[Order[scalerank{f}#6092,ASC,LAST], Order[distance{r}#6079,ASC,LAST]],15[INTEGER],233] + * \_ExchangeExec[[abbrev{f}#6090, city{f}#6096, country{f}#6095, location{f}#6094, name{f}#6091, scalerank{f}#6092, distance{r} + * #6079],false] + * \_ProjectExec[[abbrev{f}#6090, city{f}#6096, country{f}#6095, location{f}#6094, name{f}#6091, scalerank{f}#6092, distance{r} + * #6079]] + * \_FieldExtractExec[abbrev{f}#6090, city{f}#6096, country{f}#6095, name..]<[],[]> + * \_EvalExec[[STDISTANCE(location{f}#6094,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distan + * ce#6079]] + * \_FieldExtractExec[location{f}#6094]<[],[]> + * \_EsQueryExec[airports], indexMode[standard], query[ + * {"bool":{"filter":[ + * {"esql_single_value":{"field":"scalerank","next":{"range": + * {"scalerank":{"lt":6,"boost":0.0}}},"source":"scalerank < 6@3:31"}}, + * {"bool":{"must":[ + * {"geo_shape": {"location":{"relation":"INTERSECTS","shape": + * {"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape": + * {"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]}}}} + * ],"boost":1.0}}],"boost":1.0}} + * ][_doc{f}#6107], limit[15], sort[ + * [FieldSort[field=scalerank{f}#6092, direction=ASC, nulls=LAST], + * GeoDistanceSort[field=location{f}#6094, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[249] * */ public void testPushCompoundTopNDistanceWithCompoundFilterToSource() { @@ -6401,9 +6438,9 @@ public void testPushCompoundTopNDistanceWithCompoundFilterToSource() { var exchange = asRemoteExchange(topN.child()); project = as(exchange.child(), ProjectExec.class); - assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + assertThat(names(project.projections()), contains("abbrev", "city", "country", "location", "name", "scalerank", "distance")); var extract = as(project.child(), FieldExtractExec.class); - assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city", "scalerank")); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "city", "country", "name", "scalerank")); var evalExec = as(extract.child(), EvalExec.class); var alias = as(evalExec.fields().get(0), Alias.class); assertThat(alias.name(), is("distance")); @@ -7858,7 +7895,12 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP null, null, null, - new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null, DataPartitioning.AUTO), + new EsPhysicalOperationProviders( + FoldContext.small(), + List.of(), + null, + new PhysicalSettings(DataPartitioning.AUTO, ByteSizeValue.ofMb(1)) + ), List.of() ); @@ -8041,7 +8083,7 @@ public void testNotEqualsPushdownToDelegate() { * ges{f}#5, last_name{f}#6, long_noidx{f}#12, salary{f}#7],false] * \_ProjectExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, hire_date{f}#9, job{f}#10, job.raw{f}#11, langua * ges{f}#5, last_name{f}#6, long_noidx{f}#12, salary{f}#7]] - * \_FieldExtractExec[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gen..]<[],[]> + * \_FieldExtractExec[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gen..]<[],[]> * \_EsQueryExec[test], indexMode[standard], * query[{"bool":{"filter":[{"sampling":{"probability":0.1,"seed":234,"hash":0}}],"boost":1.0}}] * [_doc{f}#24], limit[1000], sort[] estimatedRowSize[332] @@ -8110,7 +8152,8 @@ private static EsQueryExec assertChildIsExtractedAs( "Expect field attribute to be extracted as " + fieldExtractPreference, extract.attributesToExtract() .stream() - .allMatch(attr -> extract.fieldExtractPreference(attr) == fieldExtractPreference && attr.dataType() == dataType) + .filter(t -> t.dataType() == dataType) + .allMatch(attr -> extract.fieldExtractPreference(attr) == fieldExtractPreference) ); return source(extract.child()); } @@ -8277,6 +8320,107 @@ private QueryBuilder sv(QueryBuilder builder, String fieldName) { return sv.next(); } + private PhysicalPlanOptimizer getCustomRulesPhysicalPlanOptimizer(List> batches) { + PhysicalOptimizerContext context = new PhysicalOptimizerContext(config); + PhysicalPlanOptimizer PhysicalPlanOptimizer = new PhysicalPlanOptimizer(context) { + @Override + protected List> batches() { + return batches; + } + }; + return PhysicalPlanOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + + PhysicalPlan plan = physicalPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, PhysicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new EvalExec( + plan.source(), + plan, + List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral)) + ); + } + return plan; + } + } + ); + PhysicalPlanOptimizer customRulesPhysicalPlanOptimizer = getCustomRulesPhysicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesPhysicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() throws Exception { + + PhysicalPlan plan = physicalPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, PhysicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + LimitExec limit = as(plan, LimitExec.class); + LimitExec newLimit = new LimitExec( + plan.source(), + limit.child(), + new Literal(Source.EMPTY, 1000, INTEGER), + randomEstimatedRowSize() + ) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + } + ); + PhysicalPlanOptimizer customRulesPhysicalPlanOptimizer = getCustomRulesPhysicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesPhysicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + } + @Override protected List filteredWarnings() { return withDefaultLimitWarning(super.filteredWarnings()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java index 96e26fbd37a4c..2db996651d9f4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java @@ -265,11 +265,11 @@ public void testNullFoldableDoesNotApplyToIsNullAndNotNull() { } public void testNullBucketGetsFolded() { - assertEquals(NULL, foldNull(new Bucket(EMPTY, NULL, NULL, NULL, NULL))); + assertEquals(NULL, foldNull(new Bucket(EMPTY, NULL, NULL, NULL, NULL, NULL))); } public void testNullCategorizeGroupingNotFolded() { - Categorize categorize = new Categorize(EMPTY, NULL); + Categorize categorize = new Categorize(EMPTY, NULL, NULL); assertEquals(categorize, foldNull(categorize)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java new file mode 100644 index 0000000000000..74ec1b71cf824 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; + +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; +import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizerTests; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.stats.SearchStats; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; + +public class LocalSubstituteSurrogateExpressionTests extends LocalLogicalPlanOptimizerTests { + + public void testSubstituteDateTruncInEvalWithRoundTo() { + var plan = plan(""" + from test + | sort hire_date + | eval x = date_trunc(1 day, hire_date) + | keep emp_no, hire_date, x + | limit 5 + """); + + // create a SearchStats with min and max millis + Map minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z + Map maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z + SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); + + LogicalPlan localPlan = localPlan(plan, searchStats); + Project project = as(localPlan, Project.class); + TopN topN = as(project.child(), TopN.class); + Eval eval = as(topN.child(), Eval.class); + List fields = eval.fields(); + assertEquals(1, fields.size()); + Alias a = fields.get(0); + assertEquals("x", a.name()); + RoundTo roundTo = as(a.child(), RoundTo.class); + FieldAttribute fa = as(roundTo.field(), FieldAttribute.class); + assertEquals("hire_date", fa.name()); + assertEquals(DATETIME, fa.dataType()); + assertEquals(4, roundTo.points().size()); // 4 days + EsRelation relation = as(eval.child(), EsRelation.class); + } + + public void testSubstituteDateTruncInAggWithRoundTo() { + var plan = plan(""" + from test + | stats count(*) by x = date_trunc(1 day, hire_date) + """); + + // create a SearchStats with min and max millis + Map minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z + Map maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z + SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); + + LogicalPlan localPlan = localPlan(plan, searchStats); + Limit limit = as(localPlan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + Eval eval = as(aggregate.child(), Eval.class); + List fields = eval.fields(); + assertEquals(1, fields.size()); + Alias a = fields.get(0); + assertEquals("x", a.name()); + RoundTo roundTo = as(a.child(), RoundTo.class); + FieldAttribute fa = as(roundTo.field(), FieldAttribute.class); + assertEquals("hire_date", fa.name()); + assertEquals(DATETIME, fa.dataType()); + assertEquals(4, roundTo.points().size()); // 4 days + EsRelation relation = as(eval.child(), EsRelation.class); + } + + public void testSubstituteBucketInAggWithRoundTo() { + var plan = plan(""" + from test + | stats count(*) by x = bucket(hire_date, 1 day) + """); + // create a SearchStats with min and max millis + Map minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z + Map maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z + SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); + + LogicalPlan localPlan = localPlan(plan, searchStats); + Limit limit = as(localPlan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + Eval eval = as(aggregate.child(), Eval.class); + List fields = eval.fields(); + assertEquals(1, fields.size()); + Alias a = fields.get(0); + assertEquals("x", a.name()); + RoundTo roundTo = as(a.child(), RoundTo.class); + FieldAttribute fa = as(roundTo.field(), FieldAttribute.class); + assertEquals("hire_date", fa.name()); + assertEquals(DATETIME, fa.dataType()); + assertEquals(4, roundTo.points().size()); // 4 days + EsRelation relation = as(eval.child(), EsRelation.class); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java index 1cd112114e027..31e0074a7fde1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java @@ -946,10 +946,6 @@ public void testBasicLimitCommand() { assertThat(limit.children().get(0).children().get(0), instanceOf(UnresolvedRelation.class)); } - public void testLimitConstraints() { - expectError("from text | limit -1", "line 1:13: Invalid value for LIMIT [-1], expecting a non negative integer"); - } - public void testBasicSortCommand() { LogicalPlan plan = statement("from text | where true | sort a+b asc nulls first, x desc nulls last | sort y asc | sort z desc"); assertThat(plan, instanceOf(OrderBy.class)); @@ -1231,7 +1227,7 @@ public void testLikeRLike() { assertEquals(".*bar.*", rlike.pattern().asJavaRegex()); expectError("from a | where foo like 12", "no viable alternative at input 'foo like 12'"); - expectError("from a | where foo rlike 12", "mismatched input '12'"); + expectError("from a | where foo rlike 12", "no viable alternative at input 'foo rlike 12'"); expectError( "from a | where foo like \"(?i)(^|[^a-zA-Z0-9_-])nmap($|\\\\.)\"", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 6749f03bedde7..b56f4a3a4898b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -19,6 +19,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.LuceneSourceOperator; @@ -340,7 +341,12 @@ private Configuration config() { } private EsPhysicalOperationProviders esPhysicalOperationProviders(List shardContexts) { - return new EsPhysicalOperationProviders(FoldContext.small(), shardContexts, null, DataPartitioning.AUTO); + return new EsPhysicalOperationProviders( + FoldContext.small(), + shardContexts, + null, + new PhysicalSettings(DataPartitioning.AUTO, ByteSizeValue.ofMb(1)) + ); } private List createShardContexts() throws IOException { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index a8916f140ea1f..1fc05b9d3f2db 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -9,10 +9,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.analysis.common.CommonAnalysisPlugin; -import org.elasticsearch.common.Randomness; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; @@ -30,7 +27,6 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.Operator; -import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator.SourceOperatorFactory; import org.elasticsearch.compute.operator.TimeSeriesAggregationOperator; @@ -58,7 +54,6 @@ import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; -import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; @@ -71,7 +66,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.Random; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -79,8 +73,6 @@ import java.util.function.Supplier; import java.util.stream.IntStream; -import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween; -import static java.util.stream.Collectors.joining; import static org.apache.lucene.tests.util.LuceneTestCase.createTempDir; import static org.elasticsearch.compute.aggregation.spatial.SpatialAggregationUtils.encodeLongitude; import static org.elasticsearch.index.mapper.MappedFieldType.FieldExtractPreference.DOC_VALUES; @@ -138,25 +130,6 @@ public PhysicalOperation timeSeriesSourceOperation(TimeSeriesSourceExec ts, Loca throw new UnsupportedOperationException("time-series source is not supported in CSV tests"); } - @Override - public Operator.OperatorFactory ordinalGroupingOperatorFactory( - PhysicalOperation source, - AggregateExec aggregateExec, - List aggregatorFactories, - Attribute attrSource, - ElementType groupElementType, - LocalExecutionPlannerContext context - ) { - int channelIndex = source.layout.numberOfChannels(); - return new TestOrdinalsGroupingAggregationOperatorFactory( - channelIndex, - aggregatorFactories, - groupElementType, - context.bigArrays(), - attrSource - ); - } - @Override public Operator.OperatorFactory timeSeriesAggregatorOperatorFactory( TimeSeriesAggregateExec ts, @@ -332,7 +305,7 @@ private Block getBlockForMultiType(DocBlock indexDoc, MultiTypeEsField multiType } return switch (extractBlockForSingleDoc(indexDoc, ((FieldAttribute) conversion.field()).fieldName().string(), blockCopier)) { case BlockResultMissing unused -> getNullsBlock(indexDoc); - case BlockResultSuccess success -> TypeConverter.fromConvertFunction(conversion).convert(success.block); + case BlockResultSuccess success -> TypeConverter.fromScalarFunction(conversion).convert(success.block); }; } @@ -393,12 +366,13 @@ private class TestHashAggregationOperator extends HashAggregationOperator { private final Attribute attribute; TestHashAggregationOperator( + List groups, List aggregators, Supplier blockHash, Attribute attribute, DriverContext driverContext ) { - super(aggregators, blockHash, driverContext); + super(groups, aggregators, blockHash, driverContext); this.attribute = attribute; } @@ -408,58 +382,6 @@ protected Page wrapPage(Page page) { } } - /** - * Pretends to be the {@link OrdinalsGroupingOperator} but always delegates to the - * {@link HashAggregationOperator}. - */ - private class TestOrdinalsGroupingAggregationOperatorFactory implements Operator.OperatorFactory { - private final int groupByChannel; - private final List aggregators; - private final ElementType groupElementType; - private final BigArrays bigArrays; - private final Attribute attribute; - - TestOrdinalsGroupingAggregationOperatorFactory( - int channelIndex, - List aggregatorFactories, - ElementType groupElementType, - BigArrays bigArrays, - Attribute attribute - ) { - this.groupByChannel = channelIndex; - this.aggregators = aggregatorFactories; - this.groupElementType = groupElementType; - this.bigArrays = bigArrays; - this.attribute = attribute; - } - - @Override - public Operator get(DriverContext driverContext) { - Random random = Randomness.get(); - int pageSize = random.nextBoolean() ? randomIntBetween(random, 1, 16) : randomIntBetween(random, 1, 10 * 1024); - return new TestHashAggregationOperator( - aggregators, - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(groupByChannel, groupElementType)), - driverContext.blockFactory(), - pageSize, - false - ), - attribute, - driverContext - ); - } - - @Override - public String describe() { - return "TestHashAggregationOperator(mode = " - + "" - + ", aggs = " - + aggregators.stream().map(Describable::describe).collect(joining(", ")) - + ")"; - } - } - private Block extractBlockForColumn( DocBlock docBlock, DataType dataType, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtilsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtilsTests.java index 2d488d7e41ee8..e2338a12f6179 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtilsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlCCSUtilsTests.java @@ -152,7 +152,7 @@ public void testUpdateExecutionInfoWithUnavailableClusters() { executionInfo.swapCluster(REMOTE2_ALIAS, (k, v) -> new EsqlExecutionInfo.Cluster(REMOTE2_ALIAS, "mylogs1,mylogs2,logs*", true)); var failure = new FieldCapabilitiesFailure(new String[] { "logs-a" }, new NoSeedNodeLeftException("unable to connect")); - var unvailableClusters = Map.of(REMOTE1_ALIAS, failure, REMOTE2_ALIAS, failure); + var unvailableClusters = Map.of(REMOTE1_ALIAS, List.of(failure), REMOTE2_ALIAS, List.of(failure)); EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, unvailableClusters); assertThat(executionInfo.clusterAliases(), equalTo(Set.of(LOCAL_CLUSTER_ALIAS, REMOTE1_ALIAS, REMOTE2_ALIAS))); @@ -184,7 +184,7 @@ public void testUpdateExecutionInfoWithUnavailableClusters() { var failure = new FieldCapabilitiesFailure(new String[] { "logs-a" }, new NoSeedNodeLeftException("unable to connect")); RemoteTransportException e = expectThrows( RemoteTransportException.class, - () -> EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, Map.of(REMOTE2_ALIAS, failure)) + () -> EsqlCCSUtils.updateExecutionInfoWithUnavailableClusters(executionInfo, Map.of(REMOTE2_ALIAS, List.of(failure))) ); assertThat(e.status().getStatus(), equalTo(500)); assertThat( @@ -253,7 +253,7 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { ) ); - IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Set.of(), Map.of()); + IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Map.of()); EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution); @@ -296,8 +296,7 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { IndexMode.STANDARD ) ); - Map unavailableClusters = Map.of(); - IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Set.of(), unavailableClusters); + IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Map.of()); EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution); @@ -338,8 +337,8 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { ); // remote1 is unavailable var failure = new FieldCapabilitiesFailure(new String[] { "logs-a" }, new NoSeedNodeLeftException("unable to connect")); - Map unavailableClusters = Map.of(REMOTE1_ALIAS, failure); - IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Set.of(), unavailableClusters); + var failures = Map.of(REMOTE1_ALIAS, List.of(failure)); + IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), failures); EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution); @@ -349,9 +348,8 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { EsqlExecutionInfo.Cluster remote1Cluster = executionInfo.getCluster(REMOTE1_ALIAS); assertThat(remote1Cluster.getIndexExpression(), equalTo("*")); - // since remote1 is in the unavailable Map (passed to IndexResolution.valid), it's status will not be changed - // by updateExecutionInfoWithClustersWithNoMatchingIndices (it is handled in updateExecutionInfoWithUnavailableClusters) - assertThat(remote1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.RUNNING)); + // since remote1 is in the failures Map (passed to IndexResolution.valid), + assertThat(remote1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SKIPPED)); EsqlExecutionInfo.Cluster remote2Cluster = executionInfo.getCluster(REMOTE2_ALIAS); assertThat(remote2Cluster.getIndexExpression(), equalTo("mylogs1*,mylogs2*,logs*")); @@ -381,8 +379,8 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { ); var failure = new FieldCapabilitiesFailure(new String[] { "logs-a" }, new NoSeedNodeLeftException("unable to connect")); - Map unavailableClusters = Map.of(REMOTE1_ALIAS, failure); - IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Set.of(), unavailableClusters); + var failures = Map.of(REMOTE1_ALIAS, List.of(failure)); + IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), failures); EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution); EsqlExecutionInfo.Cluster localCluster = executionInfo.getCluster(LOCAL_CLUSTER_ALIAS); @@ -390,9 +388,8 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { assertClusterStatusAndShardCounts(localCluster, EsqlExecutionInfo.Cluster.Status.RUNNING); EsqlExecutionInfo.Cluster remote1Cluster = executionInfo.getCluster(REMOTE1_ALIAS); - // since remote1 is in the unavailable Map (passed to IndexResolution.valid), it's status will not be changed - // by updateExecutionInfoWithClustersWithNoMatchingIndices (it is handled in updateExecutionInfoWithUnavailableClusters) - assertThat(remote1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.RUNNING)); + // skipped since remote1 is in the failures Map + assertThat(remote1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SKIPPED)); EsqlExecutionInfo.Cluster remote2Cluster = executionInfo.getCluster(REMOTE2_ALIAS); assertThat(remote2Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SKIPPED)); @@ -430,8 +427,8 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { // remote1 is unavailable var failure = new FieldCapabilitiesFailure(new String[] { "logs-a" }, new NoSeedNodeLeftException("unable to connect")); - Map unavailableClusters = Map.of(REMOTE1_ALIAS, failure); - IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), Set.of(), unavailableClusters); + var failures = Map.of(REMOTE1_ALIAS, List.of(failure)); + IndexResolution indexResolution = IndexResolution.valid(esIndex, esIndex.concreteIndices(), failures); EsqlCCSUtils.updateExecutionInfoWithClustersWithNoMatchingIndices(executionInfo, indexResolution); @@ -441,9 +438,8 @@ public void testUpdateExecutionInfoWithClustersWithNoMatchingIndices() { EsqlExecutionInfo.Cluster remote1Cluster = executionInfo.getCluster(REMOTE1_ALIAS); assertThat(remote1Cluster.getIndexExpression(), equalTo("*")); - // since remote1 is in the unavailable Map (passed to IndexResolution.valid), it's status will not be changed - // by updateExecutionInfoWithClustersWithNoMatchingIndices (it is handled in updateExecutionInfoWithUnavailableClusters) - assertThat(remote1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.RUNNING)); + // skipped since remote1 is in the failures Map + assertThat(remote1Cluster.getStatus(), equalTo(EsqlExecutionInfo.Cluster.Status.SKIPPED)); EsqlExecutionInfo.Cluster remote2Cluster = executionInfo.getCluster(REMOTE2_ALIAS); assertThat(remote2Cluster.getIndexExpression(), equalTo("mylogs1*,mylogs2*,logs*")); @@ -463,7 +459,9 @@ public void testDetermineUnavailableRemoteClusters() { ) ); - Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(failures); + Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters( + EsqlCCSUtils.groupFailuresPerCluster(failures) + ); assertThat(unavailableClusters.keySet(), equalTo(Set.of("remote1", "remote2"))); } @@ -473,7 +471,8 @@ public void testDetermineUnavailableRemoteClusters() { failures.add(new FieldCapabilitiesFailure(new String[] { "remote2:mylogs1" }, new NoSuchRemoteClusterException("remote2"))); failures.add(new FieldCapabilitiesFailure(new String[] { "remote2:mylogs1" }, new NoSeedNodeLeftException("no seed node"))); - Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(failures); + var groupedFailures = EsqlCCSUtils.groupFailuresPerCluster(failures); + Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(groupedFailures); assertThat(unavailableClusters.keySet(), equalTo(Set.of("remote2"))); } @@ -487,7 +486,8 @@ public void testDetermineUnavailableRemoteClusters() { new IllegalStateException("Unable to open any connections") ) ); - Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(failures); + var groupedFailures = EsqlCCSUtils.groupFailuresPerCluster(failures); + Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(groupedFailures); assertThat(unavailableClusters.keySet(), equalTo(Set.of("remote2"))); } @@ -495,14 +495,16 @@ public void testDetermineUnavailableRemoteClusters() { { List failures = new ArrayList<>(); failures.add(new FieldCapabilitiesFailure(new String[] { "remote1:mylogs1" }, new RuntimeException("foo"))); - Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(failures); + var groupedFailures = EsqlCCSUtils.groupFailuresPerCluster(failures); + Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(groupedFailures); assertThat(unavailableClusters.keySet(), equalTo(Set.of())); } // empty failures list { List failures = new ArrayList<>(); - Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(failures); + var groupedFailures = EsqlCCSUtils.groupFailuresPerCluster(failures); + Map unavailableClusters = EsqlCCSUtils.determineUnavailableRemoteClusters(groupedFailures); assertThat(unavailableClusters.keySet(), equalTo(Set.of())); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/DisabledSearchStats.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/DisabledSearchStats.java index 6d8f5ca925121..308d21da05c6d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/DisabledSearchStats.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/DisabledSearchStats.java @@ -9,7 +9,6 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute.FieldName; -import org.elasticsearch.xpack.esql.core.type.DataType; public class DisabledSearchStats implements SearchStats { @@ -49,12 +48,12 @@ public long count(FieldName field, BytesRef value) { } @Override - public byte[] min(FieldName field, DataType dataType) { + public Object min(FieldName field) { return null; } @Override - public byte[] max(FieldName field, DataType dataType) { + public Object max(FieldName field) { return null; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/SearchContextStatsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/SearchContextStatsTests.java new file mode 100644 index 0000000000000..2d8ecd26ad230 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/SearchContextStatsTests.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.stats; + +import org.apache.lucene.document.DoubleField; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatField; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.LongField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateNanosToLong; +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong; + +public class SearchContextStatsTests extends MapperServiceTestCase { + private final Directory directory = newDirectory(); + private SearchStats searchStats; + private List mapperServices; + private List readers; + private long minMillis, maxMillis, minNanos, maxNanos; + + @Before + public void setup() throws IOException { + int indexCount = randomIntBetween(1, 5); + List contexts = new ArrayList<>(indexCount); + mapperServices = new ArrayList<>(indexCount); + readers = new ArrayList<>(indexCount); + maxMillis = minMillis = dateTimeToLong("2025-01-01T00:00:01"); + maxNanos = minNanos = dateNanosToLong("2025-01-01T00:00:01"); + + MapperServiceTestCase mapperHelper = new MapperServiceTestCase() { + }; + // create one or more index, so that there is one or more SearchExecutionContext in SearchStats + for (int i = 0; i < indexCount; i++) { + // Start with millis/nanos, numeric and keyword types in the index mapping, more data types can be covered later if needed. + // SearchContextStats returns min/max for millis and nanos only currently, null is returned for the other types min and max. + MapperService mapperService; + if (i == 0) { + mapperService = mapperHelper.createMapperService(""" + { + "doc": { "properties": { + "byteField": { "type": "byte" }, + "shortField": { "type": "short" }, + "intField": { "type": "integer" }, + "longField": { "type": "long" }, + "floatField": { "type": "float" }, + "doubleField": { "type": "double" }, + "dateField": { "type": "date" }, + "dateNanosField": { "type": "date_nanos" }, + "keywordField": { "type": "keyword" }, + "maybeMixedField": { "type": "long" } + }} + }"""); + } else { + mapperService = mapperHelper.createMapperService(""" + { + "doc": { "properties": { + "byteField": { "type": "byte" }, + "shortField": { "type": "short" }, + "intField": { "type": "integer" }, + "longField": { "type": "long" }, + "floatField": { "type": "float" }, + "doubleField": { "type": "double" }, + "dateField": { "type": "date" }, + "dateNanosField": { "type": "date_nanos" }, + "maybeMixedField": { "type": "date" } + }} + }"""); + } + mapperServices.add(mapperService); + + int perIndexDocumentCount = randomIntBetween(1, 5); + IndexReader reader; + try (RandomIndexWriter writer = new RandomIndexWriter(random(), directory)) { + List byteValues = randomList(perIndexDocumentCount, perIndexDocumentCount, ESTestCase::randomByte); + List shortValues = randomList(perIndexDocumentCount, perIndexDocumentCount, ESTestCase::randomShort); + List intValues = randomList(perIndexDocumentCount, perIndexDocumentCount, ESTestCase::randomInt); + List longValues = randomList(perIndexDocumentCount, perIndexDocumentCount, ESTestCase::randomLong); + List floatValues = randomList(perIndexDocumentCount, perIndexDocumentCount, ESTestCase::randomFloat); + List doubleValues = randomList(perIndexDocumentCount, perIndexDocumentCount, ESTestCase::randomDouble); + List keywordValues = randomList(perIndexDocumentCount, perIndexDocumentCount, () -> randomAlphaOfLength(5)); + + for (int j = 0; j < perIndexDocumentCount; j++) { + long millis = minMillis + (j == 0 ? 0 : randomInt(1000)); + long nanos = minNanos + (j == 0 ? 0 : randomInt(1000)); + maxMillis = Math.max(millis, maxMillis); + maxNanos = Math.max(nanos, maxNanos); + minMillis = Math.min(millis, minMillis); + minNanos = Math.min(nanos, minNanos); + writer.addDocument( + List.of( + new IntField("byteField", byteValues.get(j), Field.Store.NO), + new IntField("shortField", shortValues.get(j), Field.Store.NO), + new IntField("intField", intValues.get(j), Field.Store.NO), + new LongField("longField", longValues.get(j), Field.Store.NO), + new FloatField("floatField", floatValues.get(j), Field.Store.NO), + new DoubleField("doubleField", doubleValues.get(j), Field.Store.NO), + new LongField("dateField", millis, Field.Store.NO), + new LongField("dateNanosField", nanos, Field.Store.NO), + new StringField("keywordField", keywordValues.get(j), Field.Store.NO), + new LongField("maybeMixedField", millis, Field.Store.NO) + ) + ); + } + reader = writer.getReader(); + readers.add(reader); + } + // create SearchExecutionContext for each index + SearchExecutionContext context = mapperHelper.createSearchExecutionContext(mapperService, newSearcher(reader)); + contexts.add(context); + } + // create SearchContextStats + searchStats = SearchContextStats.from(contexts); + } + + public void testMinMax() { + List fields = List.of( + "byteField", + "shortField", + "intField", + "longField", + "floatField", + "doubleField", + "dateField", + "dateNanosField", + "keywordField" + ); + for (String field : fields) { + Object min = searchStats.min(new FieldAttribute.FieldName(field)); + Object max = searchStats.max(new FieldAttribute.FieldName(field)); + if (field.startsWith("date") == false) { + assertNull(min); + assertNull(max); + } else if (field.equals("dateField")) { + assertEquals(minMillis, min); + assertEquals(maxMillis, max); + } else if (field.equals("dateNanosField")) { + assertEquals(minNanos, min); + assertEquals(maxNanos, max); + } + } + } + + @After + public void cleanup() throws IOException { + IOUtils.close(readers); + IOUtils.close(mapperServices); + IOUtils.close(directory); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index 5571729626fb9..9701fff2a5789 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -51,24 +51,6 @@ public void testAttachToDeployment() throws IOException { var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); - var updatedNumAllocations = randomIntBetween(1, 10); - var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING); - assertThat( - updatedEndpointConfig.get("service_settings"), - is( - Map.of( - "num_allocations", - updatedNumAllocations, - "num_threads", - 1, - "model_id", - "attach_to_deployment", - "deployment_id", - "existing_deployment" - ) - ) - ); - deleteModel(inferenceId); // assert deployment not stopped var stats = (List>) getTrainedModelStats(modelId).get("trained_model_stats"); @@ -128,24 +110,6 @@ public void testAttachWithModelId() throws IOException { var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); - var updatedNumAllocations = randomIntBetween(1, 10); - var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING); - assertThat( - updatedEndpointConfig.get("service_settings"), - is( - Map.of( - "num_allocations", - updatedNumAllocations, - "num_threads", - 1, - "model_id", - "attach_with_model_id", - "deployment_id", - "existing_deployment_with_model_id" - ) - ) - ); - forceStopMlNodeDeployment(deploymentId); } @@ -180,6 +144,30 @@ public void testDeploymentDoesNotExist() { assertThat(e.getMessage(), containsString("Cannot find deployment [missing_deployment]")); } + public void testCreateInferenceUsingSameDeploymentId() throws IOException { + var modelId = "conflicting_ids"; + var deploymentId = modelId; + var inferenceId = modelId; + + CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client()); + var response = startMlNodeDeploymemnt(modelId, deploymentId); + assertStatusOkOrCreated(response); + + var responseException = assertThrows( + ResponseException.class, + () -> putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING) + ); + assertThat( + responseException.getMessage(), + containsString( + "Inference endpoint IDs must be unique. " + + "Requested inference endpoint ID [conflicting_ids] matches existing trained model ID(s) but must not." + ) + ); + + forceStopMlNodeDeployment(deploymentId); + } + public void testNumAllocationsIsUpdated() throws IOException { var modelId = "update_num_allocations"; var deploymentId = modelId; @@ -208,7 +196,16 @@ public void testNumAllocationsIsUpdated() throws IOException { ) ); - assertStatusOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2)); + var responseException = assertThrows(ResponseException.class, () -> updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2)); + assertThat( + responseException.getMessage(), + containsString( + "Cannot update inference endpoint [test_num_allocations_updated] using model deployment [update_num_allocations]. " + + "The model deployment must be updated through the trained models API." + ) + ); + + updateMlNodeDeploymemnt(deploymentId, 2); var updatedServiceSettings = getModel(inferenceId).get("service_settings"); assertThat( @@ -227,6 +224,92 @@ public void testNumAllocationsIsUpdated() throws IOException { ) ) ); + + forceStopMlNodeDeployment(deploymentId); + } + + public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOException { + var modelId = "update_num_allocations_from_created_endpoint"; + var inferenceId = "test_created_endpoint_from_model"; + var deploymentId = inferenceId; + + CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client()); + + var putModel = putModel(inferenceId, Strings.format(""" + { + "service": "elasticsearch", + "service_settings": { + "num_allocations": %s, + "num_threads": %s, + "model_id": "%s" + } + } + """, 1, 1, modelId), TaskType.SPARSE_EMBEDDING); + var serviceSettings = putModel.get("service_settings"); + assertThat(putModel.toString(), serviceSettings, is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId))); + + updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2); + + var responseException = assertThrows(ResponseException.class, () -> updateMlNodeDeploymemnt(deploymentId, 2)); + assertThat( + responseException.getMessage(), + containsString( + "Cannot update deployment [test_created_endpoint_from_model] as it was created by inference endpoint " + + "[test_created_endpoint_from_model]. This model deployment must be updated through the inference API." + ) + ); + + var updatedServiceSettings = getModel(inferenceId).get("service_settings"); + assertThat( + updatedServiceSettings.toString(), + updatedServiceSettings, + is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId)) + ); + + forceStopMlNodeDeployment(deploymentId); + } + + public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws IOException { + var modelId = "model_deployment_for_endpoint"; + var inferenceId = "first_endpoint_for_model_deployment"; + var deploymentId = inferenceId; + + CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client()); + + putModel(inferenceId, Strings.format(""" + { + "service": "elasticsearch", + "service_settings": { + "num_allocations": %s, + "num_threads": %s, + "model_id": "%s" + } + } + """, 1, 1, modelId), TaskType.SPARSE_EMBEDDING); + + var secondInferenceId = "second_endpoint_for_model_deployment"; + var putModel = putModel(secondInferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING); + var serviceSettings = putModel.get("service_settings"); + assertThat( + putModel.toString(), + serviceSettings, + is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId, "deployment_id", deploymentId)) + ); + + var responseException = assertThrows( + ResponseException.class, + () -> updateInference(secondInferenceId, TaskType.SPARSE_EMBEDDING, 2) + ); + assertThat( + responseException.getMessage(), + containsString( + "Cannot update inference endpoint [second_endpoint_for_model_deployment] for model deployment " + + "[first_endpoint_for_model_deployment] as it was created by another inference endpoint. " + + "The model can only be updated using inference endpoint id [first_endpoint_for_model_deployment]." + ) + ); + + forceStopMlNodeDeployment(deploymentId); } public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException { @@ -300,6 +383,22 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr return client().performRequest(request); } + private Response updateInference(String deploymentId, TaskType taskType, int numAllocations) throws IOException { + String endPoint = Strings.format("/_inference/%s/%s/_update", taskType, deploymentId); + + var body = Strings.format(""" + { + "service_settings": { + "num_allocations": %d + } + } + """, numAllocations); + + Request request = new Request("PUT", endPoint); + request.setJsonEntity(body); + return client().performRequest(request); + } + private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException { String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update"; @@ -314,6 +413,16 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations return client().performRequest(request); } + private Map updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException { + String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update"; + + Request request = new Request("POST", endPoint); + request.setJsonEntity(body); + var response = client().performRequest(request); + assertStatusOkOrCreated(response); + return entityAsMap(response); + } + protected void stopMlNodeDeployment(String deploymentId) throws IOException { String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; Request request = new Request("POST", endpoint); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index c216535649d47..9729f26c4334c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -111,6 +111,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { containsInAnyOrder( List.of( "alibabacloud-ai-search", + "azureaistudio", "cohere", "elasticsearch", "googlevertexai", diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java index 9782d4881ac61..6191e83a7dca1 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java @@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase { // TODO: replace with proper test features private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0"; private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0"; + private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0"; private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2"; private static MockWebServer cohereEmbeddingsServer; private static MockWebServer cohereRerankServer; + private static MockWebServer cohereCompletionsServer; private enum ApiVersion { V1, @@ -60,12 +62,16 @@ public static void startWebServer() throws IOException { cohereRerankServer = new MockWebServer(); cohereRerankServer.start(); + + cohereCompletionsServer = new MockWebServer(); + cohereCompletionsServer.start(); } @AfterClass public static void shutdown() { cohereEmbeddingsServer.close(); cohereRerankServer.close(); + cohereCompletionsServer.close(); } @SuppressWarnings("unchecked") @@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException { assertThat(inferenceMap.entrySet(), not(empty())); } + @SuppressWarnings("unchecked") + public void testCohereCompletions() throws IOException { + var completionsSupported = oldClusterHasFeature(COHERE_COMPLETIONS_ADDED_TEST_FEATURE); + assumeTrue("Cohere completions not supported", completionsSupported); + + ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1; + + final String oldClusterId = "old-cluster-completions"; + + if (isOldCluster()) { + // queue a response as PUT will call the service + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion))); + put(oldClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION); + + var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "command")); + } else if (isMixedCluster()) { + var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "command")); + } else if (isUpgradedCluster()) { + // check old cluster model + var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "command")); + + final String newClusterId = "new-cluster-completions"; + { + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion))); + var inferenceMap = inference(oldClusterId, TaskType.COMPLETION, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", oldClusterApiVersion); + } + { + // new cluster uses the V2 API + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2))); + put(newClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION); + + cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2))); + var inferenceMap = inference(newClusterId, TaskType.COMPLETION, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", ApiVersion.V2); + } + + { + // new endpoints use the V2 API which require the model to be set + final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id"; + var jsonBody = Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX" + } + } + """, getUrl(cohereEmbeddingsServer)); + + var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, TaskType.COMPLETION)); + assertThat( + e.getMessage(), + containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.") + ); + } + + delete(oldClusterId); + delete(newClusterId); + } + } + private String embeddingConfigByte(String url) { return embeddingConfigTemplate(url, "byte"); } @@ -451,4 +531,86 @@ private String rerankResponse() { """; } + private String completionsConfig(String url) { + return Strings.format(""" + { + "service": "cohere", + "service_settings": { + "api_key": "XXXX", + "model_id": "command", + "url": "%s" + } + } + """, url); + } + + private String completionsResponse(ApiVersion version) { + return switch (version) { + case V1 -> v1CompletionsResponse(); + case V2 -> v2CompletionsResponse(); + }; + } + + private String v1CompletionsResponse() { + return """ + { + "response_id": "some id", + "text": "result", + "generation_id": "some id", + "chat_history": [ + { + "role": "USER", + "message": "some input" + }, + { + "role": "CHATBOT", + "message": "v1 response from the llm" + } + ], + "finish_reason": "COMPLETE", + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 4, + "output_tokens": 191 + }, + "tokens": { + "input_tokens": 70, + "output_tokens": 191 + } + } + } + """; + } + + private String v2CompletionsResponse() { + return """ + { + "id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4", + "finish_reason": "COMPLETE", + "message": { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "v2 response from the LLM" + } + ] + }, + "usage": { + "billed_units": { + "input_tokens": 1, + "output_tokens": 2 + }, + "tokens": { + "input_tokens": 3, + "output_tokens": 4 + } + } + } + """; + } + } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 8e40bba8b32f7..1eb530ac1bb9e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() { createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(gatewayUrl), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index e56782bd00ef5..22aebee72df0c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; @@ -129,7 +130,8 @@ public void testGetModel() throws Exception { mock(Client.class), mock(ThreadPool.class), mock(ClusterService.class), - Settings.EMPTY + Settings.EMPTY, + InferenceStatsTests.mockInferenceStats() ) ); ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 3d05600709b23..00f40e903d1ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -37,8 +37,12 @@ public class InferenceFeatures implements FeatureSpecification { private static final NodeFeature TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS = new NodeFeature( "test_rule_retriever.with_indices_that_dont_return_rank_docs" ); + private static final NodeFeature SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX = new NodeFeature( + "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + ); private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter"); private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2"); + public static final NodeFeature SEMANTIC_TEXT_HIGHLIGHTING_FLAT = new NodeFeature("semantic_text.highlighter.flat_index_options"); @Override public Set getTestFeatures() { @@ -68,7 +72,9 @@ public Set getTestFeatures() { SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS, SEMANTIC_TEXT_INDEX_OPTIONS, COHERE_V2_API, - SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS + SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS, + SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX, + SEMANTIC_TEXT_HIGHLIGHTING_FLAT ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index c347fa1dca4ce..6fd07cd4c2831 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -50,6 +50,8 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings; @@ -104,6 +106,8 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; @@ -173,6 +177,7 @@ public static List getNamedWriteables() { addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); addCustomNamedWriteables(namedWriteables); + addLlamaNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -272,8 +277,25 @@ private static void addMistralNamedWriteables(List MistralChatCompletionServiceSettings::new ) ); + // no task settings for Mistral + } - // note - no task settings for Mistral embeddings... + private static void addLlamaNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + LlamaEmbeddingsServiceSettings.NAME, + LlamaEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + LlamaChatCompletionServiceSettings.NAME, + LlamaChatCompletionServiceSettings::new + ) + ); + // no task settings for Llama } private static void addAzureAiStudioNamedWriteables(List namedWriteables) { @@ -306,6 +328,17 @@ private static void addAzureAiStudioNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..bbb1bd1a2fec2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -30,6 +30,7 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.License; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; @@ -132,6 +133,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; @@ -140,7 +142,6 @@ import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; import java.util.Collection; @@ -311,7 +312,8 @@ public Collection createComponents(PluginServices services) { serviceComponents.get(), inferenceServiceSettings, modelRegistry.get(), - authorizationHandler + authorizationHandler, + context ), context -> new SageMakerService( new SageMakerModelBuilder(sageMakerSchemas), @@ -321,16 +323,22 @@ public Collection createComponents(PluginServices services) { ), sageMakerSchemas, services.threadPool(), - sageMakerConfigurations::getOrCompute + sageMakerConfigurations::getOrCompute, + context ) ) ); + var meterRegistry = services.telemetryProvider().getMeterRegistry(); + var inferenceStats = InferenceStats.create(meterRegistry); + var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); + var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext( services.client(), services.threadPool(), services.clusterService(), - settings + settings, + inferenceStats ); // This must be done after the HttpRequestSenderFactory is created so that the services can get the @@ -342,10 +350,6 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); - var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var inferenceStats = InferenceStats.create(meterRegistry); - var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); - var actionFilter = new ShardBulkInferenceActionFilter( services.clusterService(), serviceRegistry, @@ -383,24 +387,25 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), - context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), - context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new CohereService(httpFactory.get(), serviceComponents.get()), - context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), - context -> new MistralService(httpFactory.get(), serviceComponents.get()), - context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), - context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), - context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), - context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), - context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), - context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), - context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context), + context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new CohereService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new MistralService(httpFactory.get(), serviceComponents.get(), context), + context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context), + context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context), + context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context), + context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), + context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, - context -> new CustomService(httpFactory.get(), serviceComponents.get()) + context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index dec6d0d928b97..269e0f27fd461 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -42,7 +43,6 @@ import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; @@ -57,10 +57,11 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.routingAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes; /** * Base class for transport actions that handle inference requests. @@ -274,15 +275,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException { } private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - Map metricAttributes = new HashMap<>(); - metricAttributes.putAll(modelAttributes(model)); - metricAttributes.putAll(responseAttributes(unwrapCause(t))); - - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } + Map metricAttributes = new HashMap<>(); + metricAttributes.putAll(modelAttributes(model)); + metricAttributes.putAll(responseAttributes(unwrapCause(t))); + + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } private void inferOnServiceWithMetrics( @@ -369,7 +366,7 @@ protected Flow.Publisher streamErrorHandler(Flow.Publisher upstream) { private void recordRequestCountMetrics(Model model, Request request, String localNodeId) { Map requestCountAttributes = new HashMap<>(); requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(routingAttributes(request, localNodeId)); + requestCountAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); inferenceStats.requestCount().incrementBy(1, requestCountAttributes); } @@ -381,16 +378,11 @@ private void recordRequestDurationMetrics( String localNodeId, @Nullable Throwable t ) { - try { - Map metricAttributes = new HashMap<>(); - metricAttributes.putAll(modelAttributes(model)); - metricAttributes.putAll(routingAttributes(request, localNodeId)); - metricAttributes.putAll(responseAttributes(unwrapCause(t))); - - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + Map metricAttributes = new HashMap<>(); + metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t))); + metricAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); + + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index d213111d82d9f..c100c9926b451 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; @@ -128,10 +129,38 @@ private void doExecuteForked( } var service = serviceRegistry.getService(unparsedModel.service()); + Model model; if (service.isPresent()) { - var model = service.get() - .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); - service.get().stop(model, listener); + try { + model = service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + } catch (Exception e) { + if (request.isForceDelete()) { + listener.onResponse(true); + return; + } else { + listener.onFailure( + new ElasticsearchStatusException( + Strings.format( + "Failed to parse model configuration for inference endpoint [%s]", + request.getInferenceEndpointId() + ), + RestStatus.INTERNAL_SERVER_ERROR, + e + ) + ); + return; + } + } + service.get().stop(model, listener.delegateResponse((l, e) -> { + if (request.isForceDelete()) { + l.onResponse(true); + } else { + l.onFailure(e); + } + })); + } else if (request.isForceDelete()) { + listener.onResponse(true); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 7d24b7766baa3..f14d679ba7d26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.threadpool.ThreadPool; @@ -24,7 +25,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; public class TransportInferenceAction extends BaseTransportInferenceAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index bfa8141d312cf..d0eef677ca1d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -29,7 +30,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.concurrent.Flow; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java index 48fefaa42fe13..77a370b2ef3dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java @@ -58,7 +58,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; @@ -224,13 +223,10 @@ private Model combineExistingModelWithNewSettings( if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) { newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings()); } - if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) { - // In cluster services can only have their num_allocations updated, so this is a special case + if (settingsToUpdate.serviceSettings() != null) { + // In cluster services can have their deployment settings updated, so this is a special case if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) { - newServiceSettings = new ElasticsearchInternalServiceSettings( - elasticServiceSettings, - (Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS) - ); + newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings()); } } if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) { @@ -257,26 +253,59 @@ private void updateInClusterEndpoint( Model newModel, Model existingParsedModel, ActionListener listener - ) throws IOException { + ) { // The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel); - throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId); + var inferenceEntityId = request.getInferenceEntityId(); + throwIfTrainedModelDoesntExist(inferenceEntityId, deploymentId); - Map serviceSettings = request.getContentAsSettings().serviceSettings(); - if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) { + if (inferenceEntityId.equals(deploymentId) == false) { + modelRegistry.getModel(deploymentId, ActionListener.wrap(unparsedModel -> { + // if this deployment was created by another inference endpoint, then it must be updated using that inference endpoint + listener.onFailure( + new ElasticsearchStatusException( + Messages.INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT, + RestStatus.CONFLICT, + inferenceEntityId, + deploymentId, + unparsedModel.inferenceEntityId() + ) + ); + }, e -> { + if (e instanceof ResourceNotFoundException) { + // if this deployment was created by the trained models API, then it must be updated by the trained models API + listener.onFailure( + new ElasticsearchStatusException( + Messages.INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED, + RestStatus.CONFLICT, + inferenceEntityId, + deploymentId + ) + ); + return; + } + listener.onFailure(e); + })); + return; + } + + if (newModel.getServiceSettings() instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) { - UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId); - updateRequest.setNumberOfAllocations(numAllocations); + var updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId); + updateRequest.setNumberOfAllocations(elasticServiceSettings.getNumAllocations()); + updateRequest.setAdaptiveAllocationsSettings(elasticServiceSettings.getAdaptiveAllocationsSettings()); + updateRequest.setIsInternal(true); var delegate = listener.delegateFailure((l2, response) -> { modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2); }); logger.info( - "Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations", + "Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations and adaptive allocations [{}]", deploymentId, request.getInferenceEntityId(), - numAllocations + elasticServiceSettings.getNumAllocations(), + elasticServiceSettings.getAdaptiveAllocationsSettings() ); client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, updateRequest, delegate); @@ -317,7 +346,6 @@ private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String dep throw ExceptionsHelper.entityNotFoundException( Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE, inferenceEntityId - ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 3127361de6d11..ecf73ed004194 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -46,6 +46,7 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -63,7 +64,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.io.IOException; import java.util.ArrayList; @@ -76,11 +76,10 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -459,8 +458,7 @@ public void onFailure(Exception exc) { private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) { Map requestCountAttributes = new HashMap<>(); - requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(responseAttributes(throwable)); + requestCountAttributes.putAll(modelAndResponseAttributes(model, throwable)); requestCountAttributes.put("inference_source", "semantic_text_bulk"); inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); } @@ -637,7 +635,9 @@ private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure index addInferenceResponseFailure( itemIndex, new InferenceException( - "Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]", + "Unable to insert inference results into document [" + + indexRequest.getIndexRequest().id() + + "] due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes.", e ) ); @@ -749,7 +749,9 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons item.abort( item.index(), new InferenceException( - "Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]", + "Unable to insert inference results into document [" + + indexRequest.id() + + "] due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes.", e ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java index 5d04df5d2e1d5..c5e4abd3648c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java @@ -12,7 +12,7 @@ public enum ChunkingSettingsOptions { MAX_CHUNK_SIZE("max_chunk_size"), OVERLAP("overlap"), SENTENCE_OVERLAP("sentence_overlap"), - SEPARATOR_SET("separator_set"), + SEPARATOR_GROUP("separator_group"), SEPARATORS("separators"); private final String chunkingSettingsOption; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java index 690a3d8ff0efe..c68dc3b216744 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java @@ -60,7 +60,10 @@ private List chunk(String input, ChunkOffset offset, List s return chunkWithBackupChunker(input, offset, maxChunkSize); } - var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex)); + var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize( + splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex)), + maxChunkSize + ); var actualChunks = new ArrayList(); for (var potentialChunk : potentialChunks) { if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) { @@ -104,6 +107,33 @@ private List splitTextBySeparatorRegex(String input, ChunkO return chunkOffsets; } + private List mergeChunkOffsetsUpToMaxChunkSize(List chunkOffsets, int maxChunkSize) { + if (chunkOffsets.size() < 2) { + return chunkOffsets; + } + + List mergedOffsetsAndCounts = new ArrayList<>(); + var mergedChunk = chunkOffsets.getFirst(); + for (int i = 1; i < chunkOffsets.size(); i++) { + var chunkOffsetAndCountToMerge = chunkOffsets.get(i); + var potentialMergedChunk = new ChunkOffsetAndCount( + new ChunkOffset(mergedChunk.chunkOffset.start(), chunkOffsetAndCountToMerge.chunkOffset.end()), + mergedChunk.wordCount + chunkOffsetAndCountToMerge.wordCount + ); + if (isChunkWithinMaxSize(potentialMergedChunk, maxChunkSize)) { + mergedChunk = potentialMergedChunk; + } else { + mergedOffsetsAndCounts.add(mergedChunk); + mergedChunk = chunkOffsets.get(i); + } + + if (i == chunkOffsets.size() - 1) { + mergedOffsetsAndCounts.add(mergedChunk); + } + } + return mergedOffsetsAndCounts; + } + private List chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) { var chunks = new SentenceBoundaryChunker().chunk( input.substring(offset.start(), offset.end()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java index c368e1bb0c255..611736ceb4213 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java @@ -36,7 +36,7 @@ public class RecursiveChunkingSettings implements ChunkingSettings { private static final Set VALID_KEYS = Set.of( ChunkingSettingsOptions.STRATEGY.toString(), ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), - ChunkingSettingsOptions.SEPARATOR_SET.toString(), + ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), ChunkingSettingsOptions.SEPARATORS.toString() ); @@ -45,7 +45,7 @@ public class RecursiveChunkingSettings implements ChunkingSettings { public RecursiveChunkingSettings(int maxChunkSize, List separators) { this.maxChunkSize = maxChunkSize; - this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators; + this.separators = separators == null ? SeparatorGroup.PLAINTEXT.getSeparators() : separators; } public RecursiveChunkingSettings(StreamInput in) throws IOException { @@ -72,12 +72,12 @@ public static RecursiveChunkingSettings fromMap(Map map) { validationException ); - SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum( + SeparatorGroup separatorGroup = ServiceUtils.extractOptionalEnum( map, - ChunkingSettingsOptions.SEPARATOR_SET.toString(), + ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), ModelConfigurations.CHUNKING_SETTINGS, - SeparatorSet::fromString, - EnumSet.allOf(SeparatorSet.class), + SeparatorGroup::fromString, + EnumSet.allOf(SeparatorGroup.class), validationException ); @@ -88,12 +88,12 @@ public static RecursiveChunkingSettings fromMap(Map map) { validationException ); - if (separators != null && separatorSet != null) { + if (separators != null && separatorGroup != null) { validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set"); } - if (separatorSet != null) { - separators = separatorSet.getSeparators(); + if (separatorGroup != null) { + separators = separatorGroup.getSeparators(); } else if (separators != null && separators.isEmpty()) { validationException.addValidationError("Recursive chunking settings can not have an empty list of separators"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorGroup.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorGroup.java index 61b997b8d17a9..cafd3b08ccf9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorGroup.java @@ -10,17 +10,17 @@ import java.util.List; import java.util.Locale; -public enum SeparatorSet { +public enum SeparatorGroup { PLAINTEXT("plaintext"), MARKDOWN("markdown"); private final String name; - SeparatorSet(String name) { + SeparatorGroup(String name) { this.name = name; } - public static SeparatorSet fromString(String name) { + public static SeparatorGroup fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 56e994be86eb4..e2dff96c6ecb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -95,7 +95,7 @@ public void validateResponse( protected abstract void checkForFailureStatusCode(Request request, HttpResult result); - private void checkForErrorObject(Request request, HttpResult result) { + protected void checkForErrorObject(Request request, HttpResult result) { var errorEntity = errorParseFunction.apply(result); if (errorEntity.errorStructureFound()) { @@ -116,12 +116,12 @@ protected Exception buildError(String message, Request request, HttpResult resul protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { var responseStatusCode = result.response().getStatusLine().getStatusCode(); return new ElasticsearchStatusException( - errorMessage(message, request, result, errorResponse, responseStatusCode), + constructErrorMessage(message, request, errorResponse, responseStatusCode), toRestStatus(responseStatusCode) ); } - protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) { + public static String constructErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { return (errorResponse == null || errorResponse.errorStructureFound() == false || Strings.isNullOrEmpty(errorResponse.getErrorMessage())) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java new file mode 100644 index 0000000000000..89617478f01c0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.util.Locale; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT; +import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.toRestStatus; + +public class ChatCompletionErrorResponseHandler { + private static final String STREAM_ERROR = "stream_error"; + + private final UnifiedChatCompletionErrorParser unifiedChatCompletionErrorParser; + + public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParser errorParser) { + this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser); + } + + public void checkForErrorObject(Request request, HttpResult result) { + var errorEntity = unifiedChatCompletionErrorParser.parse(result); + + if (errorEntity.errorStructureFound()) { + // We don't really know what happened because the status code was 200 so we'll return a failure and let the + // client retry if necessary + // If we did want to retry here, we'll need to determine if this was a streaming request, if it was + // we shouldn't retry because that would replay the entire streaming request and the client would get + // duplicate chunks back + throw new RetryException(false, buildChatCompletionErrorInternal(SERVER_ERROR_OBJECT, request, result, errorEntity)); + } + } + + public UnifiedChatCompletionException buildChatCompletionError(String message, Request request, HttpResult result) { + var errorResponse = unifiedChatCompletionErrorParser.parse(result); + return buildChatCompletionErrorInternal(message, request, result, errorResponse); + } + + private UnifiedChatCompletionException buildChatCompletionErrorInternal( + String message, + Request request, + HttpResult result, + UnifiedChatCompletionErrorResponse errorResponse + ) { + assert request.isStreaming() : "Only streaming requests support this format"; + var statusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = BaseResponseHandler.constructErrorMessage(message, request, errorResponse, statusCode); + var restStatus = toRestStatus(statusCode); + + if (errorResponse.errorStructureFound()) { + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + errorResponse.type(), + errorResponse.code(), + errorResponse.param() + ); + } else { + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); + } + } + + /** + * Builds a default {@link UnifiedChatCompletionException} for a streaming request. + * This method is used when an error response is received we were unable to parse it in the format we were expecting. + * Only streaming requests should use this method. + * + * @param errorResponse the error response extracted from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + private static UnifiedChatCompletionException buildDefaultChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } + + /** + * Builds a mid-stream error for a streaming request. + * This method is used when an error occurs while processing a streaming response. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message + * @param e the exception that caused the error, can be null + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error + */ + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + var error = unifiedChatCompletionErrorParser.parse(message); + + if (error.errorStructureFound()) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + error.getErrorMessage() + ), + error.type(), + error.code(), + error.param() + ); + } else if (e != null) { + // If the error response does not match, we can still return an exception based on the original throwable + return UnifiedChatCompletionException.fromThrowable(e); + } else { + // If no specific error response is found, we return a default mid-stream error + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error); + } + } + + /** + * Builds a default mid-stream error for a streaming request. + * This method is used when no specific error response is found in the message. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the error response extracted from the message + * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error + */ + private static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + + /** + * Creates a string representation of the error type based on the provided ErrorResponse. + * This method is used to generate a human-readable error type for logging or exception messages. + * + * @param errorResponse the ErrorResponse object + * @return a string representing the error type + */ + private static String createErrorType(ErrorResponse errorResponse) { + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java index 7fc272931e7fb..a692f4de4fca2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java @@ -22,7 +22,7 @@ public ErrorResponse(String errorMessage) { this.errorStructureFound = true; } - private ErrorResponse(boolean errorStructureFound) { + protected ErrorResponse(boolean errorStructureFound) { this.errorMessage = ""; this.errorStructureFound = errorStructureFound; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParser.java new file mode 100644 index 0000000000000..60f1c44919ca9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParser.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +public interface UnifiedChatCompletionErrorParser { + UnifiedChatCompletionErrorResponse parse(HttpResult result); + + UnifiedChatCompletionErrorResponse parse(String result); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java new file mode 100644 index 0000000000000..3a70842455f1d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.core.Nullable; + +import java.util.Objects; + +public class UnifiedChatCompletionErrorResponse extends ErrorResponse { + public static final UnifiedChatCompletionErrorResponse UNDEFINED_ERROR = new UnifiedChatCompletionErrorResponse(); + + @Nullable + private final String code; + @Nullable + private final String param; + private final String type; + + public UnifiedChatCompletionErrorResponse(String errorMessage, String type, @Nullable String code, @Nullable String param) { + super(errorMessage); + this.code = code; + this.param = param; + this.type = Objects.requireNonNull(type); + } + + private UnifiedChatCompletionErrorResponse() { + super(false); + this.code = null; + this.param = null; + this.type = "unknown"; + } + + @Nullable + public String code() { + return code; + } + + @Nullable + public String param() { + return param; + } + + public String type() { + return type; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + UnifiedChatCompletionErrorResponse that = (UnifiedChatCompletionErrorResponse) o; + return Objects.equals(code, that.code) && Objects.equals(param, that.param) && Objects.equals(type, that.type); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), code, param, type); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java index 92333a10c4d08..8e55cc9c222b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java @@ -32,6 +32,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; import org.elasticsearch.search.fetch.subphase.highlight.HighlightUtils; import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; +import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.SparseVectorQueryWrapper; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.Text; @@ -273,6 +274,8 @@ public void visitLeaf(Query query) { queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null)); } else if (query instanceof MatchAllDocsQuery) { queries.add(new MatchAllDocsQuery()); + } else if (query instanceof DenseVectorQuery.Floats floatsQuery) { + queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(floatsQuery.getQuery()), null)); } } }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index 9e513a1ed9226..b1f5c240371f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -52,16 +52,20 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI assert (queryBuilder instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder; Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); + QueryBuilder finalQueryBuilder; if (inferenceIdsIndices.size() == 1) { // Simple case, everything uses the same inference ID Map.Entry> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next(); String searchInferenceId = inferenceIdIndex.getKey(); List indices = inferenceIdIndex.getValue(); - return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId); + finalQueryBuilder = buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId); } else { // Multiple inference IDs, construct a boolean query - return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices); + finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices); } + finalQueryBuilder.boost(queryBuilder.boost()); + finalQueryBuilder.queryName(queryBuilder.queryName()); + return finalQueryBuilder; } private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( @@ -102,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); } + boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index fd1d65d00faf5..a6599afc66c3f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -36,7 +36,10 @@ protected String getQuery(QueryBuilder queryBuilder) { @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { - return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); + SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); + semanticQueryBuilder.boost(queryBuilder.boost()); + semanticQueryBuilder.queryName(queryBuilder.queryName()); + return semanticQueryBuilder; } @Override @@ -45,7 +48,10 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( InferenceIndexInformationForField indexInformation ) { assert (queryBuilder instanceof MatchQueryBuilder); - MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; + MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder; + // Create a copy for non-inference fields without boost and _name + MatchQueryBuilder matchQueryBuilder = copyMatchQueryBuilder(originalMatchQueryBuilder); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.should( createSemanticSubQuery( @@ -55,6 +61,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder)); + boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } @@ -62,4 +70,24 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( public String getQueryName() { return MatchQueryBuilder.NAME; } + + private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) { + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value()); + matchQueryBuilder.operator(queryBuilder.operator()); + matchQueryBuilder.prefixLength(queryBuilder.prefixLength()); + matchQueryBuilder.maxExpansions(queryBuilder.maxExpansions()); + matchQueryBuilder.fuzzyTranspositions(queryBuilder.fuzzyTranspositions()); + matchQueryBuilder.lenient(queryBuilder.lenient()); + matchQueryBuilder.zeroTermsQuery(queryBuilder.zeroTermsQuery()); + matchQueryBuilder.analyzer(queryBuilder.analyzer()); + matchQueryBuilder.minimumShouldMatch(queryBuilder.minimumShouldMatch()); + matchQueryBuilder.fuzzyRewrite(queryBuilder.fuzzyRewrite()); + + if (queryBuilder.fuzziness() != null) { + matchQueryBuilder.fuzziness(queryBuilder.fuzziness()); + } + + matchQueryBuilder.autoGenerateSynonymsPhraseQuery(queryBuilder.autoGenerateSynonymsPhraseQuery()); + return matchQueryBuilder; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index 21feb21fbc2e5..c85a21f10301d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -43,14 +43,18 @@ protected String getQuery(QueryBuilder queryBuilder) { @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); + QueryBuilder finalQueryBuilder; if (inferenceIdsIndices.size() == 1) { // Simple case, everything uses the same inference ID String searchInferenceId = inferenceIdsIndices.keySet().iterator().next(); - return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId); + finalQueryBuilder = buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId); } else { // Multiple inference IDs, construct a boolean query - return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices); + finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices); } + finalQueryBuilder.queryName(queryBuilder.queryName()); + finalQueryBuilder.boost(queryBuilder.boost()); + return finalQueryBuilder; } private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( @@ -79,7 +83,19 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder)); + boolQueryBuilder.should( + createSubQueryForIndices( + indexInformation.nonInferenceIndices(), + new SparseVectorQueryBuilder( + sparseVectorQueryBuilder.getFieldName(), + sparseVectorQueryBuilder.getQueryVectors(), + sparseVectorQueryBuilder.getInferenceId(), + sparseVectorQueryBuilder.getQuery(), + sparseVectorQueryBuilder.shouldPruneTokens(), + sparseVectorQueryBuilder.getTokenPruningConfig() + ) + ) + ); // We always perform nested subqueries on semantic_text fields, to support // sparse_vector queries using query vectors. for (String inferenceId : inferenceIdsIndices.keySet()) { @@ -90,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) ); } + boolQueryBuilder.boost(queryBuilder.boost()); + boolQueryBuilder.queryName(queryBuilder.queryName()); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..5074749c1cd9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -42,11 +43,13 @@ public abstract class SenderService implements InferenceService { protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; + private final ClusterService clusterService; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(clusterService); } public Sender getSender() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index adbec49328804..f5f1074bfbb86 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -304,6 +305,12 @@ public static String invalidSettingError(String settingName, String scope) { return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); } + public static URI extractUri(Map map, String fieldName, ValidationException validationException) { + String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + + return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + } + public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { return createOptionalUri(url); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..da608779fee0a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -85,8 +87,20 @@ public class AlibabaCloudSearchService extends SenderService { InputType.INTERNAL_SEARCH ); - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..c2b0ae8e69c37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -93,9 +95,19 @@ public class AmazonBedrockService extends SenderService { public AmazonBedrockService( HttpRequestSender.Factory httpSenderFactory, AmazonBedrockRequestSender.Factory amazonBedrockFactory, - ServiceComponents serviceComponents + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(httpSenderFactory, serviceComponents); + this(httpSenderFactory, amazonBedrockFactory, serviceComponents, context.clusterService()); + } + + public AmazonBedrockService( + HttpRequestSender.Factory httpSenderFactory, + AmazonBedrockRequestSender.Factory amazonBedrockFactory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(httpSenderFactory, serviceComponents, clusterService); this.amazonBedrockSender = amazonBedrockFactory.createSender(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index bec8908ab73f9..8cf5446f8b6d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,16 @@ public class AnthropicService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); - public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AnthropicService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override @@ -263,6 +273,19 @@ public static InferenceServiceConfiguration get() { .build() ); + configurationMap.put( + AnthropicServiceFields.MAX_TOKENS, + new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION)).setDescription( + "The maximum number of tokens to generate before stopping." + ) + .setLabel("Max Tokens") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.INTEGER) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); configurationMap.putAll( RateLimitSettings.toSettingsConfigurationWithDescription( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java index 296b8cf09f8c0..50dd768efab1f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java @@ -10,6 +10,7 @@ public class AzureAiStudioConstants { public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings"; public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions"; + public static final String RERANK_URI_PATH = "/v1/rerank"; // common service settings fields public static final String TARGET_FIELD = "target"; @@ -22,6 +23,10 @@ public class AzureAiStudioConstants { public static final String DIMENSIONS_FIELD = "dimensions"; public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + // rerank task settings fields + public static final String DOCUMENTS_FIELD = "documents"; + public static final String QUERY_FIELD = "query"; + // embeddings task settings fields public static final String USER_FIELD = "user"; @@ -35,5 +40,9 @@ public class AzureAiStudioConstants { public static final Double MIN_TEMPERATURE_TOP_P = 0.0; public static final Double MAX_TEMPERATURE_TOP_P = 2.0; + // rerank task settings fields + public static final String RETURN_DOCUMENTS_FIELD = "return_documents"; + public static final String TOP_N_FIELD = "top_n"; + private AzureAiStudioConstants() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java index af064707536eb..a0a723d9edd27 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java @@ -22,6 +22,9 @@ public final class AzureAiStudioProviderCapabilities { // these providers have chat completion inference (all providers at the moment) public static final List chatCompletionProviders = List.of(AzureAiStudioProvider.values()); + // these providers have rerank inference + public static final List rerankProviders = List.of(AzureAiStudioProvider.COHERE); + // these providers allow token ("pay as you go") embeddings endpoints public static final List tokenEmbeddingsProviders = List.of( AzureAiStudioProvider.OPENAI, @@ -31,6 +34,9 @@ public final class AzureAiStudioProviderCapabilities { // these providers allow realtime embeddings endpoints (none at the moment) public static final List realtimeEmbeddingsProviders = List.of(); + // these providers allow realtime rerank endpoints (none at the moment) + public static final List realtimeRerankProviders = List.of(); + // these providers allow token ("pay as you go") chat completion endpoints public static final List tokenChatCompletionProviders = List.of( AzureAiStudioProvider.OPENAI, @@ -54,6 +60,9 @@ public static boolean providerAllowsTaskType(AzureAiStudioProvider provider, Tas case TEXT_EMBEDDING -> { return embeddingProviders.contains(provider); } + case RERANK -> { + return rerankProviders.contains(provider); + } default -> { return false; } @@ -76,6 +85,11 @@ public static boolean providerAllowsEndpointTypeForTask( ? tokenEmbeddingsProviders.contains(provider) : realtimeEmbeddingsProviders.contains(provider); } + case RERANK -> { + return (endpointType == AzureAiStudioEndpointType.TOKEN) + ? rerankProviders.contains(provider) + : realtimeRerankProviders.contains(provider); + } default -> { return false; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioRerankRequestManager.java new file mode 100644 index 0000000000000..da73cbb4ea69c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioRerankRequestManager.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; +import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel; +import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity; +import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler; + +import java.util.function.Supplier; + +public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager { + private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class); + + private static final ResponseHandler HANDLER = createRerankHandler(); + + private final AzureAiStudioRerankModel model; + + public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestRerankFunction, + ActionListener listener + ) { + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest( + model, + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN() + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener)); + } + + private static ResponseHandler createRerankHandler() { + // This currently covers response handling for Azure AI Studio + return new AzureMistralOpenAiExternalResponseHandler( + "azure ai studio rerank", + new AzureAiStudioRerankResponseEntity(), + ErrorMessageResponseEntity::fromResponse, + true + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 04883f23b947f..4a5a8be8b6633 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -44,6 +46,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -71,10 +74,10 @@ public class AzureAiStudioService extends SenderService { - static final String NAME = "azureaistudio"; + public static final String NAME = "azureaistudio"; private static final String SERVICE_NAME = "Azure AI Studio"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK); private static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -83,8 +86,16 @@ public class AzureAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override @@ -270,8 +281,9 @@ private static AzureAiStudioModel createModel( ConfigurationParseContext context ) { - if (taskType == TaskType.TEXT_EMBEDDING) { - var embeddingsModel = new AzureAiStudioEmbeddingsModel( + AzureAiStudioModel model; + switch (taskType) { + case TEXT_EMBEDDING -> model = new AzureAiStudioEmbeddingsModel( inferenceEntityId, taskType, NAME, @@ -281,16 +293,7 @@ private static AzureAiStudioModel createModel( secretSettings, context ); - checkProviderAndEndpointTypeForTask( - TaskType.TEXT_EMBEDDING, - embeddingsModel.getServiceSettings().provider(), - embeddingsModel.getServiceSettings().endpointType() - ); - return embeddingsModel; - } - - if (taskType == TaskType.COMPLETION) { - var completionModel = new AzureAiStudioChatCompletionModel( + case COMPLETION -> model = new AzureAiStudioChatCompletionModel( inferenceEntityId, taskType, NAME, @@ -299,15 +302,12 @@ private static AzureAiStudioModel createModel( secretSettings, context ); - checkProviderAndEndpointTypeForTask( - TaskType.COMPLETION, - completionModel.getServiceSettings().provider(), - completionModel.getServiceSettings().endpointType() - ); - return completionModel; + case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); } - - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings(); + checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType()); + return model; } private AzureAiStudioModel createModelFromPersistent( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java index 8bc93eebcbaea..269e1d7571152 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java @@ -13,8 +13,10 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager; import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel; import java.util.Map; import java.util.Objects; @@ -49,4 +51,12 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings"); return new SenderExecutableAction(sender, requestManager, errorMessage); } + + @Override + public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map taskSettings) { + var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings); + var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool()); + var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank"); + return new SenderExecutableAction(sender, requestManager, errorMessage); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java index 1c73d48b72307..64b8bd16b6d88 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel; import java.util.Map; @@ -17,4 +18,6 @@ public interface AzureAiStudioActionVisitor { ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map taskSettings); ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map taskSettings); + + ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequest.java new file mode 100644 index 0000000000000..d42637b4e9d02 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class AzureAiStudioRerankRequest extends AzureAiStudioRequest { + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final AzureAiStudioRerankModel rerankModel; + + public AzureAiStudioRerankRequest( + AzureAiStudioRerankModel model, + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN + ) { + super(model); + this.rerankModel = Objects.requireNonNull(model); + this.query = query; + this.input = Objects.requireNonNull(input); + this.returnDocuments = returnDocuments; + this.topN = topN; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + setAuthHeader(httpPost, rerankModel); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } + + private AzureAiStudioRerankRequestEntity createRequestEntity() { + return new AzureAiStudioRerankRequestEntity(query, input, returnDocuments, topN, rerankModel.getTaskSettings()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntity.java new file mode 100644 index 0000000000000..7ba0099460d79 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntity.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.request; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.QUERY_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD; + +public record AzureAiStudioRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + AzureAiStudioRerankTaskSettings taskSettings +) implements ToXContentObject { + + public AzureAiStudioRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(DOCUMENTS_FIELD, input); + builder.field(QUERY_FIELD, query); + + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments); + } else if (taskSettings.returnDocuments() != null) { + builder.field(RETURN_DOCUMENTS_FIELD, taskSettings.returnDocuments()); + } + + if (topN != null) { + builder.field(TOP_N_FIELD, topN); + } else if (taskSettings.topN() != null) { + builder.field(TOP_N_FIELD, taskSettings.topN()); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModel.java new file mode 100644 index 0000000000000..e9f116acd0354 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModel.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; +import org.elasticsearch.xpack.inference.services.azureaistudio.action.AzureAiStudioActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RERANK_URI_PATH; + +public class AzureAiStudioRerankModel extends AzureAiStudioModel { + + public static AzureAiStudioRerankModel of(AzureAiStudioRerankModel model, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return model; + } + + final var requestTaskSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(taskSettings); + final var taskSettingToUse = AzureAiStudioRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings); + + return new AzureAiStudioRerankModel(model, taskSettingToUse); + } + + public AzureAiStudioRerankModel( + String inferenceEntityId, + AzureAiStudioRerankServiceSettings serviceSettings, + AzureAiStudioRerankTaskSettings taskSettings, + DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, TaskType.RERANK, AzureAiStudioService.NAME, serviceSettings, taskSettings), + new ModelSecrets(secrets) + ); + } + + public AzureAiStudioRerankModel( + String inferenceEntityId, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + AzureAiStudioRerankServiceSettings.fromMap(serviceSettings, context), + AzureAiStudioRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + public AzureAiStudioRerankModel(AzureAiStudioRerankModel model, AzureAiStudioRerankTaskSettings taskSettings) { + super(model, taskSettings, model.getServiceSettings().rateLimitSettings()); + } + + @Override + public AzureAiStudioRerankServiceSettings getServiceSettings() { + return (AzureAiStudioRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public AzureAiStudioRerankTaskSettings getTaskSettings() { + return (AzureAiStudioRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return super.getSecretSettings(); + } + + @Override + protected URI getEndpointUri() throws URISyntaxException { + return new URI(this.target + RERANK_URI_PATH); + } + + @Override + public ExecutableAction accept(AzureAiStudioActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettings.java new file mode 100644 index 0000000000000..52cb85a7a07c3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettings.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD; + +public record AzureAiStudioRerankRequestTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) { + + public static final AzureAiStudioRerankRequestTaskSettings EMPTY_SETTINGS = new AzureAiStudioRerankRequestTaskSettings(null, null); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link AzureAiStudioRerankRequestTaskSettings} + */ + public static AzureAiStudioRerankRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS; + } + + final var validationException = new ValidationException(); + + final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException); + final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureAiStudioRerankRequestTaskSettings(returnDocuments, topN); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettings.java new file mode 100644 index 0000000000000..e302ab4d3606d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettings.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AzureAiStudioRerankServiceSettings extends AzureAiStudioServiceSettings { + public static final String NAME = "azure_ai_studio_rerank_service_settings"; + + public static AzureAiStudioRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + final var validationException = new ValidationException(); + + final var settings = rerankSettingsFromMap(map, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureAiStudioRerankServiceSettings(settings); + } + + private static AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields rerankSettingsFromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + final var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context); + return new AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields(baseSettings); + } + + private record AzureAiStudioRerankCommonFields(BaseAzureAiStudioCommonFields baseCommonFields) {} + + public AzureAiStudioRerankServiceSettings( + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + @Nullable RateLimitSettings rateLimitSettings + ) { + super(target, provider, endpointType, rateLimitSettings); + } + + public AzureAiStudioRerankServiceSettings(StreamInput in) throws IOException { + super(in); + } + + private AzureAiStudioRerankServiceSettings(AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields fields) { + this( + fields.baseCommonFields.target(), + fields.baseCommonFields.provider(), + fields.baseCommonFields.endpointType(), + fields.baseCommonFields.rateLimitSettings() + ); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + super.addXContentFields(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { + super.addExposedXContentFields(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureAiStudioRerankServiceSettings that = (AzureAiStudioRerankServiceSettings) o; + + return Objects.equals(target, that.target) + && Objects.equals(provider, that.provider) + && Objects.equals(endpointType, that.endpointType) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(target, provider, endpointType, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettings.java new file mode 100644 index 0000000000000..1faeee4a10f16 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettings.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD; + +/** + * Defines the rerank task settings for the AzureAiStudio service. + */ +public class AzureAiStudioRerankTaskSettings implements TaskSettings { + public static final String NAME = "azure_ai_studio_rerank_task_settings"; + + public static AzureAiStudioRerankTaskSettings fromMap(Map map) { + final var validationException = new ValidationException(); + + final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException); + final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureAiStudioRerankTaskSettings(returnDocuments, topN); + } + + /** + * Creates a new {@link AzureAiStudioRerankTaskSettings} object by overriding the values in originalSettings with the ones + * passed in via requestSettings if the fields are not null. + * @param originalSettings the original {@link AzureAiStudioRerankTaskSettings} from the inference entity configuration from storage + * @param requestSettings the {@link AzureAiStudioRerankTaskSettings} from the request + * @return a new {@link AzureAiStudioRerankTaskSettings} + */ + public static AzureAiStudioRerankTaskSettings of( + AzureAiStudioRerankTaskSettings originalSettings, + AzureAiStudioRerankRequestTaskSettings requestSettings + ) { + + final var returnDocuments = requestSettings.returnDocuments() == null + ? originalSettings.returnDocuments() + : requestSettings.returnDocuments(); + final var topN = requestSettings.topN() == null ? originalSettings.topN() : requestSettings.topN(); + + return new AzureAiStudioRerankTaskSettings(returnDocuments, topN); + } + + public AzureAiStudioRerankTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) { + this.returnDocuments = returnDocuments; + this.topN = topN; + } + + public AzureAiStudioRerankTaskSettings(StreamInput in) throws IOException { + this.returnDocuments = in.readOptionalBoolean(); + this.topN = in.readOptionalVInt(); + } + + private final Boolean returnDocuments; + private final Integer topN; + + public Boolean returnDocuments() { + return returnDocuments; + } + + public Integer topN() { + return topN; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED; + } + + @Override + public boolean isEmpty() { + return returnDocuments == null && topN == null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalVInt(topN); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments); + } + if (topN != null) { + builder.field(TOP_N_FIELD, topN); + } + + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return "AzureAiStudioRerankTaskSettings{" + ", returnDocuments=" + returnDocuments + ", topN=" + topN + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureAiStudioRerankTaskSettings that = (AzureAiStudioRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topN); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + AzureAiStudioRerankRequestTaskSettings requestSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, requestSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntity.java new file mode 100644 index 0000000000000..c3bd369c8fb42 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntity.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.response; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class AzureAiStudioRerankResponseEntity extends BaseResponseEntity { + /** + * Parses the AzureAiStudio Search rerank json response. + * For a request like: + * + *
    +     * 
    +     * {
    +     *     "model": "rerank-v3.5",
    +     *     "query": "What is the capital of the United States?",
    +     *     "top_n": 2,
    +     *     "documents": ["Carson City is the capital city of the American state of Nevada.",
    +     *                   "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."]
    +     * }
    +     * 
    +     * 
    + * + * The response would look like: + * + *
    +     * 
    +     * {
    +     *     "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
    +     *     "results": [
    +     *         {
    +     *             "document": {
    +     *                 "text": "Carson City is the capital city of the American state of Nevada."
    +     *             },
    +     *             "index": 0,
    +     *             "relevance_score": 0.1728413
    +     *         },
    +     *         {
    +     *             "document": {
    +     *                 "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."
    +     *             },
    +     *             "index": 1,
    +     *             "relevance_score": 0.031005697
    +     *         }
    +     *     ],
    +     *     "meta": {
    +     *         "api_version": {
    +     *             "version": "1"
    +     *         },
    +     *         "billed_units": {
    +     *             "search_units": 1
    +     *         }
    +     *     }
    +     * }
    +     * 
    +     * 
    + */ + @Override + protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + final var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + var rerankResult = RerankResult.PARSER.apply(jsonParser, null); + return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList()); + } + } + + record RerankResult(List entries) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResult.class.getSimpleName(), + true, + args -> new RerankResult((List) args[0]) + ); + static { + PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results")); + } + } + + record RerankResultEntry(Float relevanceScore, Integer index, @Nullable ObjectParser document) { + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResultEntry.class.getSimpleName(), + args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (ObjectParser) args[2]) + ); + static { + PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareObject(optionalConstructorArg(), ObjectParser.PARSER::apply, new ParseField("document")); + } + public RankedDocsResults.RankedDoc toRankedDoc() { + return new RankedDocsResults.RankedDoc(index, relevanceScore, document == null ? null : document.text); + } + } + + record ObjectParser(String text) { + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ObjectParser.class.getSimpleName(), + args -> new AzureAiStudioRerankResponseEntity.ObjectParser((String) args[0]) + ); + static { + PARSER.declareString(optionalConstructorArg(), new ParseField("text")); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..3d9a3dd516a2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -69,8 +71,16 @@ public class AzureOpenAiService extends SenderService { private static final String SERVICE_NAME = "Azure OpenAI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureOpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c2f1221763165..fb6c630bd60c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class CohereService extends SenderService { // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated // on every request - public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CohereService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java index f512444c6d6a4..2d52a8a9dadbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java @@ -28,13 +28,16 @@ public class CohereUtils { public static final String DOCUMENTS_FIELD = "documents"; public static final String EMBEDDING_TYPES_FIELD = "embedding_types"; public static final String INPUT_TYPE_FIELD = "input_type"; - public static final String MESSAGE_FIELD = "message"; + public static final String V1_MESSAGE_FIELD = "message"; + public static final String V2_MESSAGES_FIELD = "messages"; public static final String MODEL_FIELD = "model"; public static final String QUERY_FIELD = "query"; + public static final String V2_ROLE_FIELD = "role"; public static final String SEARCH_DOCUMENT = "search_document"; public static final String SEARCH_QUERY = "search_query"; - public static final String TEXTS_FIELD = "texts"; public static final String STREAM_FIELD = "stream"; + public static final String TEXTS_FIELD = "texts"; + public static final String USER_FIELD = "user"; public static Header createRequestSourceHeader() { return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java index 4fa4552dcd94d..0be1ba8d25f29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java @@ -30,7 +30,7 @@ public CohereV1CompletionRequest(List input, CohereCompletionModel model public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); // we only allow one input for completion, so always get the first one - builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst()); + builder.field(CohereUtils.V1_MESSAGE_FIELD, input.getFirst()); if (getModelId() != null) { builder.field(CohereUtils.MODEL_FIELD, getModelId()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java index 028c4a0d486c0..1a8eae321ac77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java @@ -29,8 +29,13 @@ public CohereV2CompletionRequest(List input, CohereCompletionModel model @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + builder.startArray(CohereUtils.V2_MESSAGES_FIELD); + builder.startObject(); + builder.field(CohereUtils.V2_ROLE_FIELD, CohereUtils.USER_FIELD); // we only allow one input for completion, so always get the first one - builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst()); + builder.field("content", input.getFirst()); + builder.endObject(); + builder.endArray(); builder.field(CohereUtils.MODEL_FIELD, getModelId()); builder.field(CohereUtils.STREAM_FIELD, isStreaming()); builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4e81d37ead3ad..5f5078affa9d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -74,8 +76,16 @@ public class CustomService extends SenderService { TaskType.COMPLETION ); - public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CustomService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..8a77efbd604d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -10,12 +10,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,16 @@ public class DeepSeekService extends SenderService { ); private static final EnumSet SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public DeepSeekService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 36712ed922e95..58e964bb5c25f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -22,6 +23,7 @@ import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -139,9 +141,28 @@ public ElasticInferenceService( ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents); + this( + factory, + serviceComponents, + elasticInferenceServiceSettings, + modelRegistry, + authorizationRequestHandler, + context.clusterService() + ); + } + + public ElasticInferenceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index ad40d43b3af3b..f45d390b404d4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -49,7 +49,7 @@ protected Exception buildError(String message, Request request, HttpResult resul var restStatus = toRestStatus(responseStatusCode); return new UnifiedChatCompletionException( restStatus, - errorMessage(message, request, result, errorResponse, responseStatusCode), + constructErrorMessage(message, request, errorResponse, responseStatusCode), "error", restStatus.name().toLowerCase(Locale.ROOT) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java index b45d4449251f4..007dc820c629f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java @@ -28,7 +28,7 @@ public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInpu @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(modelId, params)); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 53e859b7f7a4d..4aaf3c2db2e61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; @@ -23,6 +22,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -38,13 +38,16 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.function.Consumer; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -55,6 +58,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi protected final ExecutorService inferenceExecutor; protected final Consumer> preferredModelVariantFn; private final ClusterService clusterService; + private final InferenceStats inferenceStats; public enum PreferredModelVariant { LINUX_X86_OPTIMIZED, @@ -69,10 +73,11 @@ public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServi this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture; this.clusterService = context.clusterService(); + this.inferenceStats = context.inferenceStats(); } // For testing. - // platformArchFn enables similating different architectures + // platformArchFn enables simulating different architectures // without extensive mocking on the client to simulate the nodes info response. // TODO make package private once the elser service is moved to the Elasticsearch // service package. @@ -85,6 +90,7 @@ public BaseElasticsearchInternalService( this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); this.preferredModelVariantFn = preferredModelVariantFn; this.clusterService = context.clusterService(); + this.inferenceStats = context.inferenceStats(); } @Override @@ -103,6 +109,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL return; } + var timer = InferenceTimer.start(); // instead of a subscribably listener, use some wait to wait for the first one. var subscribableListener = SubscribableListener.newForked( forkedListener -> { isBuiltinModelPut(model, forkedListener); } @@ -114,25 +121,29 @@ public void start(Model model, TimeValue timeout, ActionListener finalL } }).andThen((l2, modelDidPut) -> { var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout); - var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2); + var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(esModel, l2); client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); }); subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor); - subscribableListener.addListener(finalListener.delegateResponse((l, e) -> { + subscribableListener.addListener(ActionListener.wrap(started -> { + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, null)); + finalListener.onResponse(started); + }, e -> { if (e instanceof ElasticsearchTimeoutException) { - l.onFailure( - new ModelDeploymentTimeoutException( - format( - "Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. " - + "The inference endpoint can not be used to perform inference until the deployment has started. " - + "Use the trained model stats API to track the state of the deployment.", - timeout, - model.getInferenceEntityId() - ) + var timeoutException = new ModelDeploymentTimeoutException( + format( + "Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. " + + "The inference endpoint can not be used to perform inference until the deployment has started. " + + "Use the trained model stats API to track the state of the deployment.", + timeout, + model.getInferenceEntityId() ) ); + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, timeoutException)); + finalListener.onFailure(timeoutException); } else { - l.onFailure(e); + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, unwrapCause(e))); + finalListener.onFailure(e); } })); @@ -323,7 +334,7 @@ protected void maybeStartDeployment( InferModelAction.Request request, ActionListener listener ) { - if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + if (isDefaultId(model.getInferenceEntityId()) && unwrapCause(e) instanceof ResourceNotFoundException) { this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java index ce6c6258d0393..5a81eb6b04bcd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; @@ -43,7 +42,7 @@ protected String modelNotFoundErrorMessage(String modelId) { @Override public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, + ElasticsearchInternalModel esModel, ActionListener listener ) { throw new IllegalStateException("cannot start model that uses an existing deployment"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java index 276bce6dbe8f8..2c8bf8270fabc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java @@ -9,7 +9,6 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -33,7 +32,7 @@ public ElasticRerankerServiceSettings getServiceSettings() { @Override public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, + ElasticsearchInternalModel esModel, ActionListener listener ) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index f1011efd3b12c..6a553480e68cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -21,6 +21,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; @@ -85,12 +87,13 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA } public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, + ElasticsearchInternalModel esModel, ActionListener listener ) { return new ActionListener<>() { @Override public void onResponse(CreateTrainedModelAssignmentAction.Response response) { + esModel.updateServiceSettings(response.getTrainedModelAssignment()); listener.onResponse(Boolean.TRUE); } @@ -98,7 +101,7 @@ public void onResponse(CreateTrainedModelAssignmentAction.Response response) { public void onFailure(Exception e) { var cause = ExceptionsHelper.unwrapCause(e); if (cause instanceof ResourceNotFoundException) { - listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId()))); + listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(esModel.internalServiceSettings.modelId()))); return; } else if (cause instanceof ElasticsearchStatusException statusException) { if (statusException.status() == RestStatus.CONFLICT @@ -128,8 +131,18 @@ public ElasticsearchInternalServiceSettings getServiceSettings() { return (ElasticsearchInternalServiceSettings) super.getServiceSettings(); } - public void updateNumAllocations(Integer numAllocations) { - this.internalServiceSettings.setNumAllocations(numAllocations); + public void updateServiceSettings(AssignmentStats assignmentStats) { + this.internalServiceSettings.setAllocations( + assignmentStats.getNumberOfAllocations(), + assignmentStats.getAdaptiveAllocationsSettings() + ); + } + + private void updateServiceSettings(TrainedModelAssignment trainedModelAssignment) { + this.internalServiceSettings.setAllocations( + this.internalServiceSettings.getNumAllocations(), + trainedModelAssignment.getAdaptiveAllocationsSettings() + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 4f2674179be67..b17392311629f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -890,7 +890,7 @@ public void updateModelsWithDynamicFields(List models, ActionListener { for (var deploymentStats : stats.getStats().results()) { var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId()); - modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations())); + modelsForDeploymentId.forEach(model -> model.updateServiceSettings(deploymentStats)); } var updatedModels = new ArrayList(); modelsByDeploymentIds.values().forEach(updatedModels::addAll); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 98730f33d10f9..67a537b5ac2c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContentObject; @@ -21,6 +22,7 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -43,7 +45,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings { private Integer numAllocations; private final int numThreads; private final String modelId; - private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; private final String deploymentId; public static ElasticsearchInternalServiceSettings fromPersistedMap(Map map) { @@ -158,8 +160,9 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null; } - public void setNumAllocations(Integer numAllocations) { + public void setAllocations(Integer numAllocations, @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings) { this.numAllocations = numAllocations; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } @Override @@ -239,6 +242,48 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_13_0; } + public ElasticsearchInternalServiceSettings updateServiceSettings(Map serviceSettings) { + var validationException = new ValidationException(); + var mutableServiceSettings = new HashMap<>(serviceSettings); + + var numAllocations = extractOptionalPositiveInteger( + mutableServiceSettings, + NUM_ALLOCATIONS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + mutableServiceSettings, + ADAPTIVE_ALLOCATIONS, + validationException + ); + + if (numAllocations == null && adaptiveAllocationsSettings == null) { + validationException.addValidationError( + ServiceUtils.missingOneOfSettingsErrorMsg( + List.of(NUM_ALLOCATIONS, ADAPTIVE_ALLOCATIONS), + ModelConfigurations.SERVICE_SETTINGS + ) + ); + } + if (numAllocations != null && adaptiveAllocationsSettings != null) { + validationException.addValidationError( + Strings.format("[%s] cannot be set if [%s] is set", NUM_ALLOCATIONS, ADAPTIVE_ALLOCATIONS) + ); + } + validationException.throwIfValidationErrorsExist(); + + return toBuilder().setNumAllocations(numAllocations).setAdaptiveAllocationsSettings(adaptiveAllocationsSettings).build(); + } + + public Builder toBuilder() { + return new Builder().setAdaptiveAllocationsSettings(adaptiveAllocationsSettings) + .setDeploymentId(deploymentId) + .setModelId(modelId) + .setNumThreads(numThreads) + .setNumAllocations(numAllocations); + } + public static class Builder { private Integer numAllocations; private int numThreads; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..4c8997f35555b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -82,8 +84,16 @@ public class GoogleAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..2c2c667cd6eee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -97,8 +99,16 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } - public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleVertexAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 9e6fdb6eb8bb5..56969c3390268 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -10,9 +10,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; @@ -22,29 +19,31 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ChatCompletionErrorResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponse; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; import java.nio.charset.StandardCharsets; -import java.util.Locale; -import java.util.Objects; import java.util.Optional; import java.util.concurrent.Flow; -import static org.elasticsearch.core.Strings.format; - public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { private static final String ERROR_FIELD = "error"; private static final String ERROR_CODE_FIELD = "code"; private static final String ERROR_MESSAGE_FIELD = "message"; private static final String ERROR_STATUS_FIELD = "status"; + private static final GoogleVertexAiErrorParser ERROR_PARSER = new GoogleVertexAiErrorParser(); + + private final ChatCompletionErrorResponseHandler chatCompletionErrorResponseHandler; public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) { super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true); + this.chatCompletionErrorResponseHandler = new ChatCompletionErrorResponseHandler(ERROR_PARSER); } @Override @@ -52,7 +51,9 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher buildMidStreamError(request, m, e)); + var googleVertexAiProcessor = new GoogleVertexAiUnifiedStreamingProcessor( + (m, e) -> chatCompletionErrorResponseHandler.buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(googleVertexAiProcessor); @@ -60,63 +61,35 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher, Void> ERROR_PARSER = new ConstructingObjectParser<>( - "google_vertex_ai_error_wrapper", - true, - args -> Optional.ofNullable((GoogleVertexAiErrorResponse) args[0]) - ); + public static class GoogleVertexAiErrorResponse extends UnifiedChatCompletionErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = + new ConstructingObjectParser<>( + "google_vertex_ai_error_wrapper", + true, + args -> Optional.ofNullable((GoogleVertexAiErrorResponse) args[0]) + ); private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_body", @@ -137,46 +110,39 @@ public static class GoogleVertexAiErrorResponse extends ErrorResponse { ); } - public static ErrorResponse fromResponse(HttpResult response) { + public static UnifiedChatCompletionErrorResponse fromResponse(HttpResult response) { try ( XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(XContentParserConfiguration.EMPTY, response.body()) ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + return ERROR_PARSER.apply(parser, null).orElse(UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR); } catch (Exception e) { var resultAsString = new String(response.body(), StandardCharsets.UTF_8); - return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", resultAsString)); + return new GoogleVertexAiErrorResponse( + Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", resultAsString) + ); } } - static ErrorResponse fromString(String response) { + static UnifiedChatCompletionErrorResponse fromString(String response) { try ( XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(XContentParserConfiguration.EMPTY, response) ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + return ERROR_PARSER.apply(parser, null).orElse(UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR); } catch (Exception e) { - return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", response)); + return new GoogleVertexAiErrorResponse( + Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", response) + ); } } - private final int code; - @Nullable - private final String status; - - GoogleVertexAiErrorResponse(Integer code, String errorMessage, @Nullable String status) { - super(Objects.requireNonNull(errorMessage)); - this.code = code == null ? 0 : code; - this.status = status; - } - - public int code() { - return code; + GoogleVertexAiErrorResponse(@Nullable Integer code, String errorMessage, @Nullable String status) { + super(errorMessage, status != null ? status : "google_vertex_ai_error", code == null ? "0" : String.valueOf(code), null); } - @Nullable - public String status() { - return status != null ? status : "google_vertex_ai_error"; + GoogleVertexAiErrorResponse(String errorMessage) { + super(errorMessage, "google_vertex_ai_error", null, null); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index b0d40b41914d5..325f88c8904a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.inference.services.huggingface; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -44,8 +46,16 @@ public abstract class HuggingFaceBaseService extends SenderService { */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceBaseService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index 8dffd612db5c8..8b8deaef3279d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -45,7 +45,7 @@ protected Exception buildError(String message, Request request, HttpResult resul assert request.isStreaming() : "Only streaming requests support this format"; var responseStatusCode = result.response().getStatusLine().getStatusCode(); if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); var restStatus = toRestStatus(responseStatusCode); return errorResponse instanceof HuggingFaceErrorResponseEntity ? new UnifiedChatCompletionException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..bc64e832d182a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -11,10 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -71,8 +73,16 @@ public class HuggingFaceService extends HuggingFaceBaseService { OpenAiChatCompletionResponseEntity::fromResponse ); - public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index 7429153835ee3..91735d39f3973 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -31,11 +31,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceServiceSettings extends FilteredXContentObject implements ServiceSettings, HuggingFaceRateLimitServiceSettings { public static final String NAME = "hugging_face_service_settings"; @@ -70,12 +69,6 @@ public static HuggingFaceServiceSettings fromMap(Map map, Config return new HuggingFaceServiceSettings(uri, similarityMeasure, dims, maxInputTokens, rateLimitSettings); } - public static URI extractUri(Map map, String fieldName, ValidationException validationException) { - String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - - return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - } - private final URI uri; private final SimilarityMeasure similarity; private final Integer dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index cdc2529428bed..64da6e32bc1f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -31,7 +31,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Settings for the Hugging Face chat completion service. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..5f9288bb99c24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -57,8 +59,16 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face ELSER"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); - public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceElserService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java index b1d3297fc6328..ad771e72b6b35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java @@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceElserServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java index b0b21b26395af..57c103bbbf3b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java @@ -27,7 +27,7 @@ import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceRerankServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java index 41b82bbf2cd02..570dbd1e709ee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -34,7 +34,7 @@ protected Exception buildError(String message, Request request, HttpResult resul assert request.isStreaming() : "Only streaming requests support this format"; var responseStatusCode = result.response().getStatusLine().getStatusCode(); if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); var restStatus = toRestStatus(responseStatusCode); return errorResponse instanceof IbmWatsonxErrorResponseEntity ? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 9bc63be1f9e7e..9617bff0d3f3d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -83,8 +85,16 @@ public class IbmWatsonxService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public IbmWatsonxService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..00e1aede95a2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -76,8 +78,16 @@ public class JinaAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public JinaAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java new file mode 100644 index 0000000000000..3e24d058d8540 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +/** + * Abstract class representing a Llama model for inference. + * This class extends RateLimitGroupingModel and provides common functionality for Llama models. + */ +public abstract class LlamaModel extends RateLimitGroupingModel { + protected URI uri; + protected RateLimitSettings rateLimitSettings; + + /** + * Constructor for creating a LlamaModel with specified configurations and secrets. + * + * @param configurations the model configurations + * @param secrets the secret settings for the model + */ + protected LlamaModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + /** + * Constructor for creating a LlamaModel with specified model, service settings, and secret settings. + * @param model the model configurations + * @param serviceSettings the settings for the inference service + */ + protected LlamaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public URI uri() { + return this.uri; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().modelId(), uri, getSecretSettings()); + } + + // Needed for testing only + public void setURI(String newUri) { + try { + this.uri = new URI(newUri); + } catch (URISyntaxException e) { + // swallow any error + } + } + + /** + * Retrieves the secret settings from the provided map of secrets. + * If the map is null or empty, it returns an instance of EmptySecretSettings. + * Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication. + * + * @param secrets the map containing secret settings + * @return an instance of SecretSettings + */ + protected static SecretSettings retrieveSecretSettings(Map secrets) { + return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets); + } + + protected abstract ExecutableAction accept(LlamaActionVisitor creator); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java new file mode 100644 index 0000000000000..bd6b3c91fc9e9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -0,0 +1,423 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionCreator; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * LlamaService is an inference service for Llama models, supporting text embedding and chat completion tasks. + * It extends SenderService to handle HTTP requests and responses for Llama models. + */ +public class LlamaService extends SenderService { + public static final String NAME = "llama"; + private static final String SERVICE_NAME = "Llama"; + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For Llama use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler( + "llama chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + /** + * Constructor for creating a LlamaService with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + * @param context the context for the inference service factory + */ + public LlamaService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + if (model instanceof LlamaModel llamaModel) { + llamaModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + /** + * Creates a LlamaModel based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param failureMessage the message to use in case of failure + * @param context the context for parsing configuration settings + * @return a new instance of LlamaModel based on the provided parameters + */ + protected LlamaModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING: + return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); + case CHAT_COMPLETION, COMPLETION: + return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof LlamaEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new LlamaEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + similarityToUse, + serviceSettings.maxInputTokens(), + serviceSettings.rateLimitSettings() + ); + + return new LlamaEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof LlamaEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaModel = (LlamaEmbeddingsModel) model; + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + llamaModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = llamaModel.accept(actionCreator); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof LlamaChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaChatCompletionModel = (LlamaChatCompletionModel) model; + var overriddenModel = LlamaChatCompletionModel.of(llamaChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new LlamaChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = LlamaActionCreator.buildErrorMessage(CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(COMPLETION, CHAT_COMPLETION); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + LlamaModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + private LlamaModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public boolean hideFromConfigurationApi() { + // The Llama service is very configurable so we're going to hide it from being exposed in the service API. + return true; + } + + /** + * Configuration class for the Llama inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private Configuration() {} + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the Llama models documentation for the list of available models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java new file mode 100644 index 0000000000000..52e284ba7ccca --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.llama.request.embeddings.LlamaEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; + +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates actions for Llama inference requests, handling both embeddings and completions. + * This class implements the {@link LlamaActionVisitor} interface to provide specific action creation methods. + */ +public class LlamaActionCreator implements LlamaActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Llama %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "Llama completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new LlamaEmbeddingsResponseHandler( + "llama text embedding", + HuggingFaceEmbeddingsResponseEntity::fromResponse + ); + private static final ResponseHandler COMPLETION_HANDLER = new LlamaCompletionResponseHandler( + "llama completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new LlamaActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public LlamaActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(LlamaEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new LlamaEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(LlamaChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java new file mode 100644 index 0000000000000..1521b83b668c7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; + +/** + * Visitor interface for creating executable actions for Llama inference models. + * This interface defines methods to create actions for both embeddings and chat completion models. + */ +public interface LlamaActionVisitor { + /** + * Creates an executable action for the given Llama embeddings model. + * + * @param model the Llama embeddings model + * @return an executable action for the embeddings model + */ + ExecutableAction create(LlamaEmbeddingsModel model); + + /** + * Creates an executable action for the given Llama chat completion model. + * + * @param model the Llama chat completion model + * @return an executable action for the chat completion model + */ + ExecutableAction create(LlamaChatCompletionModel model); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java new file mode 100644 index 0000000000000..a1a38f1eae326 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; + +import java.util.Map; + +/** + * Represents a Llama chat completion model for inference. + * This class extends the LlamaModel and provides specific configurations and settings for chat completion tasks. + */ +public class LlamaChatCompletionModel extends LlamaModel { + + /** + * Constructor for creating a LlamaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public LlamaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + LlamaChatCompletionServiceSettings.fromMap(serviceSettings, context), + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a LlamaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public LlamaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + LlamaChatCompletionServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Factory method to create a LlamaChatCompletionModel with overridden model settings based on the request. + * If the request does not specify a model, the original model is returned. + * + * @param model the original LlamaChatCompletionModel + * @param request the UnifiedCompletionRequest containing potential overrides + * @return a new LlamaChatCompletionModel with overridden settings or the original model if no overrides are specified + */ + public static LlamaChatCompletionModel of(LlamaChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model id is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new LlamaChatCompletionServiceSettings( + request.model(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new LlamaChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + private void setPropertiesFromServiceSettings(LlamaChatCompletionServiceSettings serviceSettings) { + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Returns the service settings specific to Llama chat completion. + * + * @return the LlamaChatCompletionServiceSettings associated with this model + */ + @Override + public LlamaChatCompletionServiceSettings getServiceSettings() { + return (LlamaChatCompletionServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor that creates an executable action for this Llama chat completion model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing this model + */ + @Override + public ExecutableAction accept(LlamaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..85d60308d77d3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; +import java.util.Optional; + +import static org.elasticsearch.core.Strings.format; + +/** + * Handles streaming chat completion responses and error parsing for Llama inference endpoints. + * This handler is designed to work with the unified Llama chat completion API. + */ +public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String LLAMA_ERROR = "llama_error"; + private static final String STREAM_ERROR = "stream_error"; + + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + } + + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type, + * @param message the error message to include in the exception + * @param request the request that caused the error + * @param result the HTTP result containing the response + * @param errorResponse the error response parsed from the HTTP result + * @return an exception representing the error, specific to Llama chat completion + */ + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof LlamaErrorResponse + ? new UnifiedChatCompletionException(restStatus, errorMessage, LLAMA_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } + + /** + * Builds an exception for mid-stream errors encountered during Llama chat completion requests. + * + * @param request the request that caused the error + * @param message the error message + * @param e the exception that occurred, if any + * @return a UnifiedChatCompletionException representing the error + */ + @Override + protected Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = StreamingLlamaErrorResponseEntity.fromString(message); + if (errorResponse instanceof StreamingLlamaErrorResponseEntity) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + LLAMA_ERROR, + STREAM_ERROR + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + } + + /** + * StreamingLlamaErrorResponseEntity allows creation of {@link ErrorResponse} from a JSON string. + * This entity is used to parse error responses from streaming Llama requests. + * For non-streaming requests {@link LlamaErrorResponse} should be used. + * Example error response for Bad Request error would look like: + *
    
    +     *  {
    +     *      "error": {
    +     *          "message": "400: Invalid value: Model 'llama3.12:3b' not found"
    +     *      }
    +     *  }
    +     * 
    + */ + private static class StreamingLlamaErrorResponseEntity extends ErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + LLAMA_ERROR, + true, + args -> Optional.ofNullable((LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity) args[0]) + ); + private static final ConstructingObjectParser< + LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity, + Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( + LLAMA_ERROR, + true, + args -> new LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity( + args[0] != null ? (String) args[0] : "unknown" + ) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + /** + * Parses a streaming Llama error response from a JSON string. + * + * @param response the raw JSON string representing an error + * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails + */ + private static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + /** + * Constructs a StreamingLlamaErrorResponseEntity with the specified error message. + * + * @param errorMessage the error message to include in the response entity + */ + StreamingLlamaErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..7917a8cba5b48 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Represents the settings for a Llama chat completion service. + * This class encapsulates the model ID, URI, and rate limit settings for the Llama chat completion service. + */ +public class LlamaChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "llama_completion_service_settings"; + // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of LlamaChatCompletionServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of LlamaChatCompletionServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static LlamaChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + LlamaService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new LlamaChatCompletionServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public LlamaChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public LlamaChatCompletionServiceSettings(String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the model + * @param url the URL of the service + */ + public LlamaChatCompletionServiceSettings(String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + /** + * Returns the URI of the Llama chat completion service. + * + * @return the URI of the service + */ + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for the Llama chat completion service. + * + * @return the rate limit settings + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LlamaChatCompletionServiceSettings that = (LlamaChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java new file mode 100644 index 0000000000000..8e3b5b10df900 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for Llama models, extending the OpenAI completion response handler. + * This class is specifically designed to handle Llama's error response format. + */ +public class LlamaCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a LlamaCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "llama completions"). + * @param parseFunction The function to parse the response. + */ + public LlamaCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java new file mode 100644 index 0000000000000..ebf0b7e8132c1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; + +import java.util.Map; + +/** + * Represents a Llama embeddings model for inference. + * This class extends the LlamaModel and provides specific configurations and settings for embeddings tasks. + */ +public class LlamaEmbeddingsModel extends LlamaModel { + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public LlamaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + LlamaEmbeddingsServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param model the base LlamaEmbeddingsModel to copy properties from + * @param serviceSettings the settings for the inference service, specific to embeddings + */ + public LlamaEmbeddingsModel(LlamaEmbeddingsModel model, LlamaEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Sets properties from the provided LlamaEmbeddingsServiceSettings. + * + * @param serviceSettings the service settings to extract properties from + */ + private void setPropertiesFromServiceSettings(LlamaEmbeddingsServiceSettings serviceSettings) { + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public LlamaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + LlamaEmbeddingsServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + @Override + public LlamaEmbeddingsServiceSettings getServiceSettings() { + return (LlamaEmbeddingsServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Llama embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Llama embeddings model + */ + @Override + public ExecutableAction accept(LlamaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..240ccf46c7482 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +/** + * Handles responses for Llama embeddings requests, parsing the response and handling errors. + * This class extends OpenAiResponseHandler to provide specific functionality for Llama embeddings. + */ +public class LlamaEmbeddingsResponseHandler extends OpenAiResponseHandler { + + /** + * Constructs a new LlamaEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..a14146070247a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Settings for the Llama embeddings service. + * This class encapsulates the configuration settings required to use Llama for generating embeddings. + */ +public class LlamaEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "llama_embeddings_service_settings"; + // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final Integer dimensions; + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of LlamaEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of LlamaEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static LlamaEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + var maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, LlamaService.NAME, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new LlamaEmbeddingsServiceSettings(model, uri, dimensions, similarity, maxInputTokens, rateLimitSettings); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public LlamaEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.dimensions = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the Llama service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public LlamaEmbeddingsServiceSettings( + String modelId, + URI uri, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.uri = uri; + this.dimensions = dimensions; + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the Llama service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public LlamaEmbeddingsServiceSettings( + String modelId, + String url, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + public URI uri() { + return this.uri; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the maximum number of input tokens allowed for this service. + * + * @return the maximum input tokens, or null if not specified + */ + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + /** + * Returns the rate limit settings for this service. + * + * @return the rate limit settings, never null + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LlamaEmbeddingsServiceSettings that = (LlamaEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, dimensions, maxInputTokens, similarity, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java new file mode 100644 index 0000000000000..3bb01f215087e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Llama Chat Completion Request + * This class is responsible for creating a request to the Llama chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class LlamaChatCompletionRequest implements Request { + + private final LlamaChatCompletionModel model; + private final UnifiedChatInput chatInput; + + /** + * Constructs a new LlamaChatCompletionRequest with the specified chat input and model. + * + * @param chatInput the chat input containing the messages and parameters for the completion request + * @param model the Llama chat completion model to be used for the request + */ + public LlamaChatCompletionRequest(UnifiedChatInput chatInput, LlamaChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + /** + * Returns the chat input for this request. + * + * @return the chat input containing the messages and parameters + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new LlamaChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) { + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for Llama chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Llama chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..fc80dab09f6f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +/** + * LlamaChatCompletionRequestEntity is responsible for creating the request entity for Llama chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class LlamaChatCompletionRequestEntity implements ToXContentObject { + + private final LlamaChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + /** + * Constructs a LlamaChatCompletionRequestEntity with the specified unified chat input and model. + * + * @param unifiedChatInput the unified chat input containing messages and parameters for the completion request + * @param model the Llama chat completion model to be used for the request + */ + public LlamaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, LlamaChatCompletionModel model) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params)); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java new file mode 100644 index 0000000000000..5883880dbb812 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Llama Embeddings Request + * This class is responsible for creating a request to the Llama embeddings model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class LlamaEmbeddingsRequest implements Request { + private final URI uri; + private final LlamaEmbeddingsModel model; + private final String inferenceEntityId; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new LlamaEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the Llama embeddings model to be used for the request + */ + public LlamaEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, LlamaEmbeddingsModel model) { + this.uri = model.uri(); + this.model = model; + this.inferenceEntityId = model.getInferenceEntityId(); + this.truncator = truncator; + this.truncationResult = input; + } + + /** + * Returns the URI for this request. + * + * @return the URI of the Llama embeddings model + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new LlamaEmbeddingsRequestEntity(model.getServiceSettings().modelId(), truncationResult.input())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) { + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new LlamaEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..3f734bacec87d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * LlamaEmbeddingsRequestEntity is responsible for creating the request entity for Llama embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public record LlamaEmbeddingsRequestEntity(String modelId, List contents) implements ToXContentObject { + + public static final String CONTENTS_FIELD = "contents"; + public static final String MODEL_ID_FIELD = "model_id"; + + /** + * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and contents. + * + * @param modelId the ID of the model to use for embeddings + * @param contents the list of contents to generate embeddings for + */ + public LlamaEmbeddingsRequestEntity { + Objects.requireNonNull(modelId); + Objects.requireNonNull(contents); + } + + /** + * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and a single content string. + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_ID_FIELD, modelId); + builder.field(CONTENTS_FIELD, contents); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java new file mode 100644 index 0000000000000..727231209fdf1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.response; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.nio.charset.StandardCharsets; + +/** + * LlamaErrorResponse is responsible for handling error responses from Llama inference services. + * It extends ErrorResponse to provide specific functionality for Llama errors. + * An example error response for Not Found error would look like: + *
    
    + *  {
    + *      "detail": "Not Found"
    + *  }
    + * 
    + * An example error response for Bad Request error would look like: + *
    
    + *  {
    + *     "error": {
    + *         "detail": {
    + *             "errors": [
    + *                 {
    + *                     "loc": [
    + *                         "body",
    + *                         "model"
    + *                     ],
    + *                     "msg": "Field required",
    + *                     "type": "missing"
    + *                 }
    + *             ]
    + *         }
    + *     }
    + *  }
    + * 
    + */ +public class LlamaErrorResponse extends ErrorResponse { + + public LlamaErrorResponse(String message) { + super(message); + } + + public static ErrorResponse fromResponse(HttpResult response) { + try { + String errorMessage = new String(response.body(), StandardCharsets.UTF_8); + return new LlamaErrorResponse(errorMessage); + } catch (Exception e) { + // swallow the error + } + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java index 57219a03b3bdb..55a5b4fe71047 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java @@ -10,19 +10,21 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.net.URISyntaxException; +import java.util.Objects; /** * Represents a Mistral model that can be used for inference tasks. * This class extends RateLimitGroupingModel to handle rate limiting based on model and API key. */ public abstract class MistralModel extends RateLimitGroupingModel { - protected String model; protected URI uri; protected RateLimitSettings rateLimitSettings; @@ -34,10 +36,6 @@ protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSett super(model, serviceSettings); } - public String model() { - return this.model; - } - public URI uri() { return this.uri; } @@ -49,7 +47,7 @@ public RateLimitSettings rateLimitSettings() { @Override public int rateLimitGroupingHash() { - return 0; + return Objects.hash(getServiceSettings().modelId(), getSecretSettings().apiKey()); } // Needed for testing only @@ -65,4 +63,6 @@ public void setURI(String newUri) { public DefaultSecretSettings getSecretSettings() { return (DefaultSecretSettings) super.getSecretSettings(); } + + public abstract ExecutableAction accept(MistralActionVisitor creator); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..c1eee5eb27338 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class MistralService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public MistralService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override @@ -98,16 +108,10 @@ protected void doInfer( ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); - switch (model) { - case MistralEmbeddingsModel mistralEmbeddingsModel: - mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); - break; - case MistralChatCompletionModel mistralChatCompletionModel: - mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener); - break; - default: - listener.onFailure(createInvalidModelException(model)); - break; + if (model instanceof MistralModel mistralModel) { + mistralModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); } } @@ -162,7 +166,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); + var action = mistralEmbeddingsModel.accept(actionCreator); action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { @@ -207,7 +211,6 @@ public void parseRequestConfig( modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), @@ -232,7 +235,7 @@ public MistralModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; @@ -244,7 +247,6 @@ public MistralModel parsePersistedConfigWithSecrets( modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(modelId, NAME) @@ -254,7 +256,7 @@ public MistralModel parsePersistedConfigWithSecrets( @Override public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -265,7 +267,6 @@ public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map< modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, null, parsePersistedConfigErrorMsg(modelId, NAME) @@ -286,7 +287,6 @@ private static MistralModel createModel( String modelId, TaskType taskType, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, @@ -294,16 +294,7 @@ private static MistralModel createModel( ) { switch (taskType) { case TEXT_EMBEDDING: - return new MistralEmbeddingsModel( - modelId, - taskType, - NAME, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - context - ); + return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); case CHAT_COMPLETION, COMPLETION: return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: @@ -315,7 +306,6 @@ private MistralModel createModelFromPersistent( String inferenceEntityId, TaskType taskType, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage @@ -324,7 +314,6 @@ private MistralModel createModelFromPersistent( inferenceEntityId, taskType, serviceSettings, - taskSettings, chunkingSettings, secretSettings, failureMessage, @@ -359,10 +348,10 @@ public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int e */ public static class Configuration { public static InferenceServiceConfiguration get() { - return configuration.getOrCompute(); + return CONFIGURATION.getOrCompute(); } - private static final LazyInitializable configuration = new LazyInitializable<>( + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( () -> { var configurationMap = new HashMap(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index a9d6df687fe99..cbff01c6227eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -34,7 +34,7 @@ protected Exception buildError(String message, Request request, HttpResult resul assert request.isStreaming() : "Only streaming requests support this format"; var responseStatusCode = result.response().getStatusLine().getStatusCode(); if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); var restStatus = toRestStatus(responseStatusCode); return errorResponse instanceof MistralErrorResponse ? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java index fbf842f4fb789..ba7377c3209e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java @@ -24,7 +24,6 @@ import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; -import java.util.Map; import java.util.Objects; import static org.elasticsearch.core.Strings.format; @@ -51,7 +50,7 @@ public MistralActionCreator(Sender sender, ServiceComponents serviceComponents) } @Override - public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings) { + public ExecutableAction create(MistralEmbeddingsModel embeddingsModel) { var requestManager = new MistralEmbeddingsRequestManager( embeddingsModel, serviceComponents.truncator(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java index 5f494e4d65477..e1c4b12883c56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; -import java.util.Map; - /** * Interface for creating {@link ExecutableAction} instances for Mistral models. *

    @@ -25,10 +23,9 @@ public interface MistralActionVisitor { * Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}. * * @param embeddingsModel The model to create the action for. - * @param taskSettings The task settings to use. * @return An {@link ExecutableAction} for the given model. */ - ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings); + ExecutableAction create(MistralEmbeddingsModel embeddingsModel); /** * Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java index 03fe502a82807..876c46edcb70d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java @@ -22,7 +22,6 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -import java.util.Objects; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH; @@ -95,23 +94,17 @@ public MistralChatCompletionModel( DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); } private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); this.rateLimitSettings = serviceSettings.rateLimitSettings(); setEndpointUrl(); } - @Override - public int rateLimitGroupingHash() { - return Objects.hash(model, getSecretSettings().apiKey()); - } - private void setEndpointUrl() { try { this.uri = new URI(API_COMPLETIONS_PATH); @@ -131,6 +124,7 @@ public MistralChatCompletionServiceSettings getServiceSettings() { * @param creator The visitor that creates the executable action. * @return An ExecutableAction that can be executed. */ + @Override public ExecutableAction accept(MistralActionVisitor creator) { return creator.create(this); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java index 48d2fecc5ce13..8ac186ac9d642 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -12,7 +12,6 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -37,7 +36,6 @@ public MistralEmbeddingsModel( TaskType taskType, String service, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context @@ -47,7 +45,6 @@ public MistralEmbeddingsModel( taskType, service, MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context), - EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); @@ -59,7 +56,6 @@ public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsSer } private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); this.rateLimitSettings = serviceSettings.rateLimitSettings(); setEndpointUrl(); } @@ -77,12 +73,11 @@ public MistralEmbeddingsModel( TaskType taskType, String service, MistralEmbeddingsServiceSettings serviceSettings, - TaskSettings taskSettings, ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); @@ -93,7 +88,8 @@ public MistralEmbeddingsServiceSettings getServiceSettings() { return (MistralEmbeddingsServiceSettings) super.getServiceSettings(); } - public ExecutableAction accept(MistralActionVisitor creator, Map taskSettings) { - return creator.create(this, taskSettings); + @Override + public ExecutableAction accept(MistralActionVisitor creator) { + return creator.create(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index 6b1c7d36a9fe6..4cf1fef3c92c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -178,12 +178,13 @@ public boolean equals(Object o) { return Objects.equals(model, that.model) && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) - && Objects.equals(similarity, that.similarity); + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override public int hashCode() { - return Objects.hash(model, dimensions, maxInputTokens, similarity); + return Objects.hash(model, dimensions, maxInputTokens, similarity, rateLimitSettings); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java index 8b772d4b8f2ed..b7d3866bcebfd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java @@ -42,7 +42,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(this.uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.model(), truncationResult.input())) + Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.getServiceSettings().modelId(), truncationResult.input())) .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..b9e9e34c44736 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -91,8 +93,16 @@ public class OpenAiService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public OpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index e1a0117c7bcca..8a70f4428799b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -56,7 +56,7 @@ protected Exception buildError(String message, Request request, HttpResult resul assert request.isStreaming() : "Only streaming requests support this format"; var responseStatusCode = result.response().getStatusLine().getStatusCode(); if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); var restStatus = toRestStatus(responseStatusCode); return errorResponse instanceof StreamingErrorResponse oer ? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param()) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 3120f1ff92e48..957203b5ee802 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -198,7 +198,7 @@ private static class DeltaParser { PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); - PARSER.declareObjectArray( + PARSER.declareObjectArrayOrNull( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), new ParseField(TOOL_CALLS_FIELD) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java index 928ed3ff444e6..2ae70cb52b565 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java @@ -34,7 +34,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); unifiedRequestEntity.toXContent( builder, - UnifiedCompletionRequest.withMaxCompletionTokensTokens(model.getServiceSettings().modelId(), params) + UnifiedCompletionRequest.withMaxCompletionTokens(model.getServiceSettings().modelId(), params) ); if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index aafd6c46857fc..653c4288263f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -37,6 +39,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.core.Strings.format; @@ -55,13 +58,26 @@ public class SageMakerService implements InferenceService { private final SageMakerSchemas schemas; private final ThreadPool threadPool; private final LazyInitializable configuration; + private final ClusterService clusterService; public SageMakerService( SageMakerModelBuilder modelBuilder, SageMakerClient client, SageMakerSchemas schemas, ThreadPool threadPool, - CheckedSupplier, RuntimeException> configurationMap + CheckedSupplier, RuntimeException> configurationMap, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(modelBuilder, client, schemas, threadPool, configurationMap, context.clusterService()); + } + + public SageMakerService( + SageMakerModelBuilder modelBuilder, + SageMakerClient client, + SageMakerSchemas schemas, + ThreadPool threadPool, + CheckedSupplier, RuntimeException> configurationMap, + ClusterService clusterService ) { this.modelBuilder = modelBuilder; this.client = client; @@ -74,6 +90,7 @@ public SageMakerService( .setConfigurations(configurationMap.get()) .build() ); + this.clusterService = Objects.requireNonNull(clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java index 48e32c741a601..0975f8616da03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java @@ -116,7 +116,7 @@ public SageMakerModel override(Map taskSettingsOverride) { getConfigurations(), getSecrets(), serviceSettings, - taskSettings.updatedTaskSettings(taskSettingsOverride), + taskSettings.override(taskSettingsOverride), awsSecretSettings ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java index fd9eb2d20c5d3..a36944c51f104 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java @@ -71,11 +71,21 @@ public boolean isEmpty() { @Override public SageMakerTaskSettings updatedTaskSettings(Map newSettings) { var validationException = new ValidationException(); - var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException); + validationException.throwIfValidationErrorsExist(); + + return override(updateTaskSettings); + } + public SageMakerTaskSettings override(Map newSettings) { + var validationException = new ValidationException(); + var updateTaskSettings = fromMap(newSettings, apiTaskSettings.override(newSettings), validationException); validationException.throwIfValidationErrorsExist(); + return override(updateTaskSettings); + } + + private SageMakerTaskSettings override(SageMakerTaskSettings updateTaskSettings) { var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP) ? apiTaskSettings : updateTaskSettings.apiTaskSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java index 09a73f0f42ea4..a3ff632f466c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java @@ -68,4 +68,8 @@ default boolean isFragment() { @Override SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings); + + default SageMakerStoredTaskSchema override(Map newSettings) { + return updatedTaskSettings(newSettings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java index 46c5a9eb30a9a..781b1e906a17f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java @@ -88,12 +88,6 @@ default SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest re @Override default SageMakerElasticTaskSettings apiTaskSettings(Map taskSettings, ValidationException validationException) { - if (taskSettings != null && (taskSettings.isEmpty() == false)) { - validationException.addValidationError( - InferenceAction.Request.TASK_SETTINGS.getPreferredName() - + " is only supported during the inference request and cannot be stored in the inference endpoint." - ); - } return SageMakerElasticTaskSettings.empty(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java index 088de2068741c..dc0bc91fccd75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java @@ -9,10 +9,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; import java.io.IOException; @@ -40,6 +42,16 @@ public boolean isEmpty() { @Override public SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings) { + var validationException = new ValidationException(); + validationException.addValidationError( + InferenceAction.Request.TASK_SETTINGS.getPreferredName() + + " is only supported during the inference request and cannot be stored in the inference endpoint." + ); + throw validationException; + } + + @Override + public SageMakerStoredTaskSchema override(Map newSettings) { return new SageMakerElasticTaskSettings(newSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..9698ee4c0d4bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,8 +98,16 @@ public class VoyageAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public VoyageAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequestSemanticIndexFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequestSemanticIndexFilterTests.java new file mode 100644 index 0000000000000..9fd1c3921fcf6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequestSemanticIndexFilterTests.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.action.fieldcaps; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.BoostingQueryBuilder; +import org.elasticsearch.index.query.ConstantScoreQueryBuilder; +import org.elasticsearch.index.query.DisMaxQueryBuilder; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.NestedQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.notNullValue; + +public class FieldCapabilitiesRequestSemanticIndexFilterTests extends ESTestCase { + private static final String EXPECTED_ERROR_MESSAGE = "index filter cannot contain semantic queries. Use an exists query instead."; + + public void testValidateWithoutIndexFilter() { + FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); + request.fields("field1", "field2"); + + ActionRequestValidationException validationException = request.validate(); + assertNull(validationException); + } + + public void testValidateWithNonSemanticIndexFilter() { + FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); + request.fields("field1", "field2"); + request.indexFilter(randomNonSemanticQuery()); + + ActionRequestValidationException validationException = request.validate(); + assertNull(validationException); + } + + public void testValidateWithDirectSemanticQuery() { + FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); + request.fields("field1", "field2"); + request.indexFilter(randomSemanticQuery()); + + ActionRequestValidationException validationException = request.validate(); + assertThat(validationException, notNullValue()); + assertThat(validationException.getMessage(), containsString(EXPECTED_ERROR_MESSAGE)); + } + + public void testValidateWithRandomCompoundQueryContainingSemantic() { + for (int i = 0; i < 100; i++) { + FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); + request.fields("field1", "field2"); + + // Create a randomly structured compound query containing semantic query + QueryBuilder randomCompoundQuery = randomCompoundQueryWithSemantic(randomIntBetween(1, 3)); + request.indexFilter(randomCompoundQuery); + + ActionRequestValidationException validationException = request.validate(); + assertThat(validationException, notNullValue()); + assertThat(validationException.getMessage(), containsString(EXPECTED_ERROR_MESSAGE)); + } + } + + private static SemanticQueryBuilder randomSemanticQuery() { + return new SemanticQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(10)); + } + + private static QueryBuilder randomNonSemanticQuery() { + return switch (randomIntBetween(0, 2)) { + case 0 -> new TermQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(5)); + case 1 -> new MatchAllQueryBuilder(); + case 2 -> { + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + boolQuery.must(new TermQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(5))); + yield boolQuery; + } + default -> throw new IllegalStateException("Unexpected value"); + }; + } + + private static QueryBuilder randomCompoundQueryWithSemantic(int depth) { + if (depth <= 0) { + return randomSemanticQuery(); + } + + return switch (randomIntBetween(0, 5)) { + case 0 -> { + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + QueryBuilder clauseQuery = randomCompoundQueryWithSemantic(depth - 1); + switch (randomIntBetween(0, 3)) { + case 0 -> boolQuery.must(clauseQuery); + case 1 -> boolQuery.mustNot(clauseQuery); + case 2 -> boolQuery.should(clauseQuery); + case 3 -> boolQuery.filter(clauseQuery); + default -> throw new IllegalStateException("Unexpected value"); + } + + if (randomBoolean()) { + boolQuery.should(randomNonSemanticQuery()); + } + + yield boolQuery; + } + case 1 -> { + DisMaxQueryBuilder disMax = new DisMaxQueryBuilder(); + disMax.add(randomCompoundQueryWithSemantic(depth - 1)); + if (randomBoolean()) { + disMax.add(randomNonSemanticQuery()); + } + yield disMax; + } + case 2 -> new NestedQueryBuilder(randomAlphaOfLength(5), randomCompoundQueryWithSemantic(depth - 1), ScoreMode.Max); + case 3 -> { + boolean positiveSemanticQuery = randomBoolean(); + QueryBuilder semanticQuery = randomCompoundQueryWithSemantic(depth - 1); + QueryBuilder nonSemanticQuery = randomNonSemanticQuery(); + + yield new BoostingQueryBuilder( + positiveSemanticQuery ? semanticQuery : nonSemanticQuery, + positiveSemanticQuery ? nonSemanticQuery : semanticQuery + ); + } + case 4 -> new ConstantScoreQueryBuilder(randomCompoundQueryWithSemantic(depth - 1)); + case 5 -> new FunctionScoreQueryBuilder(randomCompoundQueryWithSemantic(depth - 1)); + default -> throw new IllegalStateException("Unexpected value"); + }; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 270cdba6d3469..1f0b56e3d6848 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -61,6 +61,14 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY); KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); + if (randomBoolean()) { + float boost = randomFloatBetween(1, 10, randomBoolean()); + original.boost(boost); + } + if (randomBoolean()) { + String queryName = randomAlphaOfLength(5); + original.queryName(queryName); + } testRewrittenInferenceQuery(context, original); } @@ -72,6 +80,14 @@ public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); + if (randomBoolean()) { + float boost = randomFloatBetween(1, 10, randomBoolean()); + original.boost(boost); + } + if (randomBoolean()) { + String queryName = randomAlphaOfLength(5); + original.queryName(queryName); + } testRewrittenInferenceQuery(context, original); } @@ -82,14 +98,23 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; + assertEquals(original.boost(), intercepted.boost(), 0.0f); + assertEquals(original.queryName(), intercepted.queryName()); assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; + assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f); + assertEquals(original.queryName(), nestedQueryBuilder.queryName()); assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); + QueryBuilder innerQuery = nestedQueryBuilder.query(); assertTrue(innerQuery instanceof KnnVectorQueryBuilder); KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery; + assertEquals(1.0f, knnVectorQueryBuilder.boost(), 0.0f); + assertNull(knnVectorQueryBuilder.queryName()); assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName()); assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder); + TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder .queryVectorBuilder(); assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index 6987ef33ed63d..b58547e1a92c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -36,6 +36,8 @@ public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase { private static final String FIELD_NAME = "fieldName"; private static final String VALUE = "value"; + private static final String QUERY_NAME = "match_query"; + private static final float BOOST = 5.0f; @Before public void setup() { @@ -79,6 +81,29 @@ public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOExcept assertEquals(original, rewritten); } + public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException { + Map inferenceFields = Map.of( + FIELD_NAME, + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) + ); + QueryRewriteContext context = createQueryRewriteContext(inferenceFields); + QueryBuilder original = createTestQueryBuilder(); + original.boost(BOOST); + original.queryName(QUERY_NAME); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof InterceptedQueryBuilderWrapper + ); + InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; + assertEquals(BOOST, intercepted.boost(), 0.0f); + assertEquals(QUERY_NAME, intercepted.queryName()); + assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder); + SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder; + assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName()); + assertEquals(VALUE, semanticQueryBuilder.getQuery()); + } + private MatchQueryBuilder createTestQueryBuilder() { return new MatchQueryBuilder(FIELD_NAME, VALUE); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java index 075955766a0a9..401b7085e2cb5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -58,21 +58,15 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); - QueryBuilder rewritten = original.rewrite(context); - assertTrue( - "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", - rewritten instanceof InterceptedQueryBuilderWrapper - ); - InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; - assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); - NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; - assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); - QueryBuilder innerQuery = nestedQueryBuilder.query(); - assertTrue(innerQuery instanceof SparseVectorQueryBuilder); - SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery; - assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName()); - assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId()); - assertEquals(QUERY, sparseVectorQueryBuilder.getQuery()); + if (randomBoolean()) { + float boost = randomFloatBetween(1, 10, randomBoolean()); + original.boost(boost); + } + if (randomBoolean()) { + String queryName = randomAlphaOfLength(5); + original.queryName(queryName); + } + testRewrittenInferenceQuery(context, original); } public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { @@ -82,32 +76,52 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY); + if (randomBoolean()) { + float boost = randomFloatBetween(1, 10, randomBoolean()); + original.boost(boost); + } + if (randomBoolean()) { + String queryName = randomAlphaOfLength(5); + original.queryName(queryName); + } + testRewrittenInferenceQuery(context, original); + } + + public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException { + QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields + QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof SparseVectorQueryBuilder + ); + assertEquals(original, rewritten); + } + + private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException { QueryBuilder rewritten = original.rewrite(context); assertTrue( "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; + assertEquals(original.boost(), intercepted.boost(), 0.0f); + assertEquals(original.queryName(), intercepted.queryName()); + assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); + assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f); + assertEquals(original.queryName(), nestedQueryBuilder.queryName()); + QueryBuilder innerQuery = nestedQueryBuilder.query(); assertTrue(innerQuery instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery; assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName()); assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId()); assertEquals(QUERY, sparseVectorQueryBuilder.getQuery()); - } - - public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException { - QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields - QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); - QueryBuilder rewritten = original.rewrite(context); - assertTrue( - "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]", - rewritten instanceof SparseVectorQueryBuilder - ); - assertEquals(original, rewritten); + assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f); + assertNull(sparseVectorQueryBuilder.queryName()); } private QueryRewriteContext createQueryRewriteContext(Map inferenceFields) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 70499c7987965..812cd1e3c6d7f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -21,6 +21,8 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -33,7 +35,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -88,7 +89,7 @@ public void setUp() throws Exception { licenseState = mock(); modelRegistry = mock(); serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); + inferenceStats = InferenceStatsTests.mockInferenceStats(); streamingTaskManager = mock(); action = createAction( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java index f57b62ac1c8bf..e210649edb845 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java @@ -17,8 +17,10 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -32,11 +34,17 @@ import java.util.Optional; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class TransportDeleteInferenceEndpointActionTests extends ESTestCase { @@ -130,4 +138,213 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() { assertTrue(response.isAcknowledged()); } + + public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() { + var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); + var serviceName = randomAlphanumericOfLength(10); + var taskType = randomFrom(TaskType.values()); + var mockService = mock(InferenceService.class); + mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService); + when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + + var listener = new PlainActionFuture(); + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false), + ClusterState.EMPTY_STATE, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint")); + + verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + verify(mockInferenceServiceRegistry).getService(eq(serviceName)); + verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService); + } + + public void testDeletesUnparsableEndpoint_WhenForceIsTrue() { + var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); + var serviceName = randomAlphanumericOfLength(10); + var taskType = randomFrom(TaskType.values()); + var mockService = mock(InferenceService.class); + mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(true); + return Void.TYPE; + }).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); + + var listener = new PlainActionFuture(); + + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false), + ClusterState.EMPTY_STATE, + listener + ); + + var response = listener.actionGet(TIMEOUT); + assertTrue(response.isAcknowledged()); + + verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + verify(mockInferenceServiceRegistry).getService(eq(serviceName)); + verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); + verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService); + } + + private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) { + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of())); + return Void.TYPE; + }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService) + .parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService)); + } + + public void testDeletesEndpointWithNoService_WhenForceIsTrue() { + var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); + var serviceName = randomAlphanumericOfLength(10); + var taskType = randomFrom(TaskType.values()); + mockNoService(inferenceEndpointId, serviceName, taskType); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(true); + return Void.TYPE; + }).when(mockModelRegistry).deleteModel(anyString(), any()); + + var listener = new PlainActionFuture(); + + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false), + ClusterState.EMPTY_STATE, + listener + ); + + var response = listener.actionGet(TIMEOUT); + assertTrue(response.isAcknowledged()); + verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + verify(mockInferenceServiceRegistry).getService(eq(serviceName)); + verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); + verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry); + } + + public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() { + var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); + var serviceName = randomAlphanumericOfLength(10); + var taskType = randomFrom(TaskType.values()); + mockNoService(inferenceEndpointId, serviceName, taskType); + when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + + var listener = new PlainActionFuture(); + + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false), + ClusterState.EMPTY_STATE, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("No service found for this inference endpoint")); + verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + verify(mockInferenceServiceRegistry).getService(eq(serviceName)); + verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry); + } + + private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) { + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of())); + return Void.TYPE; + }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty()); + } + + public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() { + var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); + var serviceName = randomAlphanumericOfLength(10); + var taskType = randomFrom(TaskType.values()); + var mockService = mock(InferenceService.class); + var mockModel = mock(Model.class); + mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel); + when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + + var listener = new PlainActionFuture(); + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false), + ClusterState.EMPTY_STATE, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Failed to stop model deployment")); + verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + verify(mockInferenceServiceRegistry).getService(eq(serviceName)); + verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verify(mockService).stop(eq(mockModel), any()); + verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel); + } + + public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() { + var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); + var serviceName = randomAlphanumericOfLength(10); + var taskType = randomFrom(TaskType.values()); + var mockService = mock(InferenceService.class); + var mockModel = mock(Model.class); + mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(true); + return Void.TYPE; + }).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); + + var listener = new PlainActionFuture(); + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false), + ClusterState.EMPTY_STATE, + listener + ); + + var response = listener.actionGet(TIMEOUT); + assertTrue(response.isAcknowledged()); + verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + verify(mockInferenceServiceRegistry).getService(eq(serviceName)); + verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verify(mockService).stop(eq(mockModel), any()); + verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); + verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel); + } + + private void mockStopDeploymentFails( + String inferenceEndpointId, + String serviceName, + TaskType taskType, + InferenceService mockService, + Model mockModel + ) { + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of())); + return Void.TYPE; + }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); + when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService)); + doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR)); + return Void.TYPE; + }).when(mockService).stop(eq(mockModel), any()); + } + } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 4d986cf0a837f..547078d93acc4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportException; @@ -22,7 +23,6 @@ import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimitAssignment; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index f26d0675487a5..9e6f4a6260936 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; @@ -20,7 +21,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.Optional; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 5b4925d8fb0a3..e96fda569aa12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -49,6 +49,8 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -66,7 +68,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -148,7 +149,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(), @@ -181,7 +182,7 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -227,7 +228,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -275,7 +276,7 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -364,7 +365,7 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -440,7 +441,7 @@ public void testExplicitNull() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testHandleEmptyInput() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -495,7 +496,7 @@ public void testHandleEmptyInput() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { @@ -559,7 +560,7 @@ public void testManyRandomDocs() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testIndexingPressure() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); @@ -677,7 +678,7 @@ public void testIndexingPressure() throws Exception { @SuppressWarnings("unchecked") public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() ); @@ -709,7 +710,10 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); assertThat( doc1Failure.getCause().getMessage(), - containsString("Insufficient memory available to update source on document [doc_1]") + containsString( + "Unable to insert inference results into document [doc_1]" + + " due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes." + ) ); assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); @@ -762,7 +766,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (length(doc1Source) + 1) + "b").build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); @@ -791,7 +795,10 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); assertThat( doc1Failure.getCause().getMessage(), - containsString("Insufficient memory available to insert inference results into document [doc_1]") + containsString( + "Unable to insert inference results into document [doc_1]" + + " due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes." + ) ); assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); @@ -875,7 +882,7 @@ public void testIndexingPressurePartialFailure() throws Exception { .build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(sparseModel.getInferenceEntityId(), sparseModel), @@ -902,7 +909,10 @@ public void testIndexingPressurePartialFailure() throws Exception { BulkItemResponse.Failure doc2Failure = doc2Response.getFailure(); assertThat( doc2Failure.getCause().getMessage(), - containsString("Insufficient memory available to insert inference results into document [doc_2]") + containsString( + "Unable to insert inference results into document [doc_2]" + + " due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes." + ) ); assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java index baa8429ae3c78..1cb90b11995fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java @@ -46,7 +46,7 @@ public void testChunkInputShorterThanMaxChunkSize() { assertExpectedChunksGenerated(input, settings, List.of(new Chunker.ChunkOffset(0, input.length()))); } - public void testChunkInputRequiresOneSplit() { + public void testChunkInputRequiresOneSplitWithoutMerges() { List separators = generateRandomSeparators(); RecursiveChunkingSettings settings = generateChunkingSettings(10, separators); String input = generateTestText(2, List.of(separators.getFirst())); @@ -58,7 +58,23 @@ public void testChunkInputRequiresOneSplit() { ); } - public void testChunkInputRequiresMultipleSplits() { + public void testChunkInputRequiresOneSplitWithMerges() { + List separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(20, separators); + String input = generateTestText(3, List.of(separators.getFirst(), separators.getFirst())); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.getFirst().length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length()) + ) + ); + } + + public void testChunkInputRequiresMultipleSplitsWithoutMerges() { var separators = generateRandomSeparators(); RecursiveChunkingSettings settings = generateChunkingSettings(15, separators); String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1))); @@ -78,6 +94,22 @@ public void testChunkInputRequiresMultipleSplits() { ); } + public void testChunkInputRequiresMultipleSplitsWithMerges() { + var separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(25, separators); + String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1))); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length()) + ) + ); + } + public void testChunkInputDoesNotSplitWhenNoLongerExceedingMaxChunkSize() { var separators = randomSubsetOf(3, TEST_SEPARATORS); RecursiveChunkingSettings settings = generateChunkingSettings(25, separators); @@ -165,7 +197,7 @@ public void testChunkLongDocument() { public void testMarkdownChunking() { int numSentences = randomIntBetween(10, 50); - List separators = SeparatorSet.MARKDOWN.getSeparators(); + List separators = SeparatorGroup.MARKDOWN.getSeparators(); List validHeaders = List.of( "# Header\n", "## Header\n", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java index 40f14e88d2558..f833aa09b1aee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java @@ -32,15 +32,15 @@ public void testFromMapValidSettingsWithSeparators() { assertEquals(separators, settings.getSeparators()); } - public void testFromMapValidSettingsWithSeparatorSet() { + public void testFromMapValidSettingsWithSeparatorGroup() { var maxChunkSize = randomIntBetween(10, 300); - var separatorSet = randomFrom(SeparatorSet.values()); - Map validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorSet.name()), Optional.empty()); + var separatorGroup = randomFrom(SeparatorGroup.values()); + Map validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorGroup.name()), Optional.empty()); RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings); assertEquals(maxChunkSize, settings.getMaxChunkSize()); - assertEquals(separatorSet.getSeparators(), settings.getSeparators()); + assertEquals(separatorGroup.getSeparators(), settings.getSeparators()); } public void testFromMapMaxChunkSizeTooSmall() { @@ -55,7 +55,7 @@ public void testFromMapMaxChunkSizeTooLarge() { assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); } - public void testFromMapInvalidSeparatorSet() { + public void testFromMapInvalidSeparatorGroup() { Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.of("invalid"), Optional.empty()); assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); @@ -68,7 +68,7 @@ public void testFromMapInvalidSettingKey() { assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); } - public void testFromMapBothSeparatorsAndSeparatorSet() { + public void testFromMapBothSeparatorsAndSeparatorGroup() { Map invalidSettings = buildChunkingSettingsMap( randomIntBetween(10, 300), Optional.of("default"), @@ -86,13 +86,13 @@ public void testFromMapEmptySeparators() { private Map buildChunkingSettingsMap( int maxChunkSize, - Optional separatorSet, + Optional separatorGroup, Optional> separators ) { Map settingsMap = new HashMap<>(); settingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.RECURSIVE.toString()); settingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); - separatorSet.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_SET.toString(), s)); + separatorGroup.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), s)); separators.ifPresent(strings -> settingsMap.put(ChunkingSettingsOptions.SEPARATORS.toString(), strings)); return settingsMap; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index aeb09af03ebab..4a4c59f091abf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.custom.CustomModel; import org.junit.After; import org.junit.Assume; import org.junit.Before; @@ -141,7 +140,7 @@ public boolean isEnabled() { return true; } - protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); } private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { @@ -151,7 +150,7 @@ public boolean isEnabled() { } @Override - protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { throw new UnsupportedOperationException("Update model tests are disabled"); } }; @@ -351,11 +350,17 @@ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() thr assertThat( exception.getMessage(), - containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType)) + containsString( + Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType) + ) ); } } + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "service does not support task type [%s]"; + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var parseConfigTestConfig = testConfiguration.commonConfig; @@ -374,7 +379,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists persistedConfigMap.secrets() ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -396,7 +401,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServ persistedConfigMap.secrets() ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -413,7 +418,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTask var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -430,7 +435,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -468,7 +473,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) ); - assertThat(exception.getMessage(), containsString("Can't update embedding details for model of type:")); + assertThat(exception.getMessage(), containsString("Can't update embedding details for model")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..7457859a64603 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; @@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -64,7 +66,7 @@ public void testStart_InitializesTheSender() throws IOException { var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); @@ -84,7 +86,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); @@ -102,8 +104,8 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep } private static final class TestSenderService extends SenderService { - TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 8fbbd33d569e4..f0258e9f66ed5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -91,7 +91,13 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -116,7 +122,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -143,7 +155,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -169,7 +187,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -190,7 +214,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -210,7 +240,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -235,7 +271,13 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -262,7 +304,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -279,7 +321,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -316,7 +358,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -360,7 +402,7 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -404,7 +446,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -452,7 +494,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -482,7 +524,13 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { String content = XContentHelper.stripWhitespace( """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index a014f27e7f0cc..c3b1cab4b4e0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -959,7 +959,14 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc ); var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1007,7 +1014,12 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); @@ -1042,7 +1054,14 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var results = new TextEmbeddingFloatResults( List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) @@ -1088,7 +1107,14 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); requestSender.enqueue(mockResults); @@ -1132,7 +1158,14 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = AmazonBedrockChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1166,7 +1199,14 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var embeddingSize = randomNonNegativeInt(); var provider = randomFrom(AmazonBedrockProvider.values()); var model = AmazonBedrockEmbeddingsModelTests.createModel( @@ -1205,7 +1245,12 @@ public void testInfer_UnauthorizedResponse() throws IOException { ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { requestSender.enqueue( @@ -1240,7 +1285,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { } public void testSupportsStreaming() throws IOException { - try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { + try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1284,7 +1329,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { var mockResults1 = new TextEmbeddingFloatResults( @@ -1345,7 +1397,12 @@ private AmazonBedrockService createAmazonBedrockService() { ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), mockClusterServiceEmpty() ); - return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + return new AmazonBedrockService( + mock(HttpRequestSender.Factory.class), + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 75ce59b16a763..9111866d29c88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -453,7 +453,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -486,7 +486,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", @@ -579,7 +579,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AnthropicChatCompletionModelTests.createChatCompletionModel( getUrl(webServer), "secret", @@ -650,6 +650,15 @@ public void testGetConfiguration() throws Exception { "updatable": false, "type": "str", "supported_task_types": ["completion"] + }, + "max_tokens": { + "description": "The maximum number of tokens to generate before stopping.", + "label": "Max Tokens", + "required": true, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["completion"] } } } @@ -670,13 +679,13 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AnthropicService createServiceWithMockSender() { - return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 3d7ba7f7436fb..3383762a9f332 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -54,6 +55,10 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -219,6 +224,33 @@ public void testParseRequestConfig_CreatesAnAzureAiStudioChatCompletionModel() t } } + public void testParseRequestConfig_CreatesAnAzureAiStudioRerankModel() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); + + var rerankModel = (AzureAiStudioRerankModel) model; + assertThat(rerankModel.getServiceSettings().target(), is("http://target.local")); + assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE)); + assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(rerankModel.getSecretSettings().apiKey().toString(), is("secret")); + assertNull(rerankModel.getTaskSettings().returnDocuments()); + assertNull(rerankModel.getTaskSettings().topN()); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.RERANK, + getRequestConfigMap( + getRerankServiceSettingsMap("http://target.local", "cohere", "token"), + getRerankTaskSettingsMap(null, null), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -441,6 +473,80 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSec } } + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token"); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret")); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException { + try (var service = createService()) { + var taskSettings = getRerankTaskSettingsMap(null, null); + taskSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + getRerankServiceSettingsMap("http://target.local", "cohere", "token"), + taskSettings, + getSecretSettingsMap("secret") + ); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + getRerankServiceSettingsMap("http://target.local", "cohere", "token"), + getRerankTaskSettingsMap(null, null), + secretSettings + ); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener); + } + } + public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() throws IOException { try (var service = createService()) { var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "databricks", "token", null, null, null, null); @@ -505,6 +611,45 @@ public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForChatComple } } + public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForRerank() throws IOException { + try (var service = createService()) { + var serviceSettings = getRerankServiceSettingsMap("http://target.local", "databricks", "token"); + + var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret")); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [rerank] task type for provider [databricks] is not available")); + } + ); + + service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider() throws IOException { + try (var service = createService()) { + var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "realtime"); + + var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret")); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available") + ); + } + ); + + service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener); + } + } + public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( @@ -603,6 +748,27 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() } } + public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws IOException { + try (var service = createService()) { + var config = getPersistedConfigMap( + getRerankServiceSettingsMap("http://target.local", "cohere", "token"), + getRerankTaskSettingsMap(true, 2), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); + + var chatCompletionModel = (AzureAiStudioRerankModel) model; + assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local")); + assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE)); + assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(chatCompletionModel.getTaskSettings().returnDocuments(), is(true)); + assertThat(chatCompletionModel.getTaskSettings().topN(), is(2)); + } + } + public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -747,6 +913,48 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompl } } + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token"); + serviceSettings.put("extra_key", "value"); + var taskSettings = getRerankTaskSettingsMap(true, 2); + var secretSettings = getSecretSettingsMap("secret"); + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token"); + var taskSettings = getRerankTaskSettingsMap(true, 2); + taskSettings.put("extra_key", "value"); + var secretSettings = getSecretSettingsMap("secret"); + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token"); + var taskSettings = getRerankTaskSettingsMap(true, 2); + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); + } + } + public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( @@ -842,9 +1050,30 @@ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() } } + public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws IOException { + try (var service = createService()) { + var config = getPersistedConfigMap( + getRerankServiceSettingsMap("http://target.local", "cohere", "token"), + getRerankTaskSettingsMap(true, 2), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.RERANK, config.config()); + + assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); + + var rerankModel = (AzureAiStudioRerankModel) model; + assertThat(rerankModel.getServiceSettings().target(), is("http://target.local")); + assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE)); + assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(rerankModel.getTaskSettings().returnDocuments(), is(true)); + assertThat(rerankModel.getTaskSettings().topN(), is(2)); + } + } + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -869,7 +1098,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -895,7 +1124,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -923,7 +1152,7 @@ public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginal private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -956,7 +1185,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -994,7 +1223,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1064,7 +1293,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1150,7 +1379,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep public void testInfer_WithChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); var model = AzureAiStudioChatCompletionModelTests.createModel( @@ -1184,10 +1413,51 @@ public void testInfer_WithChatCompletionModel() throws IOException { } } + public void testInfer_WithRerankModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson)); + + var model = AzureAiStudioRerankModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + "apikey" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + false, + 2, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + assertThat(result, CoreMatchers.instanceOf(RankedDocsResults.class)); + + var rankedDocsResults = (RankedDocsResults) result; + var rankedDocs = rankedDocsResults.getRankedDocs(); + assertThat(rankedDocs.size(), is(2)); + assertThat(rankedDocs.get(0).relevanceScore(), is(0.1111111F)); + assertThat(rankedDocs.get(0).index(), is(0)); + assertThat(rankedDocs.get(1).relevanceScore(), is(0.2222222F)); + assertThat(rankedDocs.get(1).index(), is(1)); + } + } + public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1264,7 +1534,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( "id", getUrl(webServer), @@ -1320,7 +1590,7 @@ public void testGetConfiguration() throws Exception { { "service": "azureaistudio", "name": "Azure AI Studio", - "task_types": ["text_embedding", "completion"], + "task_types": ["text_embedding", "rerank", "completion"], "configurations": { "dimensions": { "description": "The number of dimensions the resulting embeddings should have. For more information refer to https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-embeddings.", @@ -1338,7 +1608,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion"] + "supported_task_types": ["text_embedding", "rerank", "completion"] }, "provider": { "description": "The model provider for your deployment.", @@ -1347,7 +1617,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion"] + "supported_task_types": ["text_embedding", "rerank", "completion"] }, "api_key": { "description": "API Key for the provider you're connecting to.", @@ -1356,7 +1626,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "completion"] + "supported_task_types": ["text_embedding", "rerank", "completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1365,7 +1635,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "completion"] + "supported_task_types": ["text_embedding", "rerank", "completion"] }, "target": { "description": "The target URL of your Azure AI Studio model deployment.", @@ -1374,7 +1644,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion"] + "supported_task_types": ["text_embedding", "rerank", "completion"] } } } @@ -1396,7 +1666,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1405,7 +1675,11 @@ public void testSupportsStreaming() throws IOException { // ---------------------------------------------------------------- private AzureAiStudioService createService() { - return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( @@ -1462,6 +1736,10 @@ private static HashMap getChatCompletionServiceSettingsMap(Strin return AzureAiStudioChatCompletionServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType); } + private static HashMap getRerankServiceSettingsMap(String target, String provider, String endpointType) { + return AzureAiStudioRerankServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType); + } + public static Map getChatCompletionTaskSettingsMap( @Nullable Double temperature, @Nullable Double topP, @@ -1471,6 +1749,10 @@ public static Map getChatCompletionTaskSettingsMap( return AzureAiStudioChatCompletionTaskSettingsTests.getTaskSettingsMap(temperature, topP, doSample, maxNewTokens); } + public static Map getRerankTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) { + return AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(returnDocuments, topN); + } + private static Map getSecretSettingsMap(String apiKey) { return new HashMap<>(Map.of(API_KEY_FIELD, apiKey)); } @@ -1520,4 +1802,28 @@ private static Map getSecretSettingsMap(String apiKey) { } } """; + + private static final String testRerankTokenResponseJson = """ + { + "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b", + "results": [ + { + "index": 0, + "relevance_score": 0.1111111 + }, + { + "index": 1, + "relevance_score": 0.2222222 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index 9896286f503f3..12b04f909225a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -27,6 +28,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; @@ -34,6 +36,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests; import org.junit.After; import org.junit.Before; @@ -78,31 +81,20 @@ public void shutdown() throws IOException { } public void testEmbeddingsRequestAction() throws IOException { - var senderFactory = new HttpRequestSender.Factory( + final var senderFactory = new HttpRequestSender.Factory( ServiceComponentsTests.createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty() ); - var timeoutSettings = buildSettingsWithRetryFields( - TimeValue.timeValueMillis(1), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueSeconds(0) - ); - - var serviceComponents = new ServiceComponents( - threadPool, - mock(ThrottlerManager.class), - timeoutSettings, - TruncatorTests.createTruncator() - ); + final var serviceComponents = getServiceComponents(); try (var sender = createSender(senderFactory)) { sender.start(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson)); - var model = AzureAiStudioEmbeddingsModelTests.createModel( + final var model = AzureAiStudioEmbeddingsModelTests.createModel( "id", "http://will-be-replaced.local", AzureAiStudioProvider.OPENAI, @@ -111,21 +103,18 @@ public void testEmbeddingsRequestAction() throws IOException { ); model.setURI(getUrl(webServer)); - var creator = new AzureAiStudioActionCreator(sender, serviceComponents); - var action = creator.create(model, Map.of()); - PlainActionFuture listener = new PlainActionFuture<>(); - var inputType = InputTypeTests.randomSearchAndIngestWithNull(); + final var creator = new AzureAiStudioActionCreator(sender, serviceComponents); + final var action = creator.create(model, Map.of()); + final PlainActionFuture listener = new PlainActionFuture<>(); + final var inputType = InputTypeTests.randomSearchAndIngestWithNull(); action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - var result = listener.actionGet(TIMEOUT); + final var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey")); + assertWebServerRequest(API_KEY_HEADER, "apikey"); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + final var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(InputType.isSpecified(inputType) ? 2 : 1)); assertThat(requestMap.get("input"), is(List.of("abc"))); if (InputType.isSpecified(inputType)) { @@ -136,27 +125,15 @@ public void testEmbeddingsRequestAction() throws IOException { } public void testChatCompletionRequestAction() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - var timeoutSettings = buildSettingsWithRetryFields( - TimeValue.timeValueMillis(1), - TimeValue.timeValueMinutes(1), - TimeValue.timeValueSeconds(0) - ); - - var serviceComponents = new ServiceComponents( - threadPool, - mock(ThrottlerManager.class), - timeoutSettings, - TruncatorTests.createTruncator() - ); + final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + final var serviceComponents = getServiceComponents(); try (var sender = createSender(senderFactory)) { sender.start(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson)); - var webserverUrl = getUrl(webServer); - var model = AzureAiStudioChatCompletionModelTests.createModel( + final var webserverUrl = getUrl(webServer); + final var model = AzureAiStudioChatCompletionModelTests.createModel( "id", "http://will-be-replaced.local", AzureAiStudioProvider.COHERE, @@ -165,30 +142,101 @@ public void testChatCompletionRequestAction() throws IOException { ); model.setURI(webserverUrl); - var creator = new AzureAiStudioActionCreator(sender, serviceComponents); - var action = creator.create(model, Map.of()); + final var creator = new AzureAiStudioActionCreator(sender, serviceComponents); + final var action = creator.create(model, Map.of()); - PlainActionFuture listener = new PlainActionFuture<>(); + final PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - var result = listener.actionGet(TIMEOUT); + final var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); - assertThat(webServer.requests(), hasSize(1)); - - MockRequest request = webServer.requests().get(0); - assertNull(request.getUri().getQuery()); - assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("apikey")); + assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey"); - var requestMap = entityAsMap(request.getBody()); + final MockRequest request = webServer.requests().get(0); + final var requestMap = entityAsMap(request.getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); } } - private static String testEmbeddingsTokenResponseJson = """ + public void testRerankRequestAction() throws IOException { + final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + final var serviceComponents = getServiceComponents(); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson)); + final var webserverUrl = getUrl(webServer); + final var model = AzureAiStudioRerankModelTests.createModel( + "id", + "http://will-be-replaced.local", + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + "apikey" + ); + model.setURI(webserverUrl); + + final var topN = 2; + final var returnDocuments = false; + final var query = "query"; + final var documents = List.of("document 1", "document 2", "document 3"); + + final var creator = new AzureAiStudioActionCreator(sender, serviceComponents); + final var action = creator.create(model, Map.of()); + + final PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs(query, documents, returnDocuments, topN, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + final var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + equalTo( + RankedDocsResultsTests.buildExpectationRerank( + List.of( + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.1111111f)), + new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.2222222f)) + ) + ) + ) + ); + + assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey"); + + final var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("documents"), is(documents)); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("top_n"), is(topN)); + assertThat(requestMap.get("return_documents"), is(returnDocuments)); + } + } + + private void assertWebServerRequest(String authorization, String authorizationHeaderValue) { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(authorization), equalTo(authorizationHeaderValue)); + } + + private ServiceComponents getServiceComponents() { + final var timeoutSettings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + return new ServiceComponents(threadPool, mock(ThrottlerManager.class), timeoutSettings, TruncatorTests.createTruncator()); + } + + private final String testEmbeddingsTokenResponseJson = """ { "object": "list", "data": [ @@ -209,7 +257,7 @@ public void testChatCompletionRequestAction() throws IOException { } """; - private static String testCompletionTokenResponseJson = """ + private final String testCompletionTokenResponseJson = """ { "choices": [ { @@ -233,4 +281,27 @@ public void testChatCompletionRequestAction() throws IOException { } }"""; + private final String testRerankTokenResponseJson = """ + { + "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b", + "results": [ + { + "index": 0, + "relevance_score": 0.1111111 + }, + { + "index": 1, + "relevance_score": 0.2222222 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntityTests.java new file mode 100644 index 0000000000000..a2a13c2d3d389 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntityTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; + +public class AzureAiStudioRerankRequestEntityTests extends ESTestCase { + private static final String INPUT = "texts"; + private static final String QUERY = "query"; + private static final Boolean RETURN_DOCUMENTS = false; + private static final Integer TOP_N = 8; + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + final var entity = new AzureAiStudioRerankRequestEntity( + QUERY, + List.of(INPUT), + Boolean.TRUE, + TOP_N, + new AzureAiStudioRerankTaskSettings(RETURN_DOCUMENTS, TOP_N) + ); + + final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + final String xContentResult = Strings.toString(builder); + final String expected = """ + {"documents":["texts"], + "query":"query", + "return_documents":true, + "top_n":8}"""; + assertEquals(stripWhitespace(expected), xContentResult); + } + + public void testXContent_WritesMinimalFields() throws IOException { + final var entity = new AzureAiStudioRerankRequestEntity( + QUERY, + List.of(INPUT), + null, + null, + new AzureAiStudioRerankTaskSettings(null, null) + ); + + final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + final String xContentResult = Strings.toString(builder); + final String expected = """ + {"documents":["texts"],"query":"query"}"""; + assertEquals(stripWhitespace(expected), xContentResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestTests.java new file mode 100644 index 0000000000000..2df4433bcfdda --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestTests.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider; +import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD; +import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AzureAiStudioRerankRequestTests extends ESTestCase { + private static final String TARGET_URI = "http://testtarget.local"; + private static final String INPUT = "documents"; + private static final String QUERY = "query"; + private static final Integer TOP_N = 2; + + public void testCreateRequest_WithCohereProviderTokenEndpoint_NoParams() throws IOException { + final var input = randomAlphaOfLength(3); + final var query = randomAlphaOfLength(3); + final var apikey = randomAlphaOfLength(3); + final var request = createRequest(TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey, query, input); + final var httpPost = getHttpPost(request, apikey); + final var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get(QUERY), is(query)); + assertThat(requestMap.get(INPUT), is(List.of(input))); + } + + public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopNParam() throws IOException { + final var input = randomAlphaOfLength(3); + final var query = randomAlphaOfLength(3); + final var apikey = randomAlphaOfLength(3); + final var request = createRequest( + TARGET_URI, + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + apikey, + null, + TOP_N, + query, + input + ); + final var httpPost = getHttpPost(request, apikey); + final var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get(QUERY), is(query)); + assertThat(requestMap.get(INPUT), is(List.of(input))); + assertThat(requestMap.get(TOP_N_FIELD), is(TOP_N)); + } + + public void testCreateRequest_WithCohereProviderTokenEndpoint_WithReturnDocumentsParam() throws IOException { + final var input = randomAlphaOfLength(3); + final var query = randomAlphaOfLength(3); + final var apikey = randomAlphaOfLength(3); + final var request = createRequest( + TARGET_URI, + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + apikey, + true, + null, + query, + input + ); + final var httpPost = getHttpPost(request, apikey); + final var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get(QUERY), is(query)); + assertThat(requestMap.get(INPUT), is(List.of(input))); + assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD), is(true)); + } + + private HttpPost getHttpPost(AzureAiStudioRerankRequest request, String apikey) { + final var httpRequest = request.createHttpRequest(); + + final var httpPost = validateRequestUrlAndContentType(httpRequest, TARGET_URI + "/v1/rerank"); + validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey); + return httpPost; + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + final var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is(expectedUrl)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + return httpPost; + } + + private void validateRequestApiKey( + HttpPost httpPost, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + String apiKey + ) { + if (endpointType == AzureAiStudioEndpointType.TOKEN) { + if (provider == AzureAiStudioProvider.OPENAI) { + assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey)); + } else { + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(apiKey)); + } + } else { + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); + } + } + + public static AzureAiStudioRerankRequest createRequest( + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + String apiKey, + String query, + String input + ) { + return createRequest(target, provider, endpointType, apiKey, null, null, query, input); + } + + public static AzureAiStudioRerankRequest createRequest( + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + String apiKey, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + String query, + String input + ) { + final var model = AzureAiStudioRerankModelTests.createModel( + "id", + target, + provider, + endpointType, + apiKey, + returnDocuments, + topN, + null + ); + return new AzureAiStudioRerankRequest(model, query, List.of(input), returnDocuments, topN); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModelTests.java new file mode 100644 index 0000000000000..0cf5500e189b7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModelTests.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URISyntaxException; + +import static org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class AzureAiStudioRerankModelTests extends ESTestCase { + private static final String MODEL_ID = "id"; + private static final String TARGET_URI = "http://testtarget.local"; + private static final String API_KEY = "apikey"; + private static final Integer TOP_N = 1; + private static final Integer TOP_N_OVERRIDE = 2; + + public void testOverrideWith_OverridesWithoutValues() { + final var model = createModel( + MODEL_ID, + TARGET_URI, + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + API_KEY, + true, + TOP_N, + null + ); + final var requestTaskSettingsMap = getTaskSettingsMap(null, null); + final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettingsMap); + + assertThat(overriddenModel, sameInstance(overriddenModel)); + } + + public void testOverrideWith_returnDocuments() { + final var model = createModel( + MODEL_ID, + TARGET_URI, + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + API_KEY, + true, + null, + null + ); + final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(false, null); + final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings); + + assertThat( + overriddenModel, + is(createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY, false, null, null)) + ); + } + + public void testOverrideWith_topN() { + final var model = createModel( + MODEL_ID, + TARGET_URI, + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + API_KEY, + null, + TOP_N, + null + ); + final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(null, TOP_N_OVERRIDE); + final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + MODEL_ID, + TARGET_URI, + AzureAiStudioProvider.COHERE, + AzureAiStudioEndpointType.TOKEN, + API_KEY, + null, + TOP_N_OVERRIDE, + null + ) + ) + ); + } + + public void testSetsProperUrlForCohereTokenModel() throws URISyntaxException { + final var model = createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY); + assertThat(model.getEndpointUri().toString(), is(TARGET_URI + "/v1/rerank")); + } + + public static AzureAiStudioRerankModel createModel( + String id, + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + String apiKey + ) { + return createModel(id, target, provider, endpointType, apiKey, null, null, null); + } + + public static AzureAiStudioRerankModel createModel( + String id, + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + String apiKey, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + @Nullable RateLimitSettings rateLimitSettings + ) { + return new AzureAiStudioRerankModel( + id, + new AzureAiStudioRerankServiceSettings(target, provider, endpointType, rateLimitSettings), + new AzureAiStudioRerankTaskSettings(returnDocuments, topN), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..f5a7fb53388e5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettingsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.MatcherAssert; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureAiStudioRerankRequestTaskSettingsTests extends ESTestCase { + private static final String INVALID_FIELD_TYPE_STRING = "invalid"; + private static final boolean RETURN_DOCUMENTS = true; + private static final int TOP_N = 2; + + public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { + assertThat( + AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of())), + is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS) + ); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + assertThat( + AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))), + is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS) + ); + } + + public void testFromMap_ReturnsReturnDocuments() { + assertThat( + AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(RETURN_DOCUMENTS_FIELD, RETURN_DOCUMENTS))), + is(new AzureAiStudioRerankRequestTaskSettings(RETURN_DOCUMENTS, null)) + ); + } + + public void testFromMap_ReturnsTopN() { + assertThat( + AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_N_FIELD, TOP_N))), + is(new AzureAiStudioRerankRequestTaskSettings(null, TOP_N)) + ); + } + + public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() { + assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD); + } + + public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() { + assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD); + } + + private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) { + final var thrownException = expectThrows( + ValidationException.class, + () -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "field [" + + field + + "] is not of the expected type. The value [" + + INVALID_FIELD_TYPE_STRING + + "] cannot be converted to a " + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..ab102fef4d22f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettingsTests.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType; +import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType.TOKEN; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider.COHERE; +import static org.hamcrest.Matchers.is; + +public class AzureAiStudioRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + private static final String TARGET_URI = "http://testtarget.local"; + + public void testFromMap_Request_CreatesSettingsCorrectly() { + final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap( + createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null))); + } + + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { + final var settingsMap = createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()); + settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))); + + final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3)))); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap( + createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3)); + final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + final String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """ + "rate_limit":{"requests_per_minute":3}}""")); + } + + public void testToFilteredXContent_WritesAllValues() throws IOException { + final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3)); + final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + final var filteredXContent = settings.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + final String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """ + "rate_limit":{"requests_per_minute":3}}""")); + } + + public static HashMap createRequestSettingsMap(String target, String provider, String endpointType) { + return new HashMap<>(Map.of(TARGET_FIELD, target, PROVIDER_FIELD, provider, ENDPOINT_TYPE_FIELD, endpointType)); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureAiStudioRerankServiceSettings::new; + } + + @Override + protected AzureAiStudioRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AzureAiStudioRerankServiceSettings mutateInstance(AzureAiStudioRerankServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, AzureAiStudioRerankServiceSettingsTests::createRandom); + } + + @Override + protected AzureAiStudioRerankServiceSettings mutateInstanceForVersion( + AzureAiStudioRerankServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static AzureAiStudioRerankServiceSettings createRandom() { + return new AzureAiStudioRerankServiceSettings( + randomAlphaOfLength(10), + randomFrom(AzureAiStudioProvider.values()), + randomFrom(AzureAiStudioEndpointType.values()), + RateLimitSettingsTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..0c683c272986e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettingsTests.java @@ -0,0 +1,230 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureAiStudioRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + private static final String INVALID_FIELD_TYPE_STRING = "invalid"; + + public void testIsEmpty() { + final var randomSettings = createRandom(); + final var stringRep = Strings.toString(randomSettings); + assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); + } + + public void testUpdatedTaskSettings_WithAllValues() { + final AzureAiStudioRerankTaskSettings initialSettings = createRandom(); + AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings); + assertUpdateSettings(newSettings, initialSettings); + } + + public void testUpdatedTaskSettings_WithReturnDocumentsValue() { + final AzureAiStudioRerankTaskSettings initialSettings = createRandom(); + AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings); + assertUpdateSettings(newSettings, initialSettings); + } + + public void testUpdatedTaskSettings_WithTopNValue() { + final AzureAiStudioRerankTaskSettings initialSettings = createRandom(); + AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings); + assertUpdateSettings(newSettings, initialSettings); + } + + public void testUpdatedTaskSettings_WithNoValues() { + AzureAiStudioRerankTaskSettings initialSettings = createRandom(); + final AzureAiStudioRerankTaskSettings newSettings = new AzureAiStudioRerankTaskSettings(null, null); + assertUpdateSettings(newSettings, initialSettings); + } + + private void assertUpdateSettings(AzureAiStudioRerankTaskSettings newSettings, AzureAiStudioRerankTaskSettings initialSettings) { + final var settingsMap = new HashMap(); + if (newSettings.returnDocuments() != null) settingsMap.put(RETURN_DOCUMENTS_FIELD, newSettings.returnDocuments()); + if (newSettings.topN() != null) settingsMap.put(TOP_N_FIELD, newSettings.topN()); + + final AzureAiStudioRerankTaskSettings updatedSettings = (AzureAiStudioRerankTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(settingsMap) + ); + assertEquals( + newSettings.returnDocuments() == null ? initialSettings.returnDocuments() : newSettings.returnDocuments(), + updatedSettings.returnDocuments() + ); + assertEquals(newSettings.topN() == null ? initialSettings.topN() : newSettings.topN(), updatedSettings.topN()); + } + + public void testFromMap_AllValues() { + assertEquals(new AzureAiStudioRerankTaskSettings(true, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2))); + } + + public void testFromMap_ReturnDocuments() { + assertEquals( + new AzureAiStudioRerankTaskSettings(true, null), + AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null)) + ); + } + + public void testFromMap_TopN() { + assertEquals(new AzureAiStudioRerankTaskSettings(null, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2))); + } + + public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() { + getTaskSettingsMap(true, 2).put(RETURN_DOCUMENTS_FIELD, INVALID_FIELD_TYPE_STRING); + assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD); + } + + public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() { + getTaskSettingsMap(true, 2).put(TOP_N_FIELD, INVALID_FIELD_TYPE_STRING); + assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD); + } + + public void testFromMap_WithNoValues_DoesNotThrowException() { + final var taskMap = AzureAiStudioRerankTaskSettings.fromMap(new HashMap<>(Map.of())); + assertNull(taskMap.returnDocuments()); + assertNull(taskMap.topN()); + } + + public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { + final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2)); + final var overrideSettings = AzureAiStudioRerankTaskSettings.of(settings, AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS); + MatcherAssert.assertThat(overrideSettings, is(settings)); + } + + public void testOverrideWith_UsesReturnDocumentsOverride() { + final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null)); + final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(false, null)); + final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(false, null))); + } + + public void testOverrideWith_UsesTopNOverride() { + final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2)); + final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(null, 1)); + final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(null, 1))); + } + + public void testOverrideWith_UsesAllParametersOverride() { + final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(false, 2)); + final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(true, 1)); + final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(true, 1))); + } + + public void testToXContent_WithoutParameters() throws IOException { + assertThat(getXContentResult(null, null), is("{}")); + } + + public void testToXContent_WithReturnDocumentsParameter() throws IOException { + assertThat(getXContentResult(true, null), is(""" + {"return_documents":true}""")); + } + + public void testToXContent_WithTopNParameter() throws IOException { + assertThat(getXContentResult(null, 2), is(""" + {"top_n":2}""")); + } + + public void testToXContent_WithParameters() throws IOException { + assertThat(getXContentResult(true, 2), is(""" + {"return_documents":true,"top_n":2}""")); + } + + private String getXContentResult(Boolean returnDocuments, Integer topN) throws IOException { + final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(returnDocuments, topN)); + final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + return Strings.toString(builder); + } + + public static Map getTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) { + final var map = new HashMap(); + + if (returnDocuments != null) { + map.put(RETURN_DOCUMENTS_FIELD, returnDocuments); + } + + if (topN != null) { + map.put(TOP_N_FIELD, topN); + } + + return map; + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureAiStudioRerankTaskSettings::new; + } + + @Override + protected AzureAiStudioRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AzureAiStudioRerankTaskSettings mutateInstance(AzureAiStudioRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, AzureAiStudioRerankTaskSettingsTests::createRandom); + } + + @Override + protected AzureAiStudioRerankTaskSettings mutateInstanceForVersion(AzureAiStudioRerankTaskSettings instance, TransportVersion version) { + return instance; + } + + private static AzureAiStudioRerankTaskSettings createRandom() { + return new AzureAiStudioRerankTaskSettings( + randomFrom(new Boolean[] { null, randomBoolean() }), + randomFrom(new Integer[] { null, randomNonNegativeInt() }) + ); + } + + private static AzureAiStudioRerankTaskSettings createRandom(AzureAiStudioRerankTaskSettings settings) { + return new AzureAiStudioRerankTaskSettings( + randomValueOtherThan(settings.returnDocuments(), () -> randomFrom(new Boolean[] { null, randomBoolean() })), + randomValueOtherThan(settings.topN(), () -> randomFrom(new Integer[] { null, randomNonNegativeInt() })) + ); + } + + private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) { + final var thrownException = expectThrows( + ValidationException.class, + () -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "field [" + + field + + "] is not of the expected type. The value [" + + INVALID_FIELD_TYPE_STRING + + "] cannot be converted to a " + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntityTests.java new file mode 100644 index 0000000000000..edb2f7ae746f4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntityTests.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureaistudio.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AzureAiStudioRerankResponseEntityTests extends ESTestCase { + public void testResponse_WithDocuments() throws IOException { + final String responseJson = getResponseJsonWithDocuments(); + + final var parsedResults = getParsedResults(responseJson); + final var expectedResults = List.of( + new RankedDocsResults.RankedDoc(0, 0.1111111F, "test text one"), + new RankedDocsResults.RankedDoc(1, 0.2222222F, "test text two") + ); + + assertThat(parsedResults.getRankedDocs(), is(expectedResults)); + } + + public void testResponse_NoDocuments() throws IOException { + final String responseJson = getResponseJsonNoDocuments(); + + final var parsedResults = getParsedResults(responseJson); + final var expectedResults = List.of( + new RankedDocsResults.RankedDoc(0, 0.1111111F, null), + new RankedDocsResults.RankedDoc(1, 0.2222222F, null) + ); + + assertThat(parsedResults.getRankedDocs(), is(expectedResults)); + } + + private RankedDocsResults getParsedResults(String responseJson) throws IOException { + final var entity = new AzureAiStudioRerankResponseEntity(); + return (RankedDocsResults) entity.apply( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + } + + private String getResponseJsonWithDocuments() { + return """ + { + "id": "222e59de-c712-40cb-ae87-ecd402d0d2f1", + "results": [ + { + "document": { + "text": "test text one" + }, + "index": 0, + "relevance_score": 0.1111111 + }, + { + "document": { + "text": "test text two" + }, + "index": 1, + "relevance_score": 0.2222222 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + } + + private String getResponseJsonNoDocuments() { + return """ + { + "id": "222e59de-c712-40cb-ae87-ecd402d0d2f1", + "results": [ + { + "index": 0, + "relevance_score": 0.1111111 + }, + { + "index": 1, + "relevance_score": 0.2222222 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index de2e9ae9a21b8..f3d65c5589169 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -752,7 +752,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -785,7 +785,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep public void testInfer_SendsRequest() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -844,7 +844,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createModelWithRandomValues(); assertThrows( ElasticsearchStatusException.class, @@ -864,7 +864,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureOpenAiEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -891,7 +891,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -952,7 +952,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createCompletionModel( "resource", "deployment", @@ -1209,14 +1209,18 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AzureOpenAiService createAzureOpenAiService() { - return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureOpenAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 52e4f904a4de0..8f189baa33b20 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -779,7 +779,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new CohereService(factory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -812,7 +812,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -886,7 +886,7 @@ public void testInfer_SendsRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -906,7 +906,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(CohereEmbeddingType.values()); var model = CohereEmbeddingsModelTests.createModel( @@ -933,7 +933,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -975,7 +975,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsAreEmpty() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1051,7 +1051,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1125,7 +1125,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v1API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1200,7 +1200,7 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v2API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1297,7 +1297,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1387,7 +1387,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1507,7 +1507,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1591,7 +1591,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { + try (var service = new CohereService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1632,7 +1632,7 @@ private Map getRequestConfigMap(Map serviceSetti } private CohereService createCohereService() { - return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index 88d26d5d7eef1..6438a328f9fcf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -209,7 +209,10 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false)) + ); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java index 78b8b7bdeaf3e..6c5128956fc9b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java @@ -132,7 +132,10 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "model", "stream", false)) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java index 2fb51ca8ca457..6003a58bf0340 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java @@ -46,7 +46,10 @@ public void testCreateRequest() throws IOException { assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false))); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "abc")), "model", "required model id", "stream", false)) + ); } public void testDefaultUrl() { @@ -88,6 +91,6 @@ public void testXContents() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, CoreMatchers.is(""" - {"message":"some input","model":"model","stream":false}""")); + {"messages":[{"role":"user","content":"some input"}],"model":"model","stream":false}""")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java index 4a60dc5033e22..5d7a76a26e597 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereCompletionResponseEntityTests.java @@ -64,6 +64,42 @@ public void testFromResponse_CreatesResponseEntityForText() throws IOException { assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); } + public void testFromResponseV2() throws IOException { + String responseJson = """ + { + "id": "abc123", + "finish_reason": "COMPLETE", + "message": { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Response from the llm" + } + ] + }, + "usage": { + "billed_units": { + "input_tokens": 1, + "output_tokens": 4 + }, + "tokens": { + "input_tokens": 2, + "output_tokens": 5 + } + } + } + """; + + ChatCompletionResults chatCompletionResults = CohereCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("Response from the llm")); + } + public void testFromResponse_FailsWhenTextIsNotPresent() { String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index cc1bb4471c0a9..a707030a34189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -53,6 +53,7 @@ import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -148,7 +149,7 @@ private static void assertCompletionModel(Model model) { public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + return new CustomService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static Map createServiceSettingsMap(TaskType taskType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index af38ee38e1eff..908451b8e681f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -360,7 +360,8 @@ public void testDoChunkedInferAlwaysFails() throws IOException { private DeepSeekService createService() { return new DeepSeekService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6ce484954d3ce..94d1e064648ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1427,7 +1427,8 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1456,7 +1457,8 @@ private ElasticInferenceService createService( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1469,7 +1471,8 @@ private ElasticInferenceService createServiceWithAuthHandler( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java index 4ec575420613f..269d64893fe34 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java @@ -7,16 +7,24 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; public class ElasticsearchInternalServiceSettingsTests extends AbstractWireSerializingTestCase { @@ -138,4 +146,80 @@ public void testFromMapInvalidSettings() { assertThat(e.getMessage(), containsString("Invalid value [0]. [num_allocations] must be a positive integer")); assertThat(e.getMessage(), containsString("Invalid value [-1]. [num_threads] must be a positive integer")); } + + public void testUpdateNumAllocations() { + var testInstance = createTestInstance(); + var expectedNumAllocations = testInstance.getNumAllocations() != null ? testInstance.getNumAllocations() + 1 : 1; + var updatedInstance = testInstance.updateServiceSettings( + Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, expectedNumAllocations) + ); + + assertThat("update should create a new instance", updatedInstance, not(equalTo(testInstance))); + assertThat(updatedInstance.getNumAllocations(), equalTo(expectedNumAllocations)); + assertThat(updatedInstance.getAdaptiveAllocationsSettings(), nullValue()); + assertThat(updatedInstance.getNumThreads(), equalTo(testInstance.getNumThreads())); + assertThat(updatedInstance.getDeploymentId(), equalTo(testInstance.getDeploymentId())); + assertThat(updatedInstance.modelId(), equalTo(testInstance.modelId())); + + } + + public void testUpdateAdaptiveAllocations() throws IOException { + var testInstance = createTestInstance(); + var expectedAdaptiveAllocations = adaptiveAllocationSettings(testInstance.getAdaptiveAllocationsSettings()); + var updatedInstance = testInstance.updateServiceSettings( + Map.of(ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS, toMap(expectedAdaptiveAllocations)) + ); + + assertThat("update should create a new instance", updatedInstance, not(equalTo(testInstance))); + assertThat(updatedInstance.getNumAllocations(), nullValue()); + assertThat(updatedInstance.getAdaptiveAllocationsSettings(), equalTo(expectedAdaptiveAllocations)); + assertThat(updatedInstance.getNumThreads(), equalTo(testInstance.getNumThreads())); + assertThat(updatedInstance.getDeploymentId(), equalTo(testInstance.getDeploymentId())); + assertThat(updatedInstance.modelId(), equalTo(testInstance.modelId())); + } + + private static AdaptiveAllocationsSettings adaptiveAllocationSettings(AdaptiveAllocationsSettings base) { + if (base == null) { + base = new AdaptiveAllocationsSettings(true, 0, 1); + } else { + base = new AdaptiveAllocationsSettings(true, base.getMinNumberOfAllocations() + 1, base.getMaxNumberOfAllocations() + 1); + } + return base; + } + + private static Map toMap(AdaptiveAllocationsSettings adaptiveAllocationsSettings) throws IOException { + try (var builder = JsonXContent.contentBuilder()) { + adaptiveAllocationsSettings.toXContent(builder, ToXContent.EMPTY_PARAMS); + var bytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8); + try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, bytes)) { + return parser.map(); + } + } + } + + public void testUpdateNumAllocationsAndAdaptiveAllocations() { + var validationException = assertThrows(ValidationException.class, () -> { + createTestInstance().updateServiceSettings( + Map.ofEntries( + Map.entry(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1), + Map.entry(ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS, toMap(adaptiveAllocationSettings(null))) + ) + ); + }); + assertThat( + validationException.getMessage(), + equalTo("Validation Failed: 1: [num_allocations] cannot be set if [adaptive_allocations] is set;") + ); + } + + public void testUpdateWithNoNumAllocationsAndAdaptiveAllocations() { + var validationException = assertThrows(ValidationException.class, () -> createTestInstance().updateServiceSettings(Map.of())); + assertThat( + validationException.getMessage(), + equalTo( + "Validation Failed: 1: [service_settings] does not contain one of the required settings " + + "[num_allocations, adaptive_allocations];" + ) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 27709c2067a26..6c01145701d92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.support.ActionTestUtils; @@ -37,6 +38,8 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -57,10 +60,12 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; @@ -98,6 +103,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -108,10 +114,13 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; @@ -122,12 +131,16 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { - String randomInferenceEntityId = randomAlphaOfLength(10); + private String randomInferenceEntityId; + private InferenceStats inferenceStats; private static ThreadPool threadPool; @Before - public void setUpThreadPool() { + public void setUp() throws Exception { + super.setUp(); + randomInferenceEntityId = randomAlphaOfLength(10); + inferenceStats = InferenceStatsTests.mockInferenceStats(); threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); } @@ -1767,7 +1780,9 @@ private void testUpdateModelsWithDynamicFields(Map> modelsBy modelsByDeploymentId.forEach((deploymentId, models) -> { var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId); models.forEach(model -> { - verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations); + verify((ElasticsearchInternalModel) model).updateServiceSettings(assertArg(assignmentStats -> { + assertThat(assignmentStats.getNumberOfAllocations(), equalTo(expectedNumberOfAllocations)); + })); verify((ElasticsearchInternalModel) model).mlNodeDeploymentId(); verifyNoMoreInteractions(model); }); @@ -1809,7 +1824,8 @@ public void testUpdateWithoutMlEnabled() throws IOException, InterruptedExceptio mock(), threadPool, cs, - Settings.builder().put("xpack.ml.enabled", false).build() + Settings.builder().put("xpack.ml.enabled", false).build(), + inferenceStats ); try (var service = new ElasticsearchInternalService(context)) { var models = List.of(mock(Model.class)); @@ -1851,19 +1867,97 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { client, threadPool, cs, - Settings.builder().put("xpack.ml.enabled", true).build() + Settings.builder().put("xpack.ml.enabled", true).build(), + inferenceStats ); try (var service = new ElasticsearchInternalService(context)) { List models = List.of(model); var latch = new CountDownLatch(1); service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown())); assertTrue(latch.await(30, TimeUnit.SECONDS)); - verify(model).updateNumAllocations(3); + verify(model).updateServiceSettings( + assertArg(assignmentStats -> { assertThat(assignmentStats.getNumberOfAllocations(), equalTo(3)); }) + ); } } public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { - var model = new ElserInternalModel( + var model = mockModel(); + + var client = mockClientForStart( + listener -> listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)) + ); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat(exception.getMessage(), is("failed")); + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertThat(attributes.get("error.type"), is("504")); + })); + } + } + + public void testStart_OnFailure_WhenDeploymentTimeoutOccurs() throws IOException { + var model = mockModel(); + + var client = mockClientForStart( + listener -> listener.onFailure(new ElasticsearchTimeoutException("failed", RestStatus.GATEWAY_TIMEOUT)) + ); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ModelDeploymentTimeoutException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat( + exception.getMessage(), + is( + "Timed out after [30s] waiting for trained model deployment for inference endpoint [inference_id] to start. " + + "The inference endpoint can not be used to perform inference until the deployment has started. " + + "Use the trained model stats API to track the state of the deployment." + ) + ); + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertThat(attributes.get("error.type"), is("408")); + })); + } + } + + public void testStart() throws IOException { + var model = mockModel(); + + var client = mockClientForStart(listener -> { + var response = mock(CreateTrainedModelAssignmentAction.Response.class); + when(response.getTrainedModelAssignment()).thenReturn(TrainedModelAssignmentTests.randomInstance()); + listener.onResponse(response); + }); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + assertTrue(actionListener.actionGet(TimeValue.timeValueSeconds(30))); + + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertNull(attributes.get("error.type")); + assertThat(attributes.get("status_code"), is(200)); + })); + } + } + + private ElserInternalModel mockModel() { + return new ElserInternalModel( "inference_id", TaskType.SPARSE_EMBEDDING, "elasticsearch", @@ -1873,7 +1967,9 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { new ElserMlNodeTaskSettings(), null ); + } + private Client mockClientForStart(Consumer> startModelListener) { var client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); @@ -1889,27 +1985,18 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(2); - listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)); + startModelListener.accept(listener); return Void.TYPE; }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any()); - try (var service = createService(client)) { - var actionListener = new PlainActionFuture(); - service.start(model, TimeValue.timeValueSeconds(30), actionListener); - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) - ); - - assertThat(exception.getMessage(), is("failed")); - } + return client; } private ElasticsearchInternalService createService(Client client) { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); when(cs.getClusterSettings()).thenReturn(cSettings); - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY, inferenceStats); return new ElasticsearchInternalService(context); } @@ -1918,7 +2005,8 @@ private ElasticsearchInternalService createService(Client client, BaseElasticsea client, threadPool, mock(ClusterService.class), - Settings.EMPTY + Settings.EMPTY, + inferenceStats ); return new ElasticsearchInternalService(context, l -> l.onResponse(modelVariant)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java index 5b21717ac03e4..3fee80b1fbe5f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java @@ -7,8 +7,17 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class ElserInternalModelTests extends ESTestCase { public void testUpdateNumAllocation() { @@ -21,10 +30,22 @@ public void testUpdateNumAllocation() { null ); - model.updateNumAllocations(1); - assertEquals(1, model.getServiceSettings().getNumAllocations().intValue()); + AssignmentStats assignmentStats = mock(); + when(assignmentStats.getNumberOfAllocations()).thenReturn(1); + model.updateServiceSettings(assignmentStats); + + assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1)); + assertNull(model.getServiceSettings().getAdaptiveAllocationsSettings()); - model.updateNumAllocations(null); - assertNull(model.getServiceSettings().getNumAllocations()); + TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentTests.randomInstance(); + CreateTrainedModelAssignmentAction.Response response = mock(); + when(response.getTrainedModelAssignment()).thenReturn(trainedModelAssignment); + model.getCreateTrainedModelAssignmentActionListener(model, ActionListener.noop()).onResponse(response); + + assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1)); + assertThat( + model.getServiceSettings().getAdaptiveAllocationsSettings(), + equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings()) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 41175581df1cf..435ea9de5911b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -658,7 +658,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -696,7 +696,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD var model = GoogleAiStudioEmbeddingsModelTests.createModel("model", getUrl(webServer), "secret"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -730,7 +730,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "candidates": [ @@ -818,7 +818,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -897,7 +897,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -998,7 +998,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed public void testInfer_ResourceNotFound() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1033,7 +1033,7 @@ public void testInfer_ResourceNotFound() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1052,7 +1052,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = GoogleAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1124,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1171,6 +1171,10 @@ private Map getRequestConfigMap( } private GoogleAiStudioService createGoogleAiStudioService() { - return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new GoogleAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 99a09b983787d..26fd076e72462 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -1043,7 +1043,7 @@ public void testGetConfiguration() throws Exception { private GoogleVertexAiService createGoogleVertexAiService() { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool)); + return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 3be4b72c1237f..2cdf3f5263751 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -29,6 +29,7 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; import static org.mockito.Mockito.mock; @@ -92,7 +93,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep private static final class TestService extends HuggingFaceService { TestService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 814d533129439..93156d4331263 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -81,7 +81,7 @@ public void shutdown() throws IOException { public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -137,7 +137,8 @@ public void testGetConfiguration() throws Exception { try ( var service = new HuggingFaceElserService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ) ) { String content = XContentHelper.stripWhitespace(""" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index e2850910ac64a..c770672c5d5f2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -258,7 +258,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -328,7 +328,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -357,7 +357,7 @@ public void testUnifiedCompletionNonStreamingError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -486,7 +486,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -548,7 +548,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -621,7 +621,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1009,7 +1009,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1060,7 +1060,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1087,7 +1087,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { public void testInfer_SendsElserRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1139,7 +1139,7 @@ public void testInfer_SendsElserRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1158,7 +1158,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = HuggingFaceEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1179,7 +1179,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1233,7 +1233,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th public void testChunkedInfer() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1340,7 +1340,11 @@ public void testGetConfiguration() throws Exception { } private HuggingFaceService createHuggingFaceService() { - return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new HuggingFaceService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 3295ecfd4ece5..ddc62b5a412b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -597,7 +597,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -635,7 +635,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = IbmWatsonxEmbeddingsModelTests.createModel(modelId, projectId, URI.create(url), apiVersion, apiKey, getUrl(webServer)); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1018,12 +1018,12 @@ private Map getRequestConfigMap( } private IbmWatsonxService createIbmWatsonxService() { - return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index eca76bc1a702a..d36c574e0aa99 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -778,7 +778,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -819,7 +819,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var model = JinaAIEmbeddingsModelTests.createModel( @@ -846,7 +846,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -889,7 +889,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -923,7 +923,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -994,7 +994,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_clustering() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, @@ -1120,7 +1120,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1210,7 +1210,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1295,7 +1295,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1392,7 +1392,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1475,7 +1475,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1540,7 +1540,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1637,7 +1637,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1800,7 +1800,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1841,7 +1841,7 @@ private Map getRequestConfigMap(Map serviceSetti } private JinaAIService createJinaAIService() { - return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java new file mode 100644 index 0000000000000..dd68c43f5e62d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -0,0 +1,840 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; + +public class LlamaServiceTests extends AbstractInferenceServiceTests { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + public LlamaServiceTests() { + super(createTestConfiguration()); + } + + private static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) { + + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return LlamaServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return LlamaServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } + + @Override + protected Map createSecretSettingsMap() { + return LlamaServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType) { + LlamaServiceTests.assertModel(model, taskType); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + }).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); + case COMPLETION -> assertCompletionModel(model); + case CHAT_COMPLETION -> assertChatCompletionModel(model); + default -> fail("unexpected task type [" + taskType + "]"); + } + } + + private static void assertTextEmbeddingModel(Model model) { + var llamaModel = assertCommonModelFields(model); + + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); + } + + private static LlamaModel assertCommonModelFields(Model model) { + assertThat(model, instanceOf(LlamaModel.class)); + + var llamaModel = (LlamaModel) model; + assertThat(llamaModel.getServiceSettings().modelId(), is("model_id")); + assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com")); + assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); + assertThat( + ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), + Matchers.is(new SecureString("secret".toCharArray())) + ); + + return llamaModel; + } + + private static void assertCompletionModel(Model model) { + var llamaModel = assertCommonModelFields(model); + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); + } + + private static void assertChatCompletionModel(Model model) { + var llamaModel = assertCommonModelFields(model); + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); + } + + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + Map settingsMap = new HashMap<>( + Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id") + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ); + } + + return settingsMap; + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of("api_key", "secret")); + } + + private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { + var inferenceId = "inference_id"; + + return new LlamaEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + LlamaService.NAME, + new LlamaEmbeddingsServiceSettings( + "model_id", + "http://www.abc.com", + 1536, + similarityMeasure, + 512, + new RateLimitSettings(10_000) + ), + ChunkingSettingsTests.createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + } + + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "Failed to parse stored model [id] for [llama] service, please delete and add the service again"; + } + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(LlamaEmbeddingsModel.class)); + + var embeddingsModel = (LlamaEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(LlamaEmbeddingsModel.class)); + + var embeddingsModel = (LlamaEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException { + var url = "url"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(LlamaChatCompletionModel.class)); + + var chatCompletionModel = (LlamaChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutUrl() throws IOException { + var model = "model"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(LlamaChatCompletionModel.class)); + + var chatCompletionModel = (LlamaChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(model)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = createChatCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" + { + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26", + "choices": [{ + "delta": { + "content": "Deep", + "role": "assistant" + }, + "index": 0 + } + ], + "model": "llama3.2:3b", + "object": "chat.completion.chunk" + } + """)); + } + } + + public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + String responseJson = """ + { + "detail": "Not Found" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [id] status \ + [404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]", + "type" : "llama_error" + } + }"""), getUrl(webServer)))); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + data: {"error": {"message": "400: Invalid value: Model 'llama3.12:3b' not found"}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "stream_error", + "message": "Received an error response for request from inference entity id [id].\ + Error message: [400: Invalid value: Model 'llama3.12:3b' not found]", + "type": "llama_error" + } + } + """)); + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testInfer_StreamRequest_ErrorResponse() { + String responseJson = """ + { + "detail": "Not Found" + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.NOT_FOUND)); + assertThat(e.getMessage(), equalTo(String.format(Locale.ROOT, """ + Resource not found at [%s] for request from inference entity id [id] status [404]. Error message: [{ + "detail": "Not Found" + }]""", getUrl(webServer)))); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new LlamaService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModelWithChunkingSettings("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "embeddings": [ + [ + 0.010060793, + -0.0017529363 + ], + [ + 0.110060793, + -0.1017529363 + ] + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.010060793f, -0.0017529363f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.110060793f, -0.1017529363f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("contents"), Matchers.is(List.of("abc", "def"))); + assertThat(requestMap.get("model_id"), Matchers.is("id")); + } + } + + public void testGetConfiguration() throws Exception { + try (var service = createService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "llama", + "name": "Llama", + "task_types": ["text_embedding", "completion", "chat_completion"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "model_id": { + "description": "Refer to the Llama models documentation for the list of available models.", + "label": "Model", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "url": { + "description": "The URL endpoint to use for the requests.", + "label": "URL", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + private LlamaService createService() { + return new LlamaService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private static Map getEmbeddingsServiceSettingsMap() { + return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java new file mode 100644 index 0000000000000..366e0926f0daa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -0,0 +1,283 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class LlamaActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + [ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + ] + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello there, how may I assist you today?", + "refusal": null, + "role": "assistant", + "annotations": null, + "audio": null, + "function_call": null, + "tool_calls": null + } + } + ], + "created": 1750157476, + "model": "llama3.2:3b", + "object": "chat.completion", + "service_tier": null, + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 30, + "total_tokens": 40, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertCompletionRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]") + ); + + assertCompletionRequest(); + } + } + + private PlainActionFuture createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + return listener; + } + + private PlainActionFuture createCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertCompletionRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + + @SuppressWarnings("unchecked") + private void assertEmbeddingsRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("contents"), instanceOf(List.class)); + var inputList = (List) requestMap.get("contents"); + assertThat(inputList, contains("abc")); + } + + private void assertCommonRequestProperties() { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java new file mode 100644 index 0000000000000..844d17addac6d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionModelTests extends ESTestCase { + + public static LlamaChatCompletionModel createCompletionModel(String modelId, String url, String apiKey) { + return new LlamaChatCompletionModel( + "id", + TaskType.COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaChatCompletionModel createChatCompletionModel(String modelId, String url, String apiKey) { + return new LlamaChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaChatCompletionModel createChatCompletionModelNoAuth(String modelId, String url) { + return new LlamaChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + EmptySecretSettings.INSTANCE + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "model_name", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel(null, "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel(null, "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..c9b6069d383ed --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class LlamaChatCompletionResponseHandlerTests extends ESTestCase { + private final LlamaChatCompletionResponseHandler responseHandler = new LlamaChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailNotFound() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "detail": "Not Found" + } + """); + + var errorJson = invalidResponseJson(responseJson, 404); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \ + status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", + "type" : "llama_error" + } + }"""))); + } + + public void testFailBadRequest() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "error": { + "detail": { + "errors": [{ + "loc": [ + "body", + "messages" + ], + "msg": "Field required", + "type": "missing" + } + ] + } + } + } + """); + + var errorJson = invalidResponseJson(responseJson, 400); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a bad request status code for request from inference entity id [id] status [400].\ + Error message: [{\\"error\\":{\\"detail\\":{\\"errors\\":[{\\"loc\\":[\\"body\\",\\"messages\\"],\\"msg\\":\\"Field\ + required\\",\\"type\\":\\"missing\\"}]}}}]", + "type": "llama_error" + } + } + """))); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson, 500); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \ + [what? this isn't a json\\n]", + "type": "llama_error" + } + } + """))); + } + + private String invalidResponseJson(String responseJson, int statusCode) throws IOException { + var exception = invalidResponse(responseJson, statusCode); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson, int statusCode) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + + private static Request mockRequest() throws URISyntaxException { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions")); + return request; + } + + private static HttpResponse mockErrorResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..21b42453d9c39 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static final String MODEL_ID = "some model"; + public static final String CORRECT_URL = "https://www.elastic.co"; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT)))); + } + + public void testFromMap_MissingModelId_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + + public void testFromMap_MissingUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return LlamaChatCompletionServiceSettings::new; + } + + @Override + protected LlamaChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected LlamaChatCompletionServiceSettings mutateInstance(LlamaChatCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, LlamaChatCompletionServiceSettingsTests::createRandom); + } + + @Override + protected LlamaChatCompletionServiceSettings mutateInstanceForVersion( + LlamaChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static LlamaChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + return new LlamaChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String model, String url) { + var map = new HashMap(); + + map.put(ServiceFields.MODEL_ID, model); + map.put(ServiceFields.URL, url); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java new file mode 100644 index 0000000000000..4e75cab196a6d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; + +public class LlamaEmbeddingsModelTests extends ESTestCase { + public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String url, String apiKey) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaEmbeddingsModel createEmbeddingsModelWithChunkingSettings(String modelId, String url, String apiKey) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaEmbeddingsModel createEmbeddingsModelNoAuth(String modelId, String url) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + null, + EmptySecretSettings.INSTANCE + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..5fd3ce704540c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java @@ -0,0 +1,479 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class LlamaEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final String MODEL_ID = "some model"; + private static final String CORRECT_URL = "https://www.elastic.co"; + private static final int DIMENSIONS = 384; + private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT; + private static final int MAX_INPUT_TOKENS = 128; + private static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_NoModelId_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + null, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + + public void testFromMap_NoUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + null, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_EmptyUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;") + ); + } + + public void testFromMap_InvalidUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "^^^", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. " + + "Error: unable to parse url [^^^]. Reason: Illegal character in path;" + ) + ); + } + + public void testFromMap_NoSimilarity_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + null, + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + null, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + "by_size", + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. " + + "[similarity] must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testFromMap_NoDimensions_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + null, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + null, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_ZeroDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + 0, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NegativeDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + -10, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NoInputTokens_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + null, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + null, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_ZeroInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + 0, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NegativeInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + -10, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NoRateLimit_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap(MODEL_ID, CORRECT_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, null), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3000) + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "dimensions": 384, + "similarity": "dot_product", + "max_input_tokens": 128, + "rate_limit": { + "requests_per_minute": 3 + } + } + """))); + } + + public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { + var outputBuffer = new BytesStreamOutput(); + var settings = new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3) + ); + settings.writeTo(outputBuffer); + + var outputBufferRef = outputBuffer.bytes(); + var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array()); + + var settingsFromBuffer = new LlamaEmbeddingsServiceSettings(inputBuffer); + + assertEquals(settings, settingsFromBuffer); + } + + @Override + protected Writeable.Reader instanceReader() { + return LlamaEmbeddingsServiceSettings::new; + } + + @Override + protected LlamaEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected LlamaEmbeddingsServiceSettings mutateInstance(LlamaEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, LlamaEmbeddingsServiceSettingsTests::createRandom); + } + + private static LlamaEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + var similarityMeasure = randomFrom(SimilarityMeasure.values()); + var dimensions = randomIntBetween(32, 256); + var maxInputTokens = randomIntBetween(128, 256); + return new LlamaEmbeddingsServiceSettings( + modelId, + url, + dimensions, + similarityMeasure, + maxInputTokens, + RateLimitSettingsTests.createRandom() + ); + } + + public static HashMap buildServiceSettingsMap( + @Nullable String modelId, + @Nullable String url, + @Nullable String similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable HashMap rateLimitSettings + ) { + HashMap result = new HashMap<>(); + if (modelId != null) { + result.put(ServiceFields.MODEL_ID, modelId); + } + if (url != null) { + result.put(ServiceFields.URL, url); + } + if (similarity != null) { + result.put(ServiceFields.SIMILARITY, similarity); + } + if (dimensions != null) { + result.put(ServiceFields.DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + result.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (rateLimitSettings != null) { + result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings); + } + return result; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..dd8b3d7dfa38c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; + +import java.io.IOException; +import java.util.ArrayList; + +public class LlamaChatCompletionRequestEntityTests extends ESTestCase { + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + LlamaChatCompletionModel model = LlamaChatCompletionModelTests.createChatCompletionModel("model", "url", "api-key"); + + LlamaChatCompletionRequestEntity entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String expectedJson = """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java new file mode 100644 index 0000000000000..6f0701a810fb1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest_WithStreaming() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequest("model", "url", "secret", input, true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(true)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + } + + public void testCreateRequest_NoStreaming_NoAuthorization() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequestWithNoAuth("model", "url", input, false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(false)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertNull(requestMap.get("stream_options")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertNull(httpPost.getFirstHeader("Authorization")); + } + + public void testTruncate_DoesNotReduceInputTextSize() { + String input = randomAlphaOfLength(5); + var request = createRequest("model", "url", "secret", input, true); + assertThat(request.truncate(), is(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true); + assertNull(request.getTruncationInfo()); + } + + public static LlamaChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { + var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModel(modelId, url, apiKey); + return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + + public static LlamaChatCompletionRequest createRequestWithNoAuth(String modelId, String url, String input, boolean stream) { + var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModelNoAuth(modelId, url); + return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..a055a0870e30d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class LlamaEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_Success() throws IOException { + var entity = new LlamaEmbeddingsRequestEntity("llama-embed", List.of("ABDC")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "model_id": "llama-embed", + "contents": ["ABDC"] + } + """))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..ab24fa9a0bc56 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class LlamaEmbeddingsRequestTests extends ESTestCase { + + public void testCreateRequest_WithAuth_Success() throws IOException { + var request = createRequest(); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("ABCD"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + } + + public void testCreateRequest_NoAuth_Success() throws IOException { + var request = createRequestNoAuth(); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("ABCD"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + assertNull(httpPost.getFirstHeader("Authorization")); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest(); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("AB"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest(); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request) { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + return httpPost; + } + + private static LlamaEmbeddingsRequest createRequest() { + var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey"); + return new LlamaEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + + private static LlamaEmbeddingsRequest createRequestNoAuth() { + var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModelNoAuth("llama-embed", "url"); + return new LlamaEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java new file mode 100644 index 0000000000000..aa3c6f6c20b6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Mockito.mock; + +public class LlamaErrorResponseTests extends ESTestCase { + + public static final String ERROR_RESPONSE_JSON = """ + { + "error": "A valid user token is required" + } + """; + + public void testFromResponse() { + var errorResponse = LlamaErrorResponse.fromResponse( + new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage()); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 4ba9b8aa24394..8e170b25393e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -249,7 +249,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -308,7 +308,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -353,7 +353,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -421,7 +421,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -459,7 +459,7 @@ public void testInfer_StreamRequest_ErrorResponse() { } public void testSupportsStreaming() throws IOException { - try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) { + try (var service = new MistralService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -942,7 +942,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = new Model(ModelConfigurationsTests.createRandomInstance()); assertThrows( @@ -962,7 +962,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = MistralEmbeddingModelTests.createModel( randomAlphaOfLength(10), @@ -990,7 +990,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1028,7 +1028,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1086,7 +1086,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1173,7 +1173,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1276,7 +1276,7 @@ public void testGetConfiguration() throws Exception { // ---------------------------------------------------------------- private MistralService createService() { - return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java index 6f8b40fd7f19c..9aa076e224efe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -37,7 +36,6 @@ public static MistralEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), - EmptyTaskSettings.INSTANCE, chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -57,7 +55,6 @@ public static MistralEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), - EmptyTaskSettings.INSTANCE, null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java index 4a70861932d28..2c8fb4fd48698 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java @@ -49,7 +49,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap, aMapWithSize(4)); - // We do not truncate for Hugging Face chat completions + // We do not truncate for Mistral chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c19eb664e88ac..83455861198d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -847,7 +847,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -885,7 +885,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -924,7 +924,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -965,7 +965,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1003,7 +1003,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1099,7 +1099,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1132,7 +1132,7 @@ public void testUnifiedCompletionError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -1189,7 +1189,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1267,7 +1267,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1344,7 +1344,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1400,7 +1400,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1485,7 +1485,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // response with 2 embeddings String responseJson = """ @@ -1656,6 +1656,6 @@ public void testGetConfiguration() throws Exception { } private OpenAiService createOpenAiService() { - return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index d7d9473f18084..bf883a6345398 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -47,6 +47,7 @@ import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS; import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -84,7 +85,7 @@ public void init() { ThreadPool threadPool = mock(); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of); + sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, mockClusterServiceEmpty()); } public void testSupportedTaskTypes() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java index adffbb366fb02..90b5042d3dec4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java @@ -119,7 +119,7 @@ public final void testWithUnknownApiTaskSettings() { } } - public final void testUpdate() throws IOException { + public void testUpdate() throws IOException { var taskSettings = randomApiTaskSettings(); if (taskSettings != SageMakerStoredTaskSchema.NO_OP) { var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java index 65dcd62bb149a..9e4cfc52e9568 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java @@ -18,8 +18,8 @@ import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; +import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -50,6 +50,7 @@ protected SageMakerModel mockModel(SageMakerElasticTaskSettings taskSettings) { return model; } + @Override public void testApiTaskSettings() { { var validationException = new ValidationException(); @@ -67,14 +68,21 @@ public void testApiTaskSettings() { var validationException = new ValidationException(); var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException); assertTrue(actualApiTaskSettings.isEmpty()); - assertFalse(validationException.validationErrors().isEmpty()); - assertThat( - validationException.validationErrors().get(0), - is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint.")) - ); + assertTrue(validationException.validationErrors().isEmpty()); } } + @Override + public void testUpdate() { + var taskSettings = randomApiTaskSettings(); + var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings); + var e = assertThrows(ValidationException.class, () -> taskSettings.updatedTaskSettings(toMap(otherTaskSettings))); + assertThat( + e.getMessage(), + containsString("task_settings is only supported during the inference request and cannot be stored in the inference endpoint") + ); + } + public void testRequestWithRequiredFields() throws Exception { var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED); var sdkByes = payload.requestBytes(mockModel(), request); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8602621e9eb78..72a3b530ab647 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -718,7 +718,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -763,7 +763,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept "voyage-3-large" ); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -806,7 +806,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = VoyageAIEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -831,7 +831,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -873,7 +873,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -907,7 +907,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -989,7 +989,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1071,7 +1071,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1163,7 +1163,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1251,7 +1251,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1345,7 +1345,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1423,7 +1423,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1490,7 +1490,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1599,7 +1599,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1745,7 +1745,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1786,7 +1786,7 @@ private Map getRequestConfigMap(Map serviceSetti } private VoyageAIService createVoyageAIService() { - return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/45_semantic_text_match.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/45_semantic_text_match.yml index 28093ba49e6cc..3898eb7de7c29 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/45_semantic_text_match.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/45_semantic_text_match.yml @@ -277,3 +277,126 @@ setup: query: "inference test" - match: { hits.total.value: 0 } + +--- +"Apply boost and query name on single index": + - requires: + cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + reason: fix boosting and query name for semantic text match queries. + + - skip: + features: [ "headers", "close_to" ] + + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: [ "It was a beautiful game", "Very competitive" ] + non_inference_field: "non inference test" + refresh: true + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + match: + inference_field: + query: "soccer" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 5.700229E18, error: 1e15 } } + - not_exists: hits.hits.0.matched_queries + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + match: + inference_field: + query: "soccer" + boost: 5.0 + _name: i-like-naming-my-queries + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 2.8501142E19, error: 1e16 } } + - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } + +--- +"Apply boost and query name on multiple indices": + - requires: + cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + reason: fix boosting and query name for semantic text match queries. + + - skip: + features: [ "headers", "close_to" ] + + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: [ "It was a beautiful game", "Very competitive" ] + non_inference_field: "non inference test" + refresh: true + + - do: + index: + index: test-text-only-index + id: doc_2 + body: + inference_field: [ "It was a beautiful game", "Very competitive" ] + non_inference_field: "non inference test" + refresh: true + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index,test-text-only-index + body: + query: + match: + inference_field: + query: "beautiful" + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + - close_to: { hits.hits.0._score: { value: 1.1140361E19, error: 1e16 } } + - not_exists: hits.hits.0.matched_queries + - close_to: { hits.hits.1._score: { value: 0.2876821, error: 1e-4 } } + - not_exists: hits.hits.1.matched_queries + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index,test-text-only-index + body: + query: + match: + inference_field: + query: "beautiful" + boost: 5.0 + _name: i-like-naming-my-queries + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + - close_to: { hits.hits.0._score: { value: 5.5701804E19, error: 1e16 } } + - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } + - close_to: { hits.hits.1._score: { value: 1.4384103, error: 1e-4 } } + - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml index f1cff512fd209..cc67b9235f0b4 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml @@ -247,3 +247,100 @@ setup: - match: { hits.total.value: 2 } +--- +"Apply boost and query name on single index": + - requires: + cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + reason: fix boosting and query name for semantic text sparse vector queries. + + - skip: + features: [ "headers", "close_to" ] + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332E17, error: 1e14 } } + - not_exists: hits.hits.0.matched_queries + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + boost: 5.0 + _name: i-like-naming-my-queries + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.8918664E18, error: 1e15 } } + - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } + +--- +"Apply boost and query name on multiple indices": + - requires: + cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + reason: fix boosting and query name for semantic text sparse vector queries. + + - skip: + features: [ "headers", "close_to" ] + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index,test-sparse-vector-index + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + inference_id: sparse-inference-id + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + - close_to: { hits.hits.0._score: { value: 3.7837332E17, error: 1e14 } } + - not_exists: hits.hits.0.matched_queries + - close_to: { hits.hits.1._score: { value: 7.314424E8, error: 1e5 } } + - not_exists: hits.hits.1.matched_queries + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index,test-sparse-vector-index + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + inference_id: sparse-inference-id + boost: 5.0 + _name: i-like-naming-my-queries + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } + - close_to: { hits.hits.0._score: { value: 1.8918664E18, error: 1e15 } } + - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } + - close_to: { hits.hits.1._score: { value: 3.657212E9, error: 1e6 } } + - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml index 64ecb0f2d882c..d49e3a63848e3 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml @@ -404,4 +404,116 @@ setup: - match: { hits.total.value: 4 } +--- +"Apply boost and query name on single index": + - requires: + cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + reason: fix boosting and query name for semantic text knn queries. + + - skip: + features: [ "headers", "close_to" ] + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index + body: + query: + knn: + field: inference_field + k: 2 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 0.9990483, error: 1e-4 } } + - not_exists: hits.hits.0.matched_queries + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index + body: + query: + knn: + field: inference_field + k: 2 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + boost: 5.0 + _name: i-like-naming-my-queries + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 4.9952416, error: 1e-3 } } + - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } +--- +"Apply boost and query name on multiple indices": + - requires: + cluster_features: "semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix" + reason: fix boosting and query name for semantic text knn queries. + + - skip: + features: [ "headers", "close_to" ] + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index,test-dense-vector-index + body: + query: + knn: + field: inference_field + k: 2 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + model_id: dense-inference-id + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_3" } + - close_to: { hits.hits.0._score: { value: 0.9990483, error: 1e-4 } } + - not_exists: hits.hits.0.matched_queries + - close_to: { hits.hits.1._score: { value: 0.9439374, error: 1e-4 } } + - not_exists: hits.hits.1.matched_queries + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index,test-dense-vector-index + body: + query: + knn: + field: inference_field + k: 2 + num_candidates: 100 + query_vector_builder: + text_embedding: + model_text: test + model_id: dense-inference-id + boost: 5.0 + _name: i-like-naming-my-queries + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_3" } + - close_to: { hits.hits.0._score: { value: 4.9952416, error: 1e-3 } } + - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] } + - close_to: { hits.hits.1._score: { value: 4.719687, error: 1e-3 } } + - match: { hits.hits.1.matched_queries: [ "i-like-naming-my-queries" ] } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml index 021dfe320d78e..60dea800ca624 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml @@ -35,6 +35,23 @@ setup: } } + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id-compatible-with-bbq + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 64, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: indices.create: index: test-sparse-index @@ -70,7 +87,7 @@ setup: id: doc_1 body: title: "Elasticsearch" - body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] refresh: true - do: @@ -89,14 +106,14 @@ setup: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: - another_body: {} + another_body: { } - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } - - not_exists: hits.hits.0.highlight.another_body + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - not_exists: hits.hits.0.highlight.another_body --- "Highlighting using a sparse embedding model": @@ -114,10 +131,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -133,11 +150,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -154,10 +171,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -196,10 +213,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -215,11 +232,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -236,10 +253,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -256,17 +273,17 @@ setup: order: "score" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } --- "Default highlighter for fields": - requires: - cluster_features: "semantic_text.highlighter.default" - reason: semantic text field defaults to the semantic highlighter + cluster_features: "semantic_text.highlighter.default" + reason: semantic text field defaults to the semantic highlighter - do: search: @@ -281,11 +298,11 @@ setup: order: "score" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } --- "semantic highlighter ignores non-inference fields": @@ -306,8 +323,8 @@ setup: type: semantic number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - not_exists: hits.hits.0.highlight.title --- @@ -333,7 +350,7 @@ setup: index: test-multi-chunk-index id: doc_1 body: - semantic_text_field: ["some test data", " ", "now with chunks"] + semantic_text_field: [ "some test data", " ", "now with chunks" ] refresh: true - do: @@ -367,25 +384,25 @@ setup: index: test-sparse-index body: query: - match_all: {} + match_all: { } highlight: fields: body: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: body: @@ -432,18 +449,18 @@ setup: index: test-index-sparse body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.semantic_text_field: 2 } - - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } + - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } + - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } - do: indices.create: @@ -473,7 +490,7 @@ setup: index: test-index-dense body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: @@ -485,3 +502,172 @@ setup: - length: { hits.hits.0.highlight.semantic_text_field: 2 } - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } + +--- +"Highlighting with flat quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-flat + body: + settings: + index.mapping.semantic_text.use_legacy_format: false + mappings: + properties: + flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: flat + int4_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_flat + int8_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_flat + bbq_flat_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_flat + + + - do: + index: + index: test-dense-index-flat + id: doc_1 + body: + flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-flat + body: + query: + match_all: { } + highlight: + fields: + flat_field: + type: "semantic" + number_of_fragments: 1 + int4_flat_field: + type: "semantic" + number_of_fragments: 1 + int8_flat_field: + type: "semantic" + number_of_fragments: 1 + bbq_flat_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.flat_field: 1 } + - match: { hits.hits.0.highlight.flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_flat_field: 1 } + - match: { hits.hits.0.highlight.int4_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_flat_field: 1 } + - match: { hits.hits.0.highlight.int8_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_flat_field: 1 } + - match: { hits.hits.0.highlight.bbq_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + +--- +"Highlighting with HNSW quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-hnsw + body: + settings: + index.mapping.semantic_text.use_legacy_format: false + mappings: + properties: + hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: hnsw + int4_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_hnsw + int8_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_hnsw + bbq_hnsw_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_hnsw + + + - do: + index: + index: test-dense-index-hnsw + id: doc_1 + body: + hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-hnsw + body: + query: + match_all: { } + highlight: + fields: + hnsw_field: + type: "semantic" + number_of_fragments: 1 + int4_hnsw_field: + type: "semantic" + number_of_fragments: 1 + int8_hnsw_field: + type: "semantic" + number_of_fragments: 1 + bbq_hnsw_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.hnsw_field: 1 } + - match: { hits.hits.0.highlight.hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int4_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int8_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_hnsw_field: 1 } + - match: { hits.hits.0.highlight.bbq_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml index 1e874d60a016c..4675977842973 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml @@ -35,6 +35,23 @@ setup: } } + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id-compatible-with-bbq + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 64, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: indices.create: index: test-sparse-index @@ -65,12 +82,12 @@ setup: --- "Highlighting empty field": - do: - index: - index: test-dense-index - id: doc_1 - body: - body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] - refresh: true + index: + index: test-dense-index + id: doc_1 + body: + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true - match: { result: created } @@ -79,14 +96,14 @@ setup: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: - another_body: {} + another_body: { } - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } - - not_exists: hits.hits.0.highlight.another_body + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - not_exists: hits.hits.0.highlight.another_body --- "Highlighting using a sparse embedding model": @@ -95,7 +112,7 @@ setup: index: test-sparse-index id: doc_1 body: - body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] refresh: true - match: { result: created } @@ -114,10 +131,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -133,11 +150,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -154,10 +171,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -187,7 +204,7 @@ setup: index: test-dense-index id: doc_1 body: - body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] refresh: true - match: { result: created } @@ -206,10 +223,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -225,11 +242,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -246,10 +263,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -266,11 +283,11 @@ setup: order: "score" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } --- "Highlighting and multi chunks with empty input": @@ -295,7 +312,7 @@ setup: index: test-multi-chunk-index id: doc_1 body: - semantic_text_field: ["some test data", " ", "now with chunks"] + semantic_text_field: [ "some test data", " ", "now with chunks" ] refresh: true - do: @@ -337,18 +354,18 @@ setup: index: test-sparse-index body: query: - match_all: {} + match_all: { } highlight: fields: body: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: index: @@ -363,7 +380,7 @@ setup: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: body: @@ -410,18 +427,18 @@ setup: index: test-index-sparse body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.semantic_text_field: 2 } - - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } + - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } + - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } - do: indices.create: @@ -451,7 +468,7 @@ setup: index: test-index-dense body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: @@ -464,3 +481,173 @@ setup: - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } +--- +"Highlighting with flat quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-flat + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: flat + int4_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_flat + int8_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_flat + bbq_flat_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_flat + + + - do: + index: + index: test-dense-index-flat + id: doc_1 + body: + flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-flat + body: + query: + match_all: { } + highlight: + fields: + flat_field: + type: "semantic" + number_of_fragments: 1 + int4_flat_field: + type: "semantic" + number_of_fragments: 1 + int8_flat_field: + type: "semantic" + number_of_fragments: 1 + bbq_flat_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.flat_field: 1 } + - match: { hits.hits.0.highlight.flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_flat_field: 1 } + - match: { hits.hits.0.highlight.int4_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_flat_field: 1 } + - match: { hits.hits.0.highlight.int8_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_flat_field: 1 } + - match: { hits.hits.0.highlight.bbq_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + +--- +"Highlighting with HNSW quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-hnsw + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: hnsw + int4_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_hnsw + int8_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_hnsw + bbq_hnsw_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_hnsw + + + - do: + index: + index: test-dense-index-hnsw + id: doc_1 + body: + hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-hnsw + body: + query: + match_all: { } + highlight: + fields: + hnsw_field: + type: "semantic" + number_of_fragments: 1 + int4_hnsw_field: + type: "semantic" + number_of_fragments: 1 + int8_hnsw_field: + type: "semantic" + number_of_fragments: 1 + bbq_hnsw_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.hnsw_field: 1 } + - match: { hits.hits.0.highlight.hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int4_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int8_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_hnsw_field: 1 } + - match: { hits.hits.0.highlight.bbq_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + + diff --git a/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java b/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java index 5eafb858eacbe..3658313642700 100644 --- a/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java +++ b/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java @@ -571,20 +571,24 @@ public String toString() { } @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - try (var builder = factory.aggregateMetricDoubleBuilder(docs.count())) { - copyDoubleValuesToBuilder(docs, builder.min(), minValues); - copyDoubleValuesToBuilder(docs, builder.max(), maxValues); - copyDoubleValuesToBuilder(docs, builder.sum(), sumValues); - copyIntValuesToBuilder(docs, builder.count(), valueCountValues); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (var builder = factory.aggregateMetricDoubleBuilder(docs.count() - offset)) { + copyDoubleValuesToBuilder(docs, offset, builder.min(), minValues); + copyDoubleValuesToBuilder(docs, offset, builder.max(), maxValues); + copyDoubleValuesToBuilder(docs, offset, builder.sum(), sumValues); + copyIntValuesToBuilder(docs, offset, builder.count(), valueCountValues); return builder.build(); } } - private void copyDoubleValuesToBuilder(Docs docs, BlockLoader.DoubleBuilder builder, NumericDocValues values) - throws IOException { + private void copyDoubleValuesToBuilder( + Docs docs, + int offset, + BlockLoader.DoubleBuilder builder, + NumericDocValues values + ) throws IOException { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -600,10 +604,10 @@ private void copyDoubleValuesToBuilder(Docs docs, BlockLoader.DoubleBuilder buil } } - private void copyIntValuesToBuilder(Docs docs, BlockLoader.IntBuilder builder, NumericDocValues values) + private void copyIntValuesToBuilder(Docs docs, int offset, BlockLoader.IntBuilder builder, NumericDocValues values) throws IOException { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); diff --git a/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java b/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java index c0c2db53b97e9..1a94ca1b8d40a 100644 --- a/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java +++ b/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java @@ -276,7 +276,7 @@ public FieldNamesFieldMapper.FieldNamesFieldType fieldNames() { iw.close(); try (DirectoryReader reader = DirectoryReader.open(directory)) { TestBlock block = (TestBlock) loader.columnAtATimeReader(reader.leaves().get(0)) - .read(TestBlock.factory(reader.numDocs()), new BlockLoader.Docs() { + .read(TestBlock.factory(), new BlockLoader.Docs() { @Override public int count() { return 1; @@ -286,7 +286,7 @@ public int count() { public int get(int i) { return 0; } - }); + }, 0); assertThat(block.get(0), nullValue()); } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java index e15a1d36bdb9f..dad16d3cfa83b 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java @@ -9,10 +9,12 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.indices.SystemIndices; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.PersistentTaskParams; @@ -86,10 +88,11 @@ protected AllocatedPersistentTask createTask( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( SystemIndexMigrationTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { // This should select from master-eligible nodes because we already require all master-eligible nodes to have all plugins installed. // However, due to a misunderstanding, this code as-written needs to run on the master node in particular. This is not a fundamental diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index a15a733cac6c7..be905caeacba0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -22,6 +22,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -30,6 +31,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.injection.guice.Inject; @@ -690,10 +692,11 @@ protected AllocatedPersistentTask createTask( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( TaskParams params, Collection candidateNodes, - @SuppressWarnings("HiddenField") ClusterState clusterState + @SuppressWarnings("HiddenField") ClusterState clusterState, + @Nullable ProjectId projectId ) { boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed(); Optional optionalAssignment = getPotentialAssignment( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java index f45c92d3466c6..7a636e18017e1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -494,10 +495,11 @@ public StartDatafeedPersistentTasksExecutor( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( StartDatafeedAction.DatafeedParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { return new DatafeedNodeSelector( clusterState, @@ -510,7 +512,7 @@ public PersistentTasksCustomMetadata.Assignment getAssignment( } @Override - public void validate(StartDatafeedAction.DatafeedParams params, ClusterState clusterState) { + public void validate(StartDatafeedAction.DatafeedParams params, ClusterState clusterState, @Nullable ProjectId projectId) { new DatafeedNodeSelector( clusterState, resolver, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java index 9862c82fcec6f..04a6396f5d088 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java @@ -9,19 +9,27 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.client.internal.ParentTaskAssigningClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -30,7 +38,7 @@ import java.util.Objects; -import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; public class TransportUpdateTrainedModelDeploymentAction extends TransportMasterNodeAction< UpdateTrainedModelDeploymentAction.Request, @@ -41,6 +49,7 @@ public class TransportUpdateTrainedModelDeploymentAction extends TransportMaster private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService; private final InferenceAuditor auditor; private final ProjectResolver projectResolver; + private final Client client; @Inject public TransportUpdateTrainedModelDeploymentAction( @@ -50,7 +59,8 @@ public TransportUpdateTrainedModelDeploymentAction( ActionFilters actionFilters, TrainedModelAssignmentClusterService trainedModelAssignmentClusterService, InferenceAuditor auditor, - ProjectResolver projectResolver + ProjectResolver projectResolver, + Client client ) { super( UpdateTrainedModelDeploymentAction.NAME, @@ -65,6 +75,7 @@ public TransportUpdateTrainedModelDeploymentAction( this.trainedModelAssignmentClusterService = Objects.requireNonNull(trainedModelAssignmentClusterService); this.auditor = Objects.requireNonNull(auditor); this.projectResolver = Objects.requireNonNull(projectResolver); + this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN); } @Override @@ -75,25 +86,70 @@ protected void masterOperation( ActionListener listener ) throws Exception { logger.debug( - () -> format( - "[%s] received request to update number of allocations to [%s]", - request.getDeploymentId(), - request.getNumberOfAllocations() - ) + "[{}] received request to update number of allocations to [{}]", + request.getDeploymentId(), + request.getNumberOfAllocations() ); + checkIfUsedByDefaultInferenceEndpoint(task, request, listener.delegateFailureAndWrap((l, unused) -> updateDeployment(request, l))); + } + + private void checkIfUsedByDefaultInferenceEndpoint( + Task task, + UpdateTrainedModelDeploymentAction.Request request, + ActionListener listener + ) { + if (request.isInternal()) { + listener.onResponse(null); + return; + } + + var deploymentId = request.getDeploymentId(); + var parentClient = new ParentTaskAssigningClient(client, clusterService.localNode(), task); + var getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY); + if (request.ackTimeout() != null) { + getAllEndpoints.ackTimeout(request.ackTimeout()); + } + + // if this deployment was created by an inference endpoint, then it must be updated by the inference endpoint _update API + parentClient.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> { + response.getEndpoints() + .stream() + .filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser")) + .map(ModelConfigurations::getInferenceEntityId) + .filter(deploymentId::equals) + .findAny() + .ifPresentOrElse( + endpointId -> l.onFailure( + new ElasticsearchStatusException( + "Cannot update deployment [{}] as it was created by inference endpoint [{}]. " + + "This model deployment must be updated through the inference API.", + RestStatus.CONFLICT, + deploymentId, + endpointId + ) + ), + () -> l.onResponse(null) + ); + })); + } + + private void updateDeployment( + UpdateTrainedModelDeploymentAction.Request request, + ActionListener listener + ) { trainedModelAssignmentClusterService.updateDeployment( request.getDeploymentId(), request.getNumberOfAllocations(), request.getAdaptiveAllocationsSettings(), request.isInternal(), - ActionListener.wrap(updatedAssignment -> { + listener.delegateFailureAndWrap((l, updatedAssignment) -> { auditor.info( request.getDeploymentId(), Messages.getMessage(Messages.INFERENCE_DEPLOYMENT_UPDATED_NUMBER_OF_ALLOCATIONS, request.getNumberOfAllocations()) ); listener.onResponse(new CreateTrainedModelAssignmentAction.Response(updatedAssignment)); - }, listener::onFailure) + }) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java index 5686f3734f36e..47e20d1a56bf2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java @@ -162,6 +162,13 @@ public BytesRef[] getKeyTokens() { return Arrays.stream(keyTokenIndexes).mapToObj(index -> baseTokens[index]).toArray(BytesRef[]::new); } + public String getKeyTokensString() { + return Arrays.stream(keyTokenIndexes) + .mapToObj(index -> baseTokens[index]) + .map(BytesRef::utf8ToString) + .collect(Collectors.joining(" ")); + } + public String getRegex() { if (keyTokenIndexes.length == 0 || orderedCommonTokenBeginIndex == orderedCommonTokenEndIndex) { return ".*"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java index 42f722e330a19..00370dde3e089 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java @@ -15,9 +15,11 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.PersistentTaskState; @@ -88,10 +90,11 @@ public SnapshotUpgradeTaskExecutor( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( SnapshotUpgradeTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed(); Optional optionalAssignment = getPotentialAssignment( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java index 0e517b63f6f60..5621da489da7d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java @@ -17,9 +17,11 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.license.XPackLicenseState; @@ -121,7 +123,12 @@ public OpenJobPersistentTasksExecutor( } @Override - public Assignment getAssignment(OpenJobAction.JobParams params, Collection candidateNodes, ClusterState clusterState) { + protected Assignment doGetAssignment( + OpenJobAction.JobParams params, + Collection candidateNodes, + ClusterState clusterState, + @Nullable ProjectId projectId + ) { Job job = params.getJob(); // If the task parameters do not have a job field then the job // was first opened on a pre v6.6 node and has not been migrated @@ -210,13 +217,13 @@ static void validateJobAndId(String jobId, Job job) { } @Override - public void validate(OpenJobAction.JobParams params, ClusterState clusterState) { + public void validate(OpenJobAction.JobParams params, ClusterState clusterState, @Nullable ProjectId projectId) { final Job job = params.getJob(); final String jobId = params.getJobId(); validateJobAndId(jobId, job); // If we already know that we can't find an ml node because all ml nodes are running at capacity or // simply because there are no ml nodes in the cluster then we fail quickly here: - PersistentTasksCustomMetadata.Assignment assignment = getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + var assignment = getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, projectId); if (assignment.equals(AWAITING_UPGRADE)) { throw makeCurrentlyBeingUpgradedException(logger, params.getJobId()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java index 33fae40f80db6..550352954bfbc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; @@ -62,7 +63,7 @@ public void testGetAssignment_UpgradeModeIsEnabled() { .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build())) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(nullValue())); assertThat(assignment.getExplanation(), is(equalTo("persistent task cannot be assigned while upgrade mode is enabled."))); } @@ -75,7 +76,7 @@ public void testGetAssignment_NoNodes() { .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build())) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(nullValue())); assertThat(assignment.getExplanation(), is(emptyString())); } @@ -94,7 +95,7 @@ public void testGetAssignment_NoMlNodes() { ) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(nullValue())); assertThat( assignment.getExplanation(), @@ -116,7 +117,7 @@ public void testGetAssignment_MlNodeIsNewerThanTheMlJobButTheAssignmentSuceeds() .nodes(DiscoveryNodes.builder().add(createNode(0, true, Version.V_7_10_0, MlConfigVersion.V_7_10_0))) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(equalTo("_node_id0"))); assertThat(assignment.getExplanation(), is(emptyString())); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java index d88e1235241d8..4b1ed557ef287 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.OperationRouting; @@ -173,7 +174,7 @@ public void testGetAssignment_GivenUnavailableIndicesWithLazyNode() { assertEquals( "Not opening [unavailable_index_with_lazy_node], " + "because not all primary shards are active for the following indices [.ml-state]", - executor.getAssignment(params, csBuilder.nodes().getAllNodes(), csBuilder.build()).getExplanation() + executor.getAssignment(params, csBuilder.nodes().getAllNodes(), csBuilder.build(), ProjectId.DEFAULT).getExplanation() ); } @@ -195,7 +196,8 @@ public void testGetAssignment_GivenLazyJobAndNoGlobalLazyNodes() { PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment( params, csBuilder.nodes().getAllNodes(), - csBuilder.build() + csBuilder.build(), + ProjectId.DEFAULT ); assertNotNull(assignment); assertNull(assignment.getExecutorNode()); @@ -216,7 +218,8 @@ public void testGetAssignment_GivenResetInProgress() { PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment( params, csBuilder.nodes().getAllNodes(), - csBuilder.build() + csBuilder.build(), + ProjectId.DEFAULT ); assertNotNull(assignment); assertNull(assignment.getExecutorNode()); diff --git a/x-pack/plugin/otel-data/src/main/resources/component-templates/otel@settings.yaml b/x-pack/plugin/otel-data/src/main/resources/component-templates/otel@settings.yaml new file mode 100644 index 0000000000000..0d3b1b25a7ea9 --- /dev/null +++ b/x-pack/plugin/otel-data/src/main/resources/component-templates/otel@settings.yaml @@ -0,0 +1,8 @@ +version: ${xpack.oteldata.template.version} +_meta: + description: Default settings for all OpenTelemetry data streams + managed: true +template: + data_stream_options: + failure_store: + enabled: true diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml index 6772ec5bc65d4..929d26e1c30af 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml @@ -11,6 +11,7 @@ composed_of: - logs@mappings - logs@settings - otel@mappings + - otel@settings - logs-otel@mappings - semconv-resource-to-ecs@mappings - logs@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml index f8489605ad1bf..a042fc77e6fa3 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml index f5033135120bc..60739559cc9eb 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml index 9168062f30bfb..9464936f5e1e5 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml index 47c2d7d014322..888a2145073fd 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml index c9438e8c27402..36be8cb78d851 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml index b29caa3fe34a7..20d1e3ca65e88 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml index 4cab3e41a1dfa..9bb62ae9edd3b 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml index 037f3546205d6..ff4780744e216 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml index 303ac2c406fd0..b1037535754f3 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml index ea42079ced4dd..15088a2198abc 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml index 81e70cc3361fc..2f6f7e28ffc22 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml index c54b90bf8b683..5cc1828d3285b 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml index 8afe8b87951c0..906e535e2c05b 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml index 370b9351c16f5..c2e9a68bc72ad 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml @@ -11,6 +11,7 @@ composed_of: - traces@mappings - traces@settings - otel@mappings + - otel@settings - traces-otel@mappings - semconv-resource-to-ecs@mappings - traces@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/resources.yaml b/x-pack/plugin/otel-data/src/main/resources/resources.yaml index 6aadfde1683dc..608dc369c34eb 100644 --- a/x-pack/plugin/otel-data/src/main/resources/resources.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/resources.yaml @@ -1,10 +1,11 @@ # "version" holds the version of the templates and ingest pipelines installed # by xpack-plugin otel-data. This must be increased whenever an existing template is # changed, in order for it to be updated on Elasticsearch upgrade. -version: 9 +version: 10 component-templates: - otel@mappings + - otel@settings - logs-otel@mappings - semconv-resource-to-ecs@mappings - metrics-otel@mappings diff --git a/x-pack/plugin/otel-data/src/yamlRestTest/resources/rest-api-spec/test/20_logs_failure_store_test.yml b/x-pack/plugin/otel-data/src/yamlRestTest/resources/rest-api-spec/test/20_logs_failure_store_test.yml new file mode 100644 index 0000000000000..dfc6d0fc050b0 --- /dev/null +++ b/x-pack/plugin/otel-data/src/yamlRestTest/resources/rest-api-spec/test/20_logs_failure_store_test.yml @@ -0,0 +1,73 @@ +--- +setup: + - do: + cluster.health: + wait_for_events: languid +--- +teardown: + - do: + indices.delete_data_stream: + name: logs-generic.otel-default + ignore: 404 +--- +"Test logs-*.otel-* data streams have failure store enabled by default": + # Index a valid document (string message). + - do: + index: + index: logs-generic.otel-default + refresh: true + body: + '@timestamp': '2023-01-01T12:00:00Z' + severity_text: "INFO" + text: "Application started successfully" + - match: { result: created } + + # Assert empty failure store. + - do: + indices.get_data_stream: + name: logs-generic.otel-default + - match: { data_streams.0.name: logs-generic.otel-default } + - length: { data_streams.0.indices: 1 } + - match: { data_streams.0.failure_store.enabled: true } + - length: { data_streams.0.failure_store.indices: 0 } + + # Index a document with naming alias, causing an error. + - do: + index: + index: logs-generic.otel-default + refresh: true + body: + '@timestamp': '2023-01-01T12:01:00Z' + severity_text: "ERROR" + message: "Application started successfully" + - match: { result: 'created' } + - match: { failure_store: used} + + # Assert failure store containing 1 item. + - do: + indices.get_data_stream: + name: logs-generic.otel-default + - length: { data_streams.0.failure_store.indices: 1 } + + # Assert valid document. + - do: + search: + index: logs-generic.otel-default::data + body: + query: + match_all: {} + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.severity_text: "INFO" } + - match: { hits.hits.0._source.text: "Application started successfully" } + + # Assert invalid document. + - do: + search: + index: logs-generic.otel-default::failures + body: + query: + match_all: {} + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.document.source.severity_text: "ERROR" } + - match: { hits.hits.0._source.document.source.message: "Application started successfully" } + - match: { hits.hits.0._source.error.type: "document_parsing_exception" } diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponse.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponse.java index 96b434914ca75..3af4a4bc1ec12 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponse.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponse.java @@ -32,6 +32,7 @@ public class GetStackTracesResponse extends ActionResponse implements ChunkedToX private final Map stackTraceEvents; private final int totalFrames; private final double samplingRate; + private final double samplingFrequency; private final long totalSamples; public GetStackTracesResponse( @@ -42,6 +43,28 @@ public GetStackTracesResponse( int totalFrames, double samplingRate, long totalSamples + ) { + this( + stackTraces, + stackFrames, + executables, + stackTraceEvents, + totalFrames, + samplingRate, + totalSamples, + TransportGetStackTracesAction.DEFAULT_SAMPLING_FREQUENCY + ); + } + + public GetStackTracesResponse( + Map stackTraces, + Map stackFrames, + Map executables, + Map stackTraceEvents, + int totalFrames, + double samplingRate, + long totalSamples, + double samplingFrequency ) { this.stackTraces = stackTraces; this.stackFrames = stackFrames; @@ -50,6 +73,7 @@ public GetStackTracesResponse( this.totalFrames = totalFrames; this.samplingRate = samplingRate; this.totalSamples = totalSamples; + this.samplingFrequency = samplingFrequency; } @Override @@ -101,7 +125,7 @@ public Iterator toXContentChunked(ToXContent.Params params Iterators.map(v.entrySet().iterator(), e -> (b, p) -> b.field(e.getKey().stacktraceID(), e.getValue().count)) ) ), - Iterators.single((b, p) -> b.field("sampling_rate", samplingRate).endObject()) + Iterators.single((b, p) -> b.field("sampling_rate", samplingRate).field("sampling_frequency", samplingFrequency).endObject()) // the following fields are intentionally not written to the XContent representation (only needed on the transport layer): // // * start @@ -129,6 +153,7 @@ public boolean equals(Object o) { GetStackTracesResponse response = (GetStackTracesResponse) o; return totalFrames == response.totalFrames && samplingRate == response.samplingRate + && samplingFrequency == response.samplingFrequency && Objects.equals(stackTraces, response.stackTraces) && Objects.equals(stackFrames, response.stackFrames) && Objects.equals(executables, response.executables) @@ -137,6 +162,6 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(stackTraces, stackFrames, executables, stackTraceEvents, totalFrames, samplingRate); + return Objects.hash(stackTraces, stackFrames, executables, stackTraceEvents, totalFrames, samplingRate, samplingFrequency); } } diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseBuilder.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseBuilder.java index d014fe72ebe37..4ddbd5f32f854 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseBuilder.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseBuilder.java @@ -82,6 +82,14 @@ public double getSamplingRate() { return samplingRate; } + public void setSamplingFrequency(long samplingFrequency) { + this.samplingFrequency = samplingFrequency; + } + + public long getSamplingFrequency() { + return samplingFrequency; + } + public void setRequestedDuration(Double requestedDuration) { this.requestedDuration = requestedDuration; } @@ -154,14 +162,15 @@ public GetStackTracesResponse build() { } } } - return new GetStackTracesResponse(stackTraces, stackFrames, executables, stackTraceEvents, totalFrames, samplingRate, totalSamples); - } - - public void setSamplingFrequency(long samplingFrequency) { - this.samplingFrequency = samplingFrequency; - } - - public long getSamplingFrequency() { - return samplingFrequency; + return new GetStackTracesResponse( + stackTraces, + stackFrames, + executables, + stackTraceEvents, + totalFrames, + samplingRate, + totalSamples, + samplingFrequency + ); } } diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java index bc554302ce8e7..852cea9c5e054 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java @@ -400,10 +400,10 @@ private void searchEventGroupedByStackTrace( StopWatch watch = new StopWatch("createStackTraceEvents"); SingleBucketAggregation sample = searchResponse.getAggregations().get("sample"); InternalComposite stacktraces = sample.getAggregations().get("group_by"); + RandomGenerator rng = new Random(rngSeed); long indexSamplingFactor = Math.round(1 / eventsIndex.getSampleRate()); // for example, 5^2 = 25 for profiling-events-5pow02 int bucketCount = stacktraces.getBuckets().size(); long eventCount = sample.getDocCount(); - AtomicLong upscaledEventCount = new AtomicLong(eventCount * indexSamplingFactor); long maxSamplingFrequency = getAggValueAsLong(searchResponse, "max_freq") <= 0 ? (long) DEFAULT_SAMPLING_FREQUENCY : getAggValueAsLong(searchResponse, "max_freq"); @@ -413,13 +413,35 @@ private void searchEventGroupedByStackTrace( eventCount, bucketCount, indexSamplingFactor, - upscaledEventCount.get() + eventCount * indexSamplingFactor ); + // Since the random sampler aggregation does not support sampling rates between 0.5 and 1.0, + // we can have up to 2x more events in the response as requested by the user. + // In order to reduce latency for stacktrace and stackframe lookups, we add another sampling factor + // to reduce the number of events to match the user request (which reduces the number of unique stacktrace ids). + boolean needAdditionalDownsampling = eventCount > request.getSampleSize(); + double downSamplingRate = needAdditionalDownsampling + ? (double) request.getSampleSize() / eventCount + : responseBuilder.getSamplingRate(); + + eventCount = 0; boolean mixedFrequency = false; Map stackTraceEvents = new HashMap<>(bucketCount); for (InternalComposite.InternalBucket stacktraceBucket : stacktraces.getBuckets()) { - long count = stacktraceBucket.getDocCount() * indexSamplingFactor; + long sampledCount; + if (needAdditionalDownsampling) { + sampledCount = downsampleEvents(rng, downSamplingRate, stacktraceBucket.getDocCount()); + if (sampledCount <= 0) { + bucketCount--; + continue; // skip bucket + } + } else { + sampledCount = stacktraceBucket.getDocCount(); + } + + long count = roundWithRandom((sampledCount * indexSamplingFactor) / downSamplingRate, rng); + eventCount += sampledCount; TraceEventID eventID = getTraceEventID(stacktraceBucket); stackTraceEvents.compute(eventID, (k, event) -> { @@ -450,46 +472,52 @@ private void searchEventGroupedByStackTrace( } } + AtomicLong upscaledEventCount = new AtomicLong(eventCount * indexSamplingFactor); if (mixedFrequency) { - RandomGenerator r = new Random(rngSeed); upscaledEventCount.set(0); // Events have different frequencies. - // Now upscale the count values to the max sampling frequency, - // also taking into account the stratified downsampling factor (5, 25, 125, etc.). + // Scale the count up to the maximum sampling frequency. stackTraceEvents.forEach((eventID, event) -> { - if (eventID.samplingFrequency() == maxSamplingFrequency) { - upscaledEventCount.addAndGet(event.count); - return; // no need to upscale + if (eventID.samplingFrequency() != maxSamplingFrequency) { + double samplingFactor = maxSamplingFrequency / eventID.samplingFrequency(); + event.count = roundWithRandom(event.count * samplingFactor, rng); } - - // Use randomization, to avoid a systematic rounding issue that would happen - // if we naively do `event.count = Math.round(event.count * samplingFactor)`. - // For example, think of event.count = 1 and samplingFactor = 1.4: the naive approach would not change anything. - double samplingFactor = maxSamplingFrequency / eventID.samplingFrequency(); - double newCount = event.count * samplingFactor; - long integerPart = (long) newCount; - double fractionalPart = newCount - integerPart; - event.count = integerPart + (r.nextDouble() < fractionalPart ? 1 : 0); upscaledEventCount.addAndGet(event.count); }); } log.debug(watch::report); + responseBuilder.setSamplingRate(downSamplingRate); responseBuilder.setSamplingFrequency(maxSamplingFrequency); responseBuilder.setTotalSamples(upscaledEventCount.get()); - log.debug( - "Found [{}] events in [{}] buckets, upscaled to [{}] events).", - eventCount, - bucketCount, - upscaledEventCount.get() - ); + log.debug("Use [{}] events in [{}] buckets, upscaled to [{}] events).", eventCount, bucketCount, upscaledEventCount.get()); return stackTraceEvents; })); } + private static long roundWithRandom(double value, RandomGenerator r) { + // Use randomization, to avoid a systematic rounding issue that would happen + // if we naively do `Math.round(value)`. + // For example, think of rounding value = 1.4: the naive approach would always drop 0.4 and return 1. + long integerPart = (long) value; + double fractionalPart = value - integerPart; + return integerPart + (r.nextDouble() < fractionalPart ? 1 : 0); + } + + private static long downsampleEvents(RandomGenerator r, double samplingRate, long count) { + // Downsampling needs to be applied to each event individually. + long sampledCount = 0; + for (long i = 0; i < count; i++) { + if (r.nextDouble() < samplingRate) { + sampledCount++; + } + } + return sampledCount; + } + private static TraceEventID getTraceEventID(InternalComposite.InternalBucket stacktraceBucket) { Map key = stacktraceBucket.getKey(); Object samplingFrequency = key.get("freq"); diff --git a/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseTests.java b/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseTests.java index c8342f82edb11..74dc21c335aad 100644 --- a/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseTests.java +++ b/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseTests.java @@ -58,7 +58,7 @@ private GetStackTracesResponse createTestInstance() { public void testChunking() { AbstractChunkedSerializingTestCase.assertChunkCount(createTestInstance(), instance -> { - // start and {sampling_rate; end}; see GetStackTracesResponse.toXContentChunked() + // start and {sampling_rate; sampling_freq; end}; see GetStackTracesResponse.toXContentChunked() int chunks = 2; chunks += size(instance.getExecutables()); chunks += size(instance.getStackFrames()); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 1ff58f4e64078..5b22da00173ad 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -73,7 +73,6 @@ public class Constants { "cluster:admin/scripts/painless/context", "cluster:admin/scripts/painless/execute", "cluster:admin/streams/logs/toggle", - "cluster:admin/streams/status", "cluster:admin/synonyms/delete", "cluster:admin/synonyms/get", "cluster:admin/synonyms/put", @@ -373,6 +372,7 @@ public class Constants { "cluster:monitor/settings", "cluster:monitor/state", "cluster:monitor/stats", + "cluster:monitor/streams/status", "cluster:monitor/task", "cluster:monitor/task/get", "cluster:monitor/tasks/lists", diff --git a/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/ssl/SslEntitlementRestIT.java b/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/ssl/SslEntitlementRestIT.java index f661bb04dc3da..c34e7d24e1fe7 100644 --- a/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/ssl/SslEntitlementRestIT.java +++ b/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/ssl/SslEntitlementRestIT.java @@ -32,7 +32,6 @@ public class SslEntitlementRestIT extends ESRestTestCase { public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .apply(SecurityOnTrialLicenseRestTestCase.commonTrialSecurityClusterConfig) .settings(settingsProvider) - .systemProperty("es.entitlements.enabled", "true") .build(); @Override diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java index 3d640cd962c19..5de501e42d1f2 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java @@ -69,6 +69,7 @@ * {@link SecurityIntegTestCase} due to simplicity and improved speed from not needing to start * multiple nodes and wait for the cluster to form. */ +@ESTestCase.WithoutEntitlements // requires entitlement delegation ES-12382 public abstract class SecuritySingleNodeTestCase extends ESSingleNodeTestCase { private static SecuritySettingsSource SECURITY_DEFAULT_SETTINGS = null; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java index ff8b6f5eaac39..fa16de22c865c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java @@ -240,6 +240,7 @@ protected void initChannel(Channel ch) throws Exception { SSLEngine serverEngine = sslService.createSSLEngine(configuration, null, -1); serverEngine.setUseClientMode(false); final SslHandler sslHandler = new SslHandler(serverEngine); + sslHandler.setHandshakeTimeoutMillis(configuration.handshakeTimeoutMillis()); ch.pipeline().addFirst("sslhandler", sslHandler); super.initChannel(ch); assert ch.pipeline().first() == sslHandler : "SSL handler must be first handler in pipeline"; @@ -340,6 +341,7 @@ public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, Sock } final ChannelPromise connectPromise = ctx.newPromise(); final SslHandler sslHandler = new SslHandler(sslEngine); + sslHandler.setHandshakeTimeoutMillis(sslConfiguration.handshakeTimeoutMillis()); ctx.pipeline().replace(this, "ssl", sslHandler); final Future handshakePromise = sslHandler.handshakeFuture(); Netty4Utils.addListener(connectPromise, result -> { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributes.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributes.java index 320a39018d599..d1649ce209b9c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributes.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributes.java @@ -92,11 +92,14 @@ static class SamlAttribute { @Override public String toString() { + StringBuilder str = new StringBuilder(); if (Strings.isNullOrEmpty(friendlyName)) { - return name + '=' + values; + str.append(name); } else { - return friendlyName + '(' + name + ")=" + values; + str.append(friendlyName).append('(').append(name).append(')'); } + str.append("=").append(values).append("(len=").append(values.size()).append(')'); + return str.toString(); } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/QueryableBuiltInRolesSynchronizer.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/QueryableBuiltInRolesSynchronizer.java index 2c684e7e49ffd..65595f85d931b 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/QueryableBuiltInRolesSynchronizer.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/QueryableBuiltInRolesSynchronizer.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.coordination.FailedToCommitClusterStateException; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.common.Priority; @@ -515,7 +516,8 @@ public Map getNewRoleDigests() { } Tuple> execute(ClusterState state) { - IndexMetadata indexMetadata = state.metadata().getProject().index(concreteSecurityIndexName); + final var project = state.metadata().getProject(); + IndexMetadata indexMetadata = project.index(concreteSecurityIndexName); if (indexMetadata == null) { throw new IndexNotFoundException(concreteSecurityIndexName); } @@ -528,10 +530,12 @@ Tuple> execute(ClusterState state) { indexMetadataBuilder.removeCustom(METADATA_QUERYABLE_BUILT_IN_ROLES_DIGEST_KEY); } indexMetadataBuilder.version(indexMetadataBuilder.version() + 1); - ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(state.metadata().getProject().indices()); + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(project.indices()); builder.put(concreteSecurityIndexName, indexMetadataBuilder.build()); return new Tuple<>( - ClusterState.builder(state).metadata(Metadata.builder(state.metadata()).indices(builder.build()).build()).build(), + ClusterState.builder(state) + .putProjectMetadata(ProjectMetadata.builder(project).indices(builder.build()).build()) + .build(), newRoleDigests ); } else { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/test/SecurityIntegTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/test/SecurityIntegTestCase.java index 0b39b166bd128..5e39a94220571 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/test/SecurityIntegTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/test/SecurityIntegTestCase.java @@ -60,6 +60,7 @@ * * @see SecuritySettingsSource */ +@ESTestCase.WithoutEntitlements // requires entitlement delegation ES-12382 public abstract class SecurityIntegTestCase extends ESIntegTestCase { private static SecuritySettingsSource SECURITY_DEFAULT_SETTINGS; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributesTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributesTests.java new file mode 100644 index 0000000000000..2964552f9fa4d --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/saml/SamlAttributesTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.authc.saml; + +import org.hamcrest.Matchers; +import org.opensaml.saml.saml2.core.NameID; + +import java.util.List; + +public class SamlAttributesTests extends SamlTestCase { + + public void testToString() { + final String nameFormat = randomFrom(NameID.TRANSIENT, NameID.PERSISTENT, NameID.EMAIL); + final String nameId = randomIdentifier(); + final String session = randomAlphaOfLength(16); + final SamlAttributes attributes = new SamlAttributes( + new SamlNameId(nameFormat, nameId, null, null, null), + session, + List.of( + new SamlAttributes.SamlAttribute("urn:oid:0.9.2342.19200300.100.1.1", null, List.of("peter.ng")), + new SamlAttributes.SamlAttribute("urn:oid:2.5.4.3", "name", List.of("Peter Ng")), + new SamlAttributes.SamlAttribute( + "urn:oid:1.3.6.1.4.1.5923.1.5.1.1", + "groups", + List.of("employees", "engineering", "managers") + ) + ) + ); + assertThat( + attributes.toString(), + Matchers.equalTo( + "SamlAttributes(" + + ("NameId(" + nameFormat + ")=" + nameId) + + ")[" + + session + + "]{[" + + "urn:oid:0.9.2342.19200300.100.1.1=[peter.ng](len=1)" + + ", " + + "name(urn:oid:2.5.4.3)=[Peter Ng](len=1)" + + ", " + + "groups(urn:oid:1.3.6.1.4.1.5923.1.5.1.1)=[employees, engineering, managers](len=3)" + + "]}" + ) + ); + } + +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java index 094751f51dd5b..68586a973bc8a 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java @@ -1118,7 +1118,8 @@ public void testProfileFiltersCreatedDifferentlyForDifferentTransportAndRemoteCl randomFrom(SslVerificationMode.values()), SslClientAuthenticationMode.REQUIRED, List.of("TLS_AES_256_GCM_SHA384"), - List.of("TLSv1.3") + List.of("TLSv1.3"), + randomLongBetween(1, 100000) ) ); @@ -1131,7 +1132,8 @@ public void testProfileFiltersCreatedDifferentlyForDifferentTransportAndRemoteCl randomFrom(SslVerificationMode.values()), SslClientAuthenticationMode.NONE, List.of(Runtime.version().feature() < 24 ? "TLS_RSA_WITH_AES_256_GCM_SHA384" : "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"), - List.of("TLSv1.2") + List.of("TLSv1.2"), + randomLongBetween(1, 100000) ) ); doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) @@ -1181,7 +1183,8 @@ public void testNoProfileFilterForRemoteClusterWhenTheFeatureIsDisabled() { randomFrom(SslVerificationMode.values()), SslClientAuthenticationMode.REQUIRED, List.of("TLS_AES_256_GCM_SHA384"), - List.of("TLSv1.3") + List.of("TLSv1.3"), + randomLongBetween(1, 100000) ) ); doThrow(new AssertionError("profile filters should not be configured for remote cluster server when the port is disabled")).when( diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java index b984295155c1f..7a7896eb08d83 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java @@ -11,6 +11,7 @@ import io.netty.channel.socket.nio.NioChannelOption; import io.netty.handler.ssl.SslHandshakeTimeoutException; +import org.apache.logging.log4j.Level; import org.apache.lucene.util.Constants; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; @@ -35,6 +36,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; @@ -42,6 +44,9 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.mocksocket.MockServerSocket; +import org.elasticsearch.mocksocket.MockSocket; +import org.elasticsearch.test.MockLog; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.StubbableTransport; import org.elasticsearch.threadpool.ThreadPool; @@ -65,6 +70,7 @@ import org.elasticsearch.xpack.security.transport.SSLEngineUtils; import org.elasticsearch.xpack.security.transport.filter.IPFilter; +import java.io.EOFException; import java.io.IOException; import java.io.UncheckedIOException; import java.net.InetAddress; @@ -902,7 +908,15 @@ public void testTcpHandshakeTimeout() throws IOException { } } + @TestLogging(reason = "inbound timeout is reported at TRACE", value = "org.elasticsearch.transport.netty4.ESLoggingHandler:TRACE") public void testTlsHandshakeTimeout() throws IOException { + runOutboundTlsHandshakeTimeoutTest(null); + runOutboundTlsHandshakeTimeoutTest(randomLongBetween(1, 500)); + runInboundTlsHandshakeTimeoutTest(null); + runInboundTlsHandshakeTimeoutTest(randomLongBetween(1, 500)); + } + + private void runOutboundTlsHandshakeTimeoutTest(@Nullable /* to use the default */ Long handshakeTimeoutMillis) throws IOException { final CountDownLatch doneLatch = new CountDownLatch(1); try (ServerSocket socket = new MockServerSocket()) { socket.bind(getLocalEphemeral(), 1); @@ -928,16 +942,56 @@ public void testTlsHandshakeTimeout() throws IOException { TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE ); - final var future = new TestPlainActionFuture(); - serviceA.connectToNode(dummy, builder.build(), future); - final var ex = expectThrows(ExecutionException.class, ConnectTransportException.class, future::get); // long wait - assertEquals("[][" + dummy.getAddress() + "] connect_exception", ex.getMessage()); - assertNotNull(ExceptionsHelper.unwrap(ex, SslHandshakeTimeoutException.class)); + final ConnectTransportException exception; + final var transportSettings = Settings.builder(); + if (handshakeTimeoutMillis == null) { + handshakeTimeoutMillis = 10000L; // default + } else { + transportSettings.put("xpack.security.transport.ssl.handshake_timeout", TimeValue.timeValueMillis(handshakeTimeoutMillis)); + } + try (var service = buildService(getTestName(), version0, transportVersion0, transportSettings.build())) { + final var future = new TestPlainActionFuture(); + service.connectToNode(dummy, builder.build(), future); + exception = expectThrows(ExecutionException.class, ConnectTransportException.class, future::get); // long wait + assertEquals("[][" + dummy.getAddress() + "] connect_exception", exception.getMessage()); + assertThat( + asInstanceOf(SslHandshakeTimeoutException.class, exception.getCause()).getMessage(), + equalTo("handshake timed out after " + handshakeTimeoutMillis + "ms") + ); + } } finally { doneLatch.countDown(); } } + @SuppressForbidden(reason = "test needs a simple TCP connection") + private void runInboundTlsHandshakeTimeoutTest(@Nullable /* to use the default */ Long handshakeTimeoutMillis) throws IOException { + final var transportSettings = Settings.builder(); + if (handshakeTimeoutMillis == null) { + handshakeTimeoutMillis = 10000L; // default + } else { + transportSettings.put("xpack.security.transport.ssl.handshake_timeout", TimeValue.timeValueMillis(handshakeTimeoutMillis)); + } + try ( + var service = buildService(getTestName(), version0, transportVersion0, transportSettings.build()); + Socket clientSocket = new MockSocket(); + MockLog mockLog = MockLog.capture("org.elasticsearch.transport.netty4.ESLoggingHandler") + ) { + mockLog.addExpectation( + new MockLog.SeenEventExpectation( + "timeout event message", + "org.elasticsearch.transport.netty4.ESLoggingHandler", + Level.TRACE, + "SslHandshakeTimeoutException: handshake timed out after " + handshakeTimeoutMillis + "ms" + ) + ); + + clientSocket.connect(service.boundAddress().boundAddresses()[0].address()); + expectThrows(EOFException.class, () -> clientSocket.getInputStream().skipNBytes(Long.MAX_VALUE)); + mockLog.assertAllExpectationsMatched(); + } + } + public void testTcpHandshakeConnectionReset() throws IOException, InterruptedException { assumeFalse("Can't run in a FIPS JVM, TrustAllConfig is not a SunJSSE TrustManagers", inFipsJvm()); SSLService sslService = createSSLService(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/ssl/SSLErrorMessageFileTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/ssl/SSLErrorMessageFileTests.java index 4f64b780e1f97..2ac2d4ebf0c32 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/ssl/SSLErrorMessageFileTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/ssl/SSLErrorMessageFileTests.java @@ -16,7 +16,6 @@ import org.elasticsearch.core.PathUtils; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; -import org.elasticsearch.jdk.RuntimeVersionFeature; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ssl.SSLService; import org.junit.Before; @@ -363,11 +362,6 @@ private void checkBlockedResource( String configKey, BiConsumer configure ) throws Exception { - assumeTrue( - "Requires Security Manager to block access, entitlements are not checked for unit tests", - RuntimeVersionFeature.isSecurityManagerAvailable() - ); - final String prefix = randomSslPrefix(); final Settings.Builder settings = Settings.builder(); configure.accept(prefix, settings); diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java index 784f1c1fbe23e..10c1a4321f1e5 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; @@ -166,13 +167,14 @@ protected TaskExecutor(Client client, ClusterService clusterService, ThreadPool } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( TestTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + ProjectId projectId ) { candidates.set(candidateNodes); - return super.getAssignment(params, candidateNodes, clusterState); + return super.doGetAssignment(params, candidateNodes, clusterState, projectId); } @Override diff --git a/x-pack/plugin/slm/src/internalClusterTest/java/org/elasticsearch/xpack/slm/SLMFileSettingsIT.java b/x-pack/plugin/slm/src/internalClusterTest/java/org/elasticsearch/xpack/slm/SLMFileSettingsIT.java index 54a390f55cc35..1bc0d56ad3626 100644 --- a/x-pack/plugin/slm/src/internalClusterTest/java/org/elasticsearch/xpack/slm/SLMFileSettingsIT.java +++ b/x-pack/plugin/slm/src/internalClusterTest/java/org/elasticsearch/xpack/slm/SLMFileSettingsIT.java @@ -186,6 +186,7 @@ private void assertClusterStateSaveOK(CountDownLatch savedClusterState, AtomicLo boolean awaitSuccessful = savedClusterState.await(20, TimeUnit.SECONDS); assertTrue(awaitSuccessful); + awaitMasterNode(); final ClusterStateResponse clusterStateResponse = clusterAdmin().state( new ClusterStateRequest(TEST_REQUEST_TIMEOUT).waitForMetadataVersion(metadataVersion.get()) ).get(); @@ -229,8 +230,7 @@ public void testSettingsApplied() throws Exception { writeJSONFile(dataNode, testJSON); logger.info("--> start master node"); - final String masterNode = internalCluster().startMasterOnlyNode(); - awaitMasterNode(internalCluster().getNonMasterNodeName(), masterNode); + internalCluster().startMasterOnlyNode(); assertClusterStateSaveOK(savedClusterState.v1(), savedClusterState.v2()); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java index 399a9eee0d752..40d1c12af9061 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java @@ -76,6 +76,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.matchesPattern; import static org.hamcrest.Matchers.nullValue; public class RepositoryAnalysisFailureIT extends AbstractSnapshotIntegTestCase { @@ -385,6 +386,55 @@ public BytesReference onContendedCompareAndExchange(BytesRegister register, Byte assertAnalysisFailureMessage(analyseRepositoryExpectFailure(request).getMessage()); } + public void testFailsOnLostIncrement() { + final RepositoryAnalyzeAction.Request request = new RepositoryAnalyzeAction.Request("test-repo"); + final AtomicBoolean registerWasCorrupted = new AtomicBoolean(); + + blobStore.setDisruption(new Disruption() { + @Override + public BytesReference onContendedCompareAndExchange(BytesRegister register, BytesReference expected, BytesReference updated) { + if (expected.equals(updated) == false // not the initial read + && updated.length() == Long.BYTES // not the final write + && randomBoolean() + && register.get().equals(expected) // would have succeeded + && registerWasCorrupted.compareAndSet(false, true)) { + + // indicate success without actually applying the update + return expected; + } + + return register.compareAndExchange(expected, updated); + } + }); + + safeAwait((ActionListener l) -> analyseRepository(request, l.delegateResponse((ll, e) -> { + if (ExceptionsHelper.unwrapCause(e) instanceof RepositoryVerificationException repositoryVerificationException) { + assertAnalysisFailureMessage(repositoryVerificationException.getMessage()); + assertTrue( + "did not lose increment, so why did the verification fail?", + // clear flag for final assertion + registerWasCorrupted.compareAndSet(true, false) + ); + assertThat( + asInstanceOf( + RepositoryVerificationException.class, + ExceptionsHelper.unwrapCause(repositoryVerificationException.getCause()) + ).getMessage(), + matchesPattern(""" + \\[test-repo] Successfully completed all \\[.*] atomic increments of register \\[test-register-contended-.*] \ + so its expected value is \\[OptionalBytesReference\\[.*]], but reading its value with \\[.*] unexpectedly \ + yielded \\[OptionalBytesReference\\[.*]]\\. This anomaly may indicate an atomicity failure amongst concurrent \ + compare-and-exchange operations on registers in this repository\\.""") + ); + ll.onResponse(null); + } else { + ll.onFailure(e); + } + }))); + + assertFalse(registerWasCorrupted.get()); + } + public void testFailsIfRegisterHoldsSpuriousValue() { final RepositoryAnalyzeAction.Request request = new RepositoryAnalyzeAction.Request("test-repo"); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java index 5418a5081c443..35e11ae40d51a 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java @@ -42,6 +42,7 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThrottledIterator; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; @@ -76,7 +77,6 @@ import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.LongSupplier; @@ -379,7 +379,6 @@ public static class AsyncAction { // choose the blob path nondeterministically to avoid clashes, assuming that the actual path doesn't matter for reproduction private final String blobPath = "temp-analysis-" + UUIDs.randomBase64UUID(); - private final AtomicLong expectedRegisterValue = new AtomicLong(); private final Queue> queue = ConcurrentCollections.newQueue(); private final AtomicReference failure = new AtomicReference<>(); private final Semaphore innerFailures = new Semaphore(5); // limit the number of suppressed failures @@ -485,16 +484,17 @@ public void run() { if (minClusterTransportVersion.onOrAfter(TransportVersions.V_8_8_0)) { final String contendedRegisterName = CONTENDED_REGISTER_NAME_PREFIX + UUIDs.randomBase64UUID(random); final AtomicBoolean contendedRegisterAnalysisComplete = new AtomicBoolean(); + final int registerOperations = Math.max(nodes.size(), request.getRegisterOperationCount()); try ( var registerRefs = new RefCountingRunnable( finalRegisterValueVerifier( contendedRegisterName, + registerOperations, random, Releasables.wrap(requestRefs.acquire(), () -> contendedRegisterAnalysisComplete.set(true)) ) ) ) { - final int registerOperations = Math.max(nodes.size(), request.getRegisterOperationCount()); for (int i = 0; i < registerOperations; i++) { final ContendedRegisterAnalyzeAction.Request registerAnalyzeRequest = new ContendedRegisterAnalyzeAction.Request( request.getRepositoryName(), @@ -630,9 +630,7 @@ private void runContendedRegisterAnalysis(Releasable ref, ContendedRegisterAnaly TransportRequestOptions.EMPTY, new ActionListenerResponseHandler<>(ActionListener.releaseAfter(new ActionListener<>() { @Override - public void onResponse(ActionResponse.Empty response) { - expectedRegisterValue.incrementAndGet(); - } + public void onResponse(ActionResponse.Empty response) {} @Override public void onFailure(Exception exp) { @@ -646,68 +644,108 @@ public void onFailure(Exception exp) { } } - private Runnable finalRegisterValueVerifier(String registerName, Random random, Releasable ref) { - return () -> { - if (isRunning()) { - final var expectedFinalRegisterValue = expectedRegisterValue.get(); - transportService.getThreadPool() - .executor(ThreadPool.Names.SNAPSHOT) - .execute(ActionRunnable.wrap(ActionListener.releaseAfter(new ActionListener() { - @Override - public void onResponse(OptionalBytesReference actualFinalRegisterValue) { - if (actualFinalRegisterValue.isPresent() == false - || longFromBytes(actualFinalRegisterValue.bytesReference()) != expectedFinalRegisterValue) { - fail( - new RepositoryVerificationException( - request.getRepositoryName(), - Strings.format( - "register [%s] should have value [%d] but instead had value [%s]", - registerName, - expectedFinalRegisterValue, - actualFinalRegisterValue + private Runnable finalRegisterValueVerifier(String registerName, int expectedFinalRegisterValue, Random random, Releasable ref) { + return new Runnable() { + + final CheckedConsumer, Exception> finalValueReader = switch (random.nextInt(3)) { + case 0 -> new CheckedConsumer, Exception>() { + @Override + public void accept(ActionListener listener) { + getBlobContainer().getRegister(OperationPurpose.REPOSITORY_ANALYSIS, registerName, listener); + } + + @Override + public String toString() { + return "getRegister"; + } + }; + case 1 -> new CheckedConsumer, Exception>() { + @Override + public void accept(ActionListener listener) { + getBlobContainer().compareAndExchangeRegister( + OperationPurpose.REPOSITORY_ANALYSIS, + registerName, + bytesFromLong(expectedFinalRegisterValue), + new BytesArray(new byte[] { (byte) 0xff }), + listener + ); + } + + @Override + public String toString() { + return "compareAndExchangeRegister"; + } + }; + case 2 -> new CheckedConsumer, Exception>() { + @Override + public void accept(ActionListener listener) { + getBlobContainer().compareAndSetRegister( + OperationPurpose.REPOSITORY_ANALYSIS, + registerName, + bytesFromLong(expectedFinalRegisterValue), + new BytesArray(new byte[] { (byte) 0xff }), + listener.map( + b -> b + ? OptionalBytesReference.of(bytesFromLong(expectedFinalRegisterValue)) + : OptionalBytesReference.MISSING + ) + ); + } + + @Override + public String toString() { + return "compareAndSetRegister"; + } + }; + default -> { + assert false; + throw new IllegalStateException(); + } + }; + + @Override + public void run() { + if (isRunning()) { + transportService.getThreadPool() + .executor(ThreadPool.Names.SNAPSHOT) + .execute(ActionRunnable.wrap(ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(OptionalBytesReference actualFinalRegisterValue) { + if (actualFinalRegisterValue.isPresent() == false + || longFromBytes(actualFinalRegisterValue.bytesReference()) != expectedFinalRegisterValue) { + fail( + new RepositoryVerificationException( + request.getRepositoryName(), + Strings.format( + """ + Successfully completed all [%d] atomic increments of register [%s] so its expected \ + value is [%s], but reading its value with [%s] unexpectedly yielded [%s]. This \ + anomaly may indicate an atomicity failure amongst concurrent compare-and-exchange \ + operations on registers in this repository.""", + expectedFinalRegisterValue, + registerName, + OptionalBytesReference.of(bytesFromLong(expectedFinalRegisterValue)), + finalValueReader.toString(), + actualFinalRegisterValue + ) ) - ) - ); + ); + } } - } - @Override - public void onFailure(Exception exp) { - // Registers are not supported on all repository types, and that's ok. - if (exp instanceof UnsupportedOperationException == false) { - fail(exp); - } - } - }, ref), listener -> { - switch (random.nextInt(3)) { - case 0 -> getBlobContainer().getRegister(OperationPurpose.REPOSITORY_ANALYSIS, registerName, listener); - case 1 -> getBlobContainer().compareAndExchangeRegister( - OperationPurpose.REPOSITORY_ANALYSIS, - registerName, - bytesFromLong(expectedFinalRegisterValue), - new BytesArray(new byte[] { (byte) 0xff }), - listener - ); - case 2 -> getBlobContainer().compareAndSetRegister( - OperationPurpose.REPOSITORY_ANALYSIS, - registerName, - bytesFromLong(expectedFinalRegisterValue), - new BytesArray(new byte[] { (byte) 0xff }), - listener.map( - b -> b - ? OptionalBytesReference.of(bytesFromLong(expectedFinalRegisterValue)) - : OptionalBytesReference.MISSING - ) - ); - default -> { - assert false; - throw new IllegalStateException(); + @Override + public void onFailure(Exception exp) { + // Registers are not supported on all repository types, and that's ok. + if (exp instanceof UnsupportedOperationException == false) { + fail(exp); + } } - } - })); - } else { - ref.close(); + }, ref), finalValueReader)); + } else { + ref.close(); + } } + }; } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml index 5b0492f9e847e..b9d20d4cd40cf 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml @@ -555,11 +555,11 @@ FROM EVAL SORT LIMIT with documents_found: - method: POST path: /_query parameters: [ ] - capabilities: [ parameter_for_limit ] + capabilities: [ normalized_limit_error_message ] reason: "named or positional parameters for field names" - do: - catch: "/Invalid value for LIMIT \\[foo: String\\], expecting a non negative integer/" + catch: "/value of \\[limit \\?l\\] must be a non negative integer, found value \\[\\?l\\] type \\[keyword\\]/" esql.query: body: query: 'from test | limit ?l' diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/46_downsample.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/46_downsample.yml index ee1a381c6e589..23e6772d26d08 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/46_downsample.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/46_downsample.yml @@ -251,3 +251,381 @@ setup: - match: {values.0.1: 800479.0} - match: {values.0.2: 4812452.0} - match: {values.0.3: 6} + +--- +"Stats from downsampled and non-downsampled index simultaneously with implicit casting": + - requires: + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_query + parameters: [] + capabilities: [aggregate_metric_double_implicit_casting_in_aggs] + reason: "Support for casting aggregate metric double implicitly when present in aggregations" + + - do: + indices.downsample: + index: test + target_index: test-downsample + body: > + { + "fixed_interval": "1h" + } + - is_true: acknowledged + + - do: + indices.create: + index: test-2 + body: + settings: + number_of_shards: 1 + index: + mode: time_series + routing_path: [ metricset, k8s.pod.uid ] + time_series: + start_time: 2021-04-29T00:00:00Z + end_time: 2021-04-30T00:00:00Z + mappings: + properties: + "@timestamp": + type: date + metricset: + type: keyword + time_series_dimension: true + k8s: + properties: + pod: + properties: + uid: + type: keyword + time_series_dimension: true + name: + type: keyword + created_at: + type: date_nanos + running: + type: boolean + number_of_containers: + type: integer + ip: + type: ip + tags: + type: keyword + values: + type: integer + network: + properties: + tx: + type: long + time_series_metric: gauge + rx: + type: long + time_series_metric: gauge + + - do: + bulk: + refresh: true + index: test-2 + body: + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:04.467Z", "metricset": "pod", "k8s": {"pod": {"name": "cat", "uid":"947e4ced-1786-4e53-9e0c-5c447e959507", "ip": "10.10.55.1", "network": {"tx": 2001810, "rx": 802339}, "created_at": "2021-04-28T19:34:00.000Z", "running": false, "number_of_containers": 2, "tags": ["backend", "prod"], "values": [2, 3, 6]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:24.467Z", "metricset": "pod", "k8s": {"pod": {"name": "cat", "uid":"947e4ced-1786-4e53-9e0c-5c447e959507", "ip": "10.10.55.26", "network": {"tx": 2000177, "rx": 800479}, "created_at": "2021-04-28T19:35:00.000Z", "running": true, "number_of_containers": 2, "tags": ["backend", "prod", "us-west1"], "values": [1, 1, 3]}}}' + - '{"index": {}}' + + - do: + esql.query: + body: + query: "FROM test-* | + WHERE k8s.pod.uid == \"947e4ced-1786-4e53-9e0c-5c447e959507\" | + STATS max(k8s.pod.network.rx), min(k8s.pod.network.rx), sum(k8s.pod.network.rx), count(k8s.pod.network.rx), avg(k8s.pod.network.rx) | + LIMIT 100" + + - length: {values: 1} + - length: {values.0: 5} + - match: {columns.0.name: "max(k8s.pod.network.rx)"} + - match: {columns.0.type: "double"} + - match: {columns.1.name: "min(k8s.pod.network.rx)"} + - match: {columns.1.type: "double"} + - match: {columns.2.name: "sum(k8s.pod.network.rx)"} + - match: {columns.2.type: "double"} + - match: {columns.3.name: "count(k8s.pod.network.rx)"} + - match: {columns.3.type: "long"} + - match: {columns.4.name: "avg(k8s.pod.network.rx)"} + - match: {columns.4.type: "double"} + - match: {values.0.0: 803685.0} + - match: {values.0.1: 800479.0} + - match: {values.0.2: 4812452.0} + - match: {values.0.3: 6} + - match: {values.0.4: 802075.3333333334} + + - do: + esql.query: + body: + query: "TS test-* | STATS max = max(k8s.pod.network.rx) | LIMIT 100" + - length: {values: 1} + - length: {values.0: 1} + - match: {columns.0.name: "max"} + - match: {columns.0.type: "double"} + - match: {values.0.0: 803685.0} + +--- +"Over time functions from downsampled and non-downsampled indices simultaneously, no grouping": + - requires: + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_query + parameters: [] + capabilities: [aggregate_metric_double_implicit_casting_in_aggs] + reason: "Support for casting aggregate metric double implicitly when present in aggregations" + + - do: + indices.downsample: + index: test + target_index: test-downsample + body: > + { + "fixed_interval": "1h" + } + - is_true: acknowledged + + - do: + indices.create: + index: test-2 + body: + settings: + number_of_shards: 1 + index: + mode: time_series + routing_path: [ metricset, k8s.pod.uid ] + time_series: + start_time: 2021-04-29T00:00:00Z + end_time: 2021-04-30T00:00:00Z + mappings: + properties: + "@timestamp": + type: date + metricset: + type: keyword + time_series_dimension: true + k8s: + properties: + pod: + properties: + uid: + type: keyword + time_series_dimension: true + name: + type: keyword + created_at: + type: date_nanos + running: + type: boolean + number_of_containers: + type: integer + ip: + type: ip + tags: + type: keyword + values: + type: integer + network: + properties: + tx: + type: long + time_series_metric: gauge + rx: + type: long + time_series_metric: gauge + + - do: + bulk: + refresh: true + index: test-2 + body: + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:04.467Z", "metricset": "pod", "k8s": {"pod": {"name": "cat", "uid":"947e4ced-1786-4e53-9e0c-5c447e959507", "ip": "10.10.55.10", "network": {"tx": 2005820, "rx": 802339}, "created_at": "2021-04-29T21:34:00.000Z", "running": false, "number_of_containers": 2, "tags": ["backend", "prod"], "values": [2, 3, 6]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:24.467Z", "metricset": "pod", "k8s": {"pod": {"name": "cat", "uid":"947e4ced-1786-4e53-9e0c-5c447e959507", "ip": "10.10.55.28", "network": {"tx": 2000481, "rx": 800479}, "created_at": "2021-04-29T21:35:00.000Z", "running": true, "number_of_containers": 2, "tags": ["backend", "prod", "us-west1"], "values": [1, 1, 3]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:14.467Z", "metricset": "pod", "k8s": {"pod": {"name": "dog", "uid":"df3145b3-0563-4d3b-a0f7-897eb2876ea9", "ip": "10.10.55.192", "network": {"tx": 1458377, "rx": 530184}, "created_at": "2021-04-29T21:36:00.000Z", "running": false, "number_of_containers": 2, "tags": ["backend", "test"], "values": [3, 3, 1]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:44.467Z", "metricset": "pod", "k8s": {"pod": {"name": "dog", "uid":"df3145b3-0563-4d3b-a0f7-897eb2876ea9", "ip": "10.10.55.206", "network": {"tx": 1434104, "rx": 535020}, "created_at": "2021-04-29T21:35:00.000Z", "running": true, "number_of_containers": 2, "tags": ["backend", "prod", "us-west2"], "values": [4, 1, 3]}}}' + - '{"index": {}}' + + - do: + esql.query: + body: + query: "TS test-* | + STATS avg = sum(avg_over_time(k8s.pod.network.rx)), + count = sum(count_over_time(k8s.pod.network.rx)), + sum = sum(sum_over_time(k8s.pod.network.rx)) + BY time_bucket = bucket(@timestamp, 1 hour) | + SORT time_bucket | LIMIT 10" + + - length: {values: 4} + - length: {values.0: 4} + - match: {columns.0.name: "avg"} + - match: {columns.0.type: "double"} + - match: {columns.1.name: "count"} + - match: {columns.1.type: "long"} + - match: {columns.2.name: "sum"} + - match: {columns.2.type: "double"} + - match: {columns.3.name: "time_bucket"} + - match: {columns.3.type: "date"} + - match: {values.0.0: 1332393.5} + - match: {values.0.1: 4} + - match: {values.0.2: 2664787.0} + - match: {values.0.3: "2021-04-28T18:00:00.000Z"} + - match: {values.1.0: 530604.5} + - match: {values.1.1: 2} + - match: {values.1.2: 1061209.0} + - match: {values.1.3: "2021-04-28T19:00:00.000Z"} + - match: {values.2.0: 803011.0} + - match: {values.2.1: 2} + - match: {values.2.2: 1606022.0} + - match: {values.2.3: "2021-04-28T20:00:00.000Z"} + - match: {values.3.0: 1334011.0} + - match: {values.3.1: 4} + - match: {values.3.2: 2668022.0} + - match: {values.3.3: "2021-04-29T21:00:00.000Z"} + +--- +"Over time functions from downsampled and non-downsampled indices simultaneously, with grouping": + - requires: + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_query + parameters: [] + capabilities: [aggregate_metric_double_implicit_casting_in_aggs] + reason: "Support for casting aggregate metric double implicitly when present in aggregations" + + - do: + indices.downsample: + index: test + target_index: test-downsample + body: > + { + "fixed_interval": "1h" + } + - is_true: acknowledged + + - do: + indices.create: + index: test-2 + body: + settings: + number_of_shards: 1 + index: + mode: time_series + routing_path: [ metricset, k8s.pod.uid ] + time_series: + start_time: 2021-04-29T00:00:00Z + end_time: 2021-04-30T00:00:00Z + mappings: + properties: + "@timestamp": + type: date + metricset: + type: keyword + time_series_dimension: true + k8s: + properties: + pod: + properties: + uid: + type: keyword + time_series_dimension: true + name: + type: keyword + created_at: + type: date_nanos + running: + type: boolean + number_of_containers: + type: integer + ip: + type: ip + tags: + type: keyword + values: + type: integer + network: + properties: + tx: + type: long + time_series_metric: gauge + rx: + type: long + time_series_metric: gauge + + - do: + bulk: + refresh: true + index: test-2 + body: + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:04.467Z", "metricset": "pod", "k8s": {"pod": {"name": "cat", "uid":"947e4ced-1786-4e53-9e0c-5c447e959507", "ip": "10.10.55.10", "network": {"tx": 2005820, "rx": 802339}, "created_at": "2021-04-29T21:34:00.000Z", "running": false, "number_of_containers": 2, "tags": ["backend", "prod"], "values": [2, 3, 6]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:24.467Z", "metricset": "pod", "k8s": {"pod": {"name": "cat", "uid":"947e4ced-1786-4e53-9e0c-5c447e959507", "ip": "10.10.55.28", "network": {"tx": 2000481, "rx": 800479}, "created_at": "2021-04-29T21:35:00.000Z", "running": true, "number_of_containers": 2, "tags": ["backend", "prod", "us-west1"], "values": [1, 1, 3]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:14.467Z", "metricset": "pod", "k8s": {"pod": {"name": "dog", "uid":"df3145b3-0563-4d3b-a0f7-897eb2876ea9", "ip": "10.10.55.192", "network": {"tx": 1458377, "rx": 530184}, "created_at": "2021-04-29T21:36:00.000Z", "running": false, "number_of_containers": 2, "tags": ["backend", "test"], "values": [3, 3, 1]}}}' + - '{"index": {}}' + - '{"@timestamp": "2021-04-29T21:50:44.467Z", "metricset": "pod", "k8s": {"pod": {"name": "dog", "uid":"df3145b3-0563-4d3b-a0f7-897eb2876ea9", "ip": "10.10.55.206", "network": {"tx": 1434104, "rx": 535020}, "created_at": "2021-04-29T21:35:00.000Z", "running": true, "number_of_containers": 2, "tags": ["backend", "prod", "us-west2"], "values": [4, 1, 3]}}}' + - '{"index": {}}' + + - do: + esql.query: + body: + query: "TS test-* | + STATS avg = sum(avg_over_time(k8s.pod.network.rx)), + count = sum(count_over_time(k8s.pod.network.rx)), + sum = sum(sum_over_time(k8s.pod.network.rx)) + BY k8s.pod.name, time_bucket = bucket(@timestamp, 1 hour) | + SORT time_bucket, k8s.pod.name | + LIMIT 10" + + - length: {values: 6} + - length: {values.0: 5} + - match: {columns.0.name: "avg"} + - match: {columns.0.type: "double"} + - match: {columns.1.name: "count"} + - match: {columns.1.type: "long"} + - match: {columns.2.name: "sum"} + - match: {columns.2.type: "double"} + - match: {columns.3.name: "k8s.pod.name"} + - match: {columns.3.type: "keyword"} + - match: {columns.4.name: "time_bucket"} + - match: {columns.4.type: "date"} + - match: {values.0.0: 801806.0} + - match: {values.0.1: 2} + - match: {values.0.2: 1603612.0} + - match: {values.0.3: "cat"} + - match: {values.0.4: "2021-04-28T18:00:00.000Z"} + - match: {values.1.0: 530587.5} + - match: {values.1.1: 2} + - match: {values.1.2: 1061175.0} + - match: {values.1.3: "dog"} + - match: {values.1.4: "2021-04-28T18:00:00.000Z"} + - match: {values.2.0: 530604.5} + - match: {values.2.1: 2} + - match: {values.2.2: 1061209.0} + - match: {values.2.3: "dog"} + - match: {values.2.4: "2021-04-28T19:00:00.000Z"} + - match: {values.3.0: 803011.0} + - match: {values.3.1: 2} + - match: {values.3.2: 1606022.0} + - match: {values.3.3: "cat"} + - match: {values.3.4: "2021-04-28T20:00:00.000Z"} + - match: {values.4.0: 801409.0} + - match: {values.4.1: 2} + - match: {values.4.2: 1602818.0} + - match: {values.4.3: "cat"} + - match: {values.4.4: "2021-04-29T21:00:00.000Z"} + - match: {values.5.0: 532602.0} + - match: {values.5.1: 2} + - match: {values.5.2: 1065204.0} + - match: {values.5.3: "dog"} + - match: {values.5.4: "2021-04-29T21:00:00.000Z"} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_enrich.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_enrich.yml index d05b14a390c4e..c8f3c15400b9e 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_enrich.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_enrich.yml @@ -225,7 +225,7 @@ teardown: - method: POST path: /_query parameters: [] - capabilities: [ no_brackets_in_unquoted_index_names ] + capabilities: [ no_brackets_in_unquoted_index_names, fork_v9 ] reason: "Change in the grammar" - do: diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 72b518e2228ee..94f56c3c85367 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -41,6 +41,7 @@ setup: - sum_over_time - count_over_time - distinct_over_time + - cosine_vector_similarity_function reason: "Test that should only be executed on snapshot versions" - do: {xpack.usage: {}} @@ -130,7 +131,7 @@ setup: - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 156} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 157} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version": diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java index b7bd434194b80..495a0db966343 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; @@ -113,10 +114,11 @@ public TransformPersistentTasksExecutor( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( TransformTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { /* Note: * diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/TransformMetadataTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/TransformMetadataTests.java index 108bbab85935e..8f2008dab55db 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/TransformMetadataTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/TransformMetadataTests.java @@ -7,11 +7,17 @@ package org.elasticsearch.xpack.transform; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractChunkedSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.transform.TransformMetadata; +import static org.hamcrest.Matchers.equalTo; + public class TransformMetadataTests extends AbstractChunkedSerializingTestCase { @Override @@ -35,4 +41,39 @@ protected TransformMetadata mutateInstance(TransformMetadata instance) { .upgradeMode(instance.upgradeMode() == false) .build(); } + + public void testTransformMetadataFromClusterState() { + var expectedTransformMetadata = new TransformMetadata.Builder().resetMode(true).upgradeMode(true).build(); + var projectId = randomUniqueProjectId(); + var clusterState = ClusterState.builder(new ClusterName("_name")) + .metadata( + Metadata.builder().put(ProjectMetadata.builder(projectId).putCustom(TransformMetadata.TYPE, expectedTransformMetadata)) + ) + .build(); + + assertThat(TransformMetadata.transformMetadata(clusterState, projectId), equalTo(expectedTransformMetadata)); + assertThat(TransformMetadata.getTransformMetadata(clusterState), equalTo(expectedTransformMetadata)); + } + + public void testTransformMetadataFromMissingClusterState() { + assertThat(TransformMetadata.transformMetadata(null, randomUniqueProjectId()), equalTo(TransformMetadata.EMPTY_METADATA)); + assertThat(TransformMetadata.getTransformMetadata(null), equalTo(TransformMetadata.EMPTY_METADATA)); + } + + public void testTransformMetadataFromMissingProjectId() { + assertThat( + TransformMetadata.transformMetadata(ClusterState.builder(new ClusterName("_name")).build(), null), + equalTo(TransformMetadata.EMPTY_METADATA) + ); + } + + public void testTransformMetadataWhenAbsentFromClusterState() { + var projectId = randomUniqueProjectId(); + var clusterState = ClusterState.builder(new ClusterName("_name")) + .metadata(Metadata.builder().put(ProjectMetadata.builder(projectId))) + .build(); + + assertThat(TransformMetadata.transformMetadata(clusterState, projectId), equalTo(TransformMetadata.EMPTY_METADATA)); + assertThat(TransformMetadata.getTransformMetadata(clusterState), equalTo(TransformMetadata.EMPTY_METADATA)); + } } diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java index fa509143f9ba9..ec4122b3da7f2 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java @@ -12,10 +12,14 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.project.TestProjectResolvers; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RecoverySource; @@ -83,6 +87,7 @@ public class TransformPersistentTasksExecutorTests extends ESTestCase { private static ThreadPool threadPool; private TransformConfigAutoMigration autoMigration; + private ProjectId projectId; @BeforeClass public static void setUpThreadPool() { @@ -106,13 +111,15 @@ public static void tearDownThreadPool() { } @Before - public void initMocks() { + public void setUp() throws Exception { + super.setUp(); autoMigration = mock(); doAnswer(ans -> { ActionListener listener = ans.getArgument(1); listener.onResponse(ans.getArgument(0)); return null; }).when(autoMigration).migrateAndSave(any(), any()); + projectId = randomUniqueProjectId(); } public void testNodeVersionAssignment() { @@ -124,7 +131,8 @@ public void testNodeVersionAssignment() { executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ).getExecutorNode(), equalTo("current-data-node-with-1-tasks") ); @@ -132,7 +140,8 @@ public void testNodeVersionAssignment() { executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ).getExecutorNode(), equalTo("current-data-node-with-0-tasks-transform-remote-disabled") ); @@ -140,7 +149,8 @@ public void testNodeVersionAssignment() { executor.getAssignment( new TransformTaskParams("new-old-task-id", TransformConfigVersion.V_7_7_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ).getExecutorNode(), equalTo("past-data-node-1") ); @@ -154,7 +164,8 @@ public void testNodeAssignmentProblems() { Assignment assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), List.of(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -173,7 +184,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), List.of(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -189,7 +201,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -205,7 +218,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("dedicated-transform-node")); @@ -218,7 +232,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_8_0_0, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -235,7 +250,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("past-data-node-1")); @@ -248,7 +264,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -264,7 +281,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("current-data-node-with-0-tasks-transform-remote-disabled")); @@ -277,7 +295,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -299,29 +318,27 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("past-data-node-1")); } public void testVerifyIndicesPrimaryShardsAreActive() { - Metadata.Builder metadata = Metadata.builder(); + Metadata.Builder metadata = metadataWithProject(); RoutingTable.Builder routingTable = RoutingTable.builder(); addIndices(metadata, routingTable); ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")); - csBuilder.routingTable(routingTable.build()); + csBuilder.putRoutingTable(projectId, routingTable.build()); csBuilder.metadata(metadata); ClusterState cs = csBuilder.build(); - assertEquals( - 0, - TransformPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive(cs, TestIndexNameExpressionResolver.newInstance()).size() - ); + assertEquals(0, TransformPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive(cs, indexNameExpressionResolver()).size()); metadata = Metadata.builder(cs.metadata()); - routingTable = new RoutingTable.Builder(cs.routingTable()); + routingTable = new RoutingTable.Builder(cs.routingTable(projectId)); String indexToRemove = TransformInternalIndexConstants.LATEST_INDEX_NAME; if (randomBoolean()) { routingTable.remove(indexToRemove); @@ -342,11 +359,11 @@ public void testVerifyIndicesPrimaryShardsAreActive() { } csBuilder = ClusterState.builder(cs); - csBuilder.routingTable(routingTable.build()); + csBuilder.putRoutingTable(projectId, routingTable.build()); csBuilder.metadata(metadata); List result = TransformPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive( csBuilder.build(), - TestIndexNameExpressionResolver.newInstance() + indexNameExpressionResolver() ); assertEquals(1, result.size()); assertEquals(indexToRemove, result.get(0)); @@ -441,7 +458,7 @@ private void addIndices(Metadata.Builder metadata, RoutingTable.Builder routingT for (String indexName : indices) { IndexMetadata.Builder indexMetadata = IndexMetadata.builder(indexName); indexMetadata.settings(indexSettings(IndexVersion.current(), 1, 0).put(IndexMetadata.SETTING_INDEX_UUID, "_uuid")); - metadata.put(indexMetadata); + metadata.getProject(projectId).put(indexMetadata); Index index = new Index(indexName, "_uuid"); ShardId shardId = new ShardId(index, 0); ShardRouting shardRouting = ShardRouting.newUnassigned( @@ -556,7 +573,7 @@ private DiscoveryNodes.Builder buildNodes( } private ClusterState buildClusterState(DiscoveryNodes.Builder nodes) { - Metadata.Builder metadata = Metadata.builder().clusterUUID("cluster-uuid"); + Metadata.Builder metadata = metadataWithProject().clusterUUID("cluster-uuid"); RoutingTable.Builder routingTable = RoutingTable.builder(); addIndices(metadata, routingTable); PersistentTasksCustomMetadata.Builder pTasksBuilder = PersistentTasksCustomMetadata.builder() @@ -580,15 +597,19 @@ private ClusterState buildClusterState(DiscoveryNodes.Builder nodes) { ); PersistentTasksCustomMetadata pTasks = pTasksBuilder.build(); - metadata.putCustom(PersistentTasksCustomMetadata.TYPE, pTasks); + metadata.getProject(projectId).putCustom(PersistentTasksCustomMetadata.TYPE, pTasks); ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")).nodes(nodes); - csBuilder.routingTable(routingTable.build()); + csBuilder.putRoutingTable(projectId, routingTable.build()); csBuilder.metadata(metadata); return csBuilder.build(); } + private Metadata.Builder metadataWithProject() { + return Metadata.builder().put(ProjectMetadata.builder(projectId)); + } + private TransformPersistentTasksExecutor buildTaskExecutor() { var transformServices = transformServices( new InMemoryTransformConfigManager(), @@ -622,11 +643,15 @@ private TransformPersistentTasksExecutor buildTaskExecutor(TransformServices tra clusterService(), Settings.EMPTY, new DefaultTransformExtension(), - TestIndexNameExpressionResolver.newInstance(), + indexNameExpressionResolver(), autoMigration ); } + private IndexNameExpressionResolver indexNameExpressionResolver() { + return TestIndexNameExpressionResolver.newInstance(TestProjectResolvers.singleProjectOnly(projectId)); + } + private ClusterService clusterService() { var clusterService = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(Transform.NUM_FAILURE_RETRIES_SETTING)); diff --git a/x-pack/plugin/voting-only-node/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/votingonly/VotingOnlyNodePluginTests.java b/x-pack/plugin/voting-only-node/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/votingonly/VotingOnlyNodePluginTests.java index 92297f7585128..1175b6b7ea299 100644 --- a/x-pack/plugin/voting-only-node/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/votingonly/VotingOnlyNodePluginTests.java +++ b/x-pack/plugin/voting-only-node/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/votingonly/VotingOnlyNodePluginTests.java @@ -96,12 +96,12 @@ public void testPreferFullMasterOverVotingOnlyNodes() throws Exception { internalCluster().setBootstrapMasterNodeIndex(0); internalCluster().startNodes(2); internalCluster().startNode(addRoles(Set.of(DiscoveryNodeRole.VOTING_ONLY_NODE_ROLE))); - internalCluster().startDataOnlyNodes(randomInt(2)); - assertBusy( - () -> assertThat( - clusterAdmin().prepareState(TEST_REQUEST_TIMEOUT).get().getState().getLastCommittedConfiguration().getNodeIds().size(), - equalTo(3) - ) + final int numDataNodes = randomInt(2); + internalCluster().startDataOnlyNodes(numDataNodes); + internalCluster().validateClusterFormed(); + + awaitClusterState( + state -> state.getLastCommittedConfiguration().getNodeIds().size() == 3 && state.nodes().size() == 3 + numDataNodes ); final String originalMaster = internalCluster().getMasterName(); @@ -157,15 +157,14 @@ public void testVotingOnlyNodesCannotBeMasterWithoutFullMasterNodes() throws Exc internalCluster().setBootstrapMasterNodeIndex(0); internalCluster().startNode(); internalCluster().startNodes(2, addRoles(Set.of(DiscoveryNodeRole.VOTING_ONLY_NODE_ROLE))); - internalCluster().startDataOnlyNodes(randomInt(2)); - assertBusy( - () -> assertThat( - clusterAdmin().prepareState(TEST_REQUEST_TIMEOUT).get().getState().getLastCommittedConfiguration().getNodeIds().size(), - equalTo(3) - ) + final int numDataNodes = randomInt(2); + internalCluster().startDataOnlyNodes(numDataNodes); + internalCluster().validateClusterFormed(); + + awaitClusterState( + state -> state.getLastCommittedConfiguration().getNodeIds().size() == 3 && state.nodes().size() == 3 + numDataNodes ); - awaitMasterNode(); - final String oldMasterId = clusterAdmin().prepareState(TEST_REQUEST_TIMEOUT).get().getState().nodes().getMasterNodeId(); + final String oldMasterId = internalCluster().getMasterName(); internalCluster().stopCurrentMasterNode(); awaitMasterNotFound(); diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/Watcher.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/Watcher.java index 68cf0984d3808..657c307897425 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/Watcher.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/Watcher.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; @@ -35,6 +36,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.env.Environment; @@ -823,10 +825,11 @@ public String getFeatureName() { } @Override + @NotMultiProjectCapable(description = "Watcher is not available in serverless") public void prepareForIndicesMigration(ClusterService clusterService, Client client, ActionListener> listener) { Client originClient = new OriginSettingClient(client, WATCHER_ORIGIN); boolean manuallyStopped = Optional.ofNullable( - clusterService.state().metadata().getProject().custom(WatcherMetadata.TYPE) + clusterService.state().metadata().getProject(ProjectId.DEFAULT).custom(WatcherMetadata.TYPE) ).map(WatcherMetadata::manuallyStopped).orElse(false); if (manuallyStopped == false) { diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherIndexingListener.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherIndexingListener.java index e77c7aba6824d..b431334fe159c 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherIndexingListener.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherIndexingListener.java @@ -14,6 +14,7 @@ import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.routing.AllocationId; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.Murmur3HashFunction; @@ -22,6 +23,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.shard.IndexingOperationListener; import org.elasticsearch.index.shard.ShardId; @@ -251,10 +253,12 @@ private void checkWatchIndexHasChanged(IndexMetadata metadata, ClusterChangedEve * @param event The cluster changed event containing the new cluster state */ private void reloadConfiguration(String watchIndex, List localShardRouting, ClusterChangedEvent event) { + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + ProjectId projectId = ProjectId.DEFAULT; // changed alias means to always read a new configuration boolean isAliasChanged = watchIndex.equals(configuration.index) == false; - if (isAliasChanged || hasShardAllocationIdChanged(watchIndex, event.state())) { - IndexRoutingTable watchIndexRoutingTable = event.state().routingTable().index(watchIndex); + if (isAliasChanged || hasShardAllocationIdChanged(projectId, watchIndex, event.state())) { + IndexRoutingTable watchIndexRoutingTable = event.state().routingTable(projectId).index(watchIndex); Map ids = getLocalShardAllocationIds(localShardRouting, watchIndexRoutingTable); configuration = new Configuration(watchIndex, ids); } @@ -267,9 +271,9 @@ private void reloadConfiguration(String watchIndex, List localShar * @param state The new cluster state * @return true if the routing tables has changed and local shards are affected */ - private boolean hasShardAllocationIdChanged(String watchIndex, ClusterState state) { - List allStartedRelocatedShards = state.getRoutingTable().index(watchIndex).shardsWithState(STARTED); - allStartedRelocatedShards.addAll(state.getRoutingTable().index(watchIndex).shardsWithState(RELOCATING)); + private boolean hasShardAllocationIdChanged(ProjectId projectId, String watchIndex, ClusterState state) { + List allStartedRelocatedShards = state.routingTable(projectId).index(watchIndex).shardsWithState(STARTED); + allStartedRelocatedShards.addAll(state.routingTable(projectId).index(watchIndex).shardsWithState(RELOCATING)); // exit early, when there are shards, but the current configuration is inactive if (allStartedRelocatedShards.isEmpty() == false && configuration == INACTIVE) { diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleService.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleService.java index f202ba46aa832..674e9686235ed 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleService.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleService.java @@ -13,12 +13,15 @@ import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.RoutingNode; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.LifecycleListener; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.xpack.core.watcher.WatcherMetadata; @@ -153,8 +156,11 @@ public void clusterChanged(ClusterChangedEvent event) { // also check if non local shards have changed, as loosing a shard on a // remote node or adding a replica on a remote node needs to trigger a reload too Set localShardIds = localShards.stream().map(ShardRouting::shardId).collect(Collectors.toSet()); - List allShards = event.state().routingTable().index(watchIndex).shardsWithState(STARTED); - allShards.addAll(event.state().routingTable().index(watchIndex).shardsWithState(RELOCATING)); + + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + IndexRoutingTable routingTable = event.state().routingTable(ProjectId.DEFAULT).index(watchIndex); + List allShards = routingTable.shardsWithState(STARTED); + allShards.addAll(routingTable.shardsWithState(RELOCATING)); List localAffectedShardRoutings = allShards.stream() .filter(shardRouting -> localShardIds.contains(shardRouting.shardId())) // shardrouting is not comparable, so we need some order mechanism @@ -192,8 +198,9 @@ private void pauseExecution(String reason) { /** * check if watcher has been stopped manually via the stop API */ + @NotMultiProjectCapable(description = "Watcher is not available in serverless") private static boolean isWatcherStoppedManually(ClusterState state) { - WatcherMetadata watcherMetadata = state.getMetadata().getProject().custom(WatcherMetadata.TYPE); + WatcherMetadata watcherMetadata = state.getMetadata().getProject(ProjectId.DEFAULT).custom(WatcherMetadata.TYPE); return watcherMetadata != null && watcherMetadata.manuallyStopped(); } diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherService.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherService.java index 0ea9b432d3b0f..16f90e15d8afb 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherService.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/WatcherService.java @@ -18,6 +18,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.routing.AllocationId; import org.elasticsearch.cluster.routing.Murmur3HashFunction; import org.elasticsearch.cluster.routing.Preference; @@ -27,6 +28,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -134,6 +136,8 @@ public class WatcherService { * @return true if everything is good to go, so that the service can be started */ public boolean validate(ClusterState state) { + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + ProjectId projectId = ProjectId.DEFAULT; IndexMetadata watcherIndexMetadata = WatchStoreUtils.getConcreteIndex(Watch.INDEX, state.metadata()); IndexMetadata triggeredWatchesIndexMetadata = WatchStoreUtils.getConcreteIndex( TriggeredWatchStoreField.INDEX_NAME, @@ -160,7 +164,7 @@ public boolean validate(ClusterState state) { return watcherIndexMetadata == null || (watcherIndexMetadata.getState() == IndexMetadata.State.OPEN - && state.routingTable().index(watcherIndexMetadata.getIndex()).allPrimaryShardsActive()); + && state.routingTable(projectId).index(watcherIndexMetadata.getIndex()).allPrimaryShardsActive()); } catch (IllegalStateException e) { logger.warn("Validation error: cannot start watcher", e); return false; @@ -329,7 +333,8 @@ private Collection loadWatches(ClusterState clusterState) { List localShards = routingNode.shardsWithState(watchIndexName, RELOCATING, STARTED).toList(); // find out all allocation ids - List watchIndexShardRoutings = clusterState.getRoutingTable().allShards(watchIndexName); + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + List watchIndexShardRoutings = clusterState.routingTable(ProjectId.DEFAULT).allShards(watchIndexName); SearchRequest searchRequest = new SearchRequest(INDEX).scroll(scrollTimeout) .preference(Preference.ONLY_LOCAL.toString()) diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStore.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStore.java index dfa0c47493ed7..48b03f2c62e12 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStore.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStore.java @@ -24,8 +24,10 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.routing.Preference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.search.SearchHit; @@ -183,10 +185,11 @@ public Collection findTriggeredWatches(Collection watches return triggeredWatches; } + @NotMultiProjectCapable(description = "Watcher is not available in serverless") public static boolean validate(ClusterState state) { IndexMetadata indexMetadata = WatchStoreUtils.getConcreteIndex(TriggeredWatchStoreField.INDEX_NAME, state.metadata()); return indexMetadata == null || (indexMetadata.getState() == IndexMetadata.State.OPEN - && state.routingTable().index(indexMetadata.getIndex()).allPrimaryShardsActive()); + && state.routingTable(ProjectId.DEFAULT).index(indexMetadata.getIndex()).allPrimaryShardsActive()); } } diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/history/HistoryStore.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/history/HistoryStore.java index d8ba0c7e7a506..98bf3e7ab40a4 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/history/HistoryStore.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/history/HistoryStore.java @@ -13,9 +13,11 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.core.watcher.history.HistoryStoreField; @@ -91,10 +93,11 @@ public void forcePut(WatchRecord watchRecord) { * @param state The current cluster state * @return true, if history store is ready to be started */ + @NotMultiProjectCapable(description = "Watcher is not available in serverless") public static boolean validate(ClusterState state) { IndexMetadata indexMetadata = WatchStoreUtils.getConcreteIndex(HistoryStoreField.DATA_STREAM, state.metadata()); return indexMetadata == null || (indexMetadata.getState() == IndexMetadata.State.OPEN - && state.routingTable().index(indexMetadata.getIndex()).allPrimaryShardsActive()); + && state.routingTable(ProjectId.DEFAULT).index(indexMetadata.getIndex()).allPrimaryShardsActive()); } } diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistry.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistry.java index dca1f2bbc56ce..1c5232c7798b8 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistry.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistry.java @@ -9,8 +9,10 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.ilm.LifecyclePolicy; @@ -89,9 +91,10 @@ protected String getOrigin() { return WATCHER_ORIGIN; } + @NotMultiProjectCapable(description = "Watcher is not available in serverless") public static boolean validate(ClusterState state) { return state.getMetadata() - .getProject() + .getProject(ProjectId.DEFAULT) .templatesV2() .keySet() .stream() diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtils.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtils.java index f80eecae2ca8d..4b917e367d14c 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtils.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtils.java @@ -9,6 +9,9 @@ import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; @@ -26,7 +29,9 @@ public class WatchStoreUtils { * @throws IndexNotFoundException If no index exists */ public static IndexMetadata getConcreteIndex(String name, Metadata metadata) { - IndexAbstraction indexAbstraction = metadata.getProject().getIndicesLookup().get(name); + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + ProjectMetadata projectMetadata = metadata.getProject(ProjectId.DEFAULT); + IndexAbstraction indexAbstraction = projectMetadata.getIndicesLookup().get(name); if (indexAbstraction == null) { return null; } @@ -48,7 +53,7 @@ public static IndexMetadata getConcreteIndex(String name, Metadata metadata) { if (concreteIndex == null) { concreteIndex = indexAbstraction.getIndices().get(indexAbstraction.getIndices().size() - 1); } - return metadata.getProject().index(concreteIndex); + return projectMetadata.index(concreteIndex); } } diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java index 9a4e315fd1db3..5327e30b98344 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; @@ -30,6 +31,7 @@ import org.elasticsearch.cluster.routing.TestShardRouting; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.time.DateUtils; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.core.Strings; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; @@ -87,6 +89,8 @@ public class WatcherIndexingListenerTests extends ESTestCase { private Engine.IndexResult result = mock(Engine.IndexResult.class); private Engine.Index operation = mock(Engine.Index.class); private Engine.Delete delete = mock(Engine.Delete.class); + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + private final ProjectId projectId = ProjectId.DEFAULT; @Before public void setup() throws Exception { @@ -324,7 +328,7 @@ public void testClusterChangedWatchAliasChanged() throws Exception { when(routingTable.hasIndex(eq(newActiveWatchIndex))).thenReturn(true); ClusterState currentClusterState = mockClusterState(newActiveWatchIndex); - when(currentClusterState.routingTable()).thenReturn(routingTable); + when(currentClusterState.routingTable(projectId)).thenReturn(routingTable); DiscoveryNodes nodes = DiscoveryNodes.builder().add(newNode("node_1")).localNodeId("node_1").build(); when(currentClusterState.getNodes()).thenReturn(nodes); RoutingNodes routingNodes = mock(RoutingNodes.class); @@ -347,7 +351,7 @@ public void testClusterChangedWatchAliasChanged() throws Exception { when(currentClusterState.getRoutingNodes()).thenReturn(routingNodes); ClusterState previousClusterState = mockClusterState(randomAlphaOfLength(8)); - when(previousClusterState.routingTable()).thenReturn(routingTable); + when(previousClusterState.routingTable(projectId)).thenReturn(routingTable); ClusterChangedEvent event = new ClusterChangedEvent("something", currentClusterState, previousClusterState); listener.clusterChanged(event); @@ -364,12 +368,12 @@ public void testClusterChangedNoRoutingChanges() throws Exception { IndexRoutingTable watchRoutingTable = IndexRoutingTable.builder(index).build(); ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(watchRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTable).build()) .build(); ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1")).add(newNode("node_2"))) - .routingTable(RoutingTable.builder().add(watchRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTable).build()) .build(); Configuration configuration = listener.getConfiguration(); @@ -492,9 +496,9 @@ public void testOnNonDataNodes() { IndexMetadata.Builder indexMetadataBuilder = createIndexBuilder(Watch.INDEX, 1, 0); ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(indexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1).add(node2).add(node3)) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) .build(); IndexMetadata.Builder newIndexMetadataBuilder = createIndexBuilder(Watch.INDEX, 1, 1); @@ -502,9 +506,9 @@ public void testOnNonDataNodes() { ShardRouting replicaShardRouting = TestShardRouting.newShardRouting(shardId, "node3", false, STARTED); IndexRoutingTable.Builder newRoutingTable = IndexRoutingTable.builder(index).addShard(shardRouting).addShard(replicaShardRouting); ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(newIndexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(newIndexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1).add(node2).add(node3)) - .routingTable(RoutingTable.builder().add(newRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(newRoutingTable).build()) .build(); ClusterChangedEvent event = new ClusterChangedEvent("something", currentState, previousState); @@ -526,9 +530,9 @@ public void testListenerWorksIfOtherIndicesChange() throws Exception { .addShard(TestShardRouting.newShardRouting(firstShardId, "node_2", false, STARTED)); ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(indexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1).add(node2)) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) .build(); IndexMetadata.Builder currentMetadataBuilder = createIndexBuilder(Watch.INDEX, 2, 1); @@ -543,9 +547,9 @@ public void testListenerWorksIfOtherIndicesChange() throws Exception { .addShard(TestShardRouting.newShardRouting(watchShardId, "node_2", false, STARTED)); ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(currentMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(currentMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1).add(node2)) - .routingTable(RoutingTable.builder().add(currentRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(currentRoutingTable).build()) .build(); listener.setConfiguration(INACTIVE); @@ -588,9 +592,9 @@ public void testThatShardConfigurationIsNotReloadedNonAffectedShardsChange() { .addShard(secondShardRoutingReplica); ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(indexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId(localNode).add(node1).add(node2).add(node3).add(node4)) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) .build(); ClusterState emptyState = ClusterState.builder(new ClusterName("my-cluster")) @@ -611,9 +615,9 @@ public void testThatShardConfigurationIsNotReloadedNonAffectedShardsChange() { .addShard(secondShardRoutingPrimary); ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(newIndexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(newIndexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId(localNode).add(node1).add(node2).add(node3).add(node4)) - .routingTable(RoutingTable.builder().add(newRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(newRoutingTable).build()) .build(); ClusterChangedEvent nodeGoneEvent = new ClusterChangedEvent("something", currentState, previousState); @@ -636,9 +640,11 @@ public void testWithAliasPointingToTwoIndicesSetsWatcherInactive() { // regular cluster state with correct single alias pointing to watches index ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(createIndexBuilder("foo", 1, 0).putAlias(AliasMetadata.builder(Watch.INDEX)))) + .putProjectMetadata( + ProjectMetadata.builder(projectId).put(createIndexBuilder("foo", 1, 0).putAlias(AliasMetadata.builder(Watch.INDEX))) + ) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1)) - .routingTable(RoutingTable.builder().add(fooIndexRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(fooIndexRoutingTable).build()) .build(); // index bar pointing to .watches @@ -650,11 +656,15 @@ public void testWithAliasPointingToTwoIndicesSetsWatcherInactive() { // cluster state with two indices pointing to the .watches index ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata( - Metadata.builder().put(createIndexBuilder("foo", 1, 0).putAlias(AliasMetadata.builder(Watch.INDEX))).put(barIndexMetadata) + .putProjectMetadata( + ProjectMetadata.builder(projectId) + .put(createIndexBuilder("foo", 1, 0).putAlias(AliasMetadata.builder(Watch.INDEX))) + .put(barIndexMetadata) + .build() ) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1)) - .routingTable( + .putRoutingTable( + projectId, RoutingTable.builder().add(IndexRoutingTable.builder(fooIndex).addShard(fooShardRouting)).add(barIndexRoutingTable).build() ) .build(); @@ -699,7 +709,7 @@ public void testThatIndexingListenerBecomesInactiveOnClusterBlock() { private ClusterState mockClusterState(String watchIndex) { Metadata metadata = mock(Metadata.class); ProjectMetadata projectMetadata = mock(ProjectMetadata.class); - when(metadata.getProject()).thenReturn(projectMetadata); + when(metadata.getProject(projectId)).thenReturn(projectMetadata); if (watchIndex == null) { when(projectMetadata.getIndicesLookup()).thenReturn(Collections.emptySortedMap()); } else { diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleServiceTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleServiceTests.java index 40bd6c1adb46f..0915b1a9fa4fb 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleServiceTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherLifeCycleServiceTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.cluster.coordination.NoMasterBlockService; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; -import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; @@ -29,6 +29,7 @@ import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.TestShardRouting; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; @@ -73,6 +74,8 @@ public class WatcherLifeCycleServiceTests extends ESTestCase { private WatcherService watcherService; private WatcherLifeCycleService lifeCycleService; + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + private ProjectId projectId = ProjectId.DEFAULT; @Before public void prepareServices() { @@ -91,14 +94,18 @@ public void testNoRestartWithoutAllocationIdsConfigured() { IndexRoutingTable indexRoutingTable = IndexRoutingTable.builder(new Index("anything", "foo")).build(); ClusterState previousClusterState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) .build(); IndexRoutingTable watchRoutingTable = IndexRoutingTable.builder(new Index(Watch.INDEX, "foo")).build(); ClusterState clusterState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())).build()) + .putProjectMetadata( + ProjectMetadata.builder(projectId) + .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())) + .build() + ) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(watchRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTable).build()) .build(); when(watcherService.validate(clusterState)).thenReturn(true); @@ -126,8 +133,12 @@ public void testShutdown() { IndexRoutingTable watchRoutingTable = IndexRoutingTable.builder(new Index(Watch.INDEX, "foo")).build(); ClusterState clusterState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(watchRoutingTable).build()) - .metadata(Metadata.builder().put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTable).build()) + .putProjectMetadata( + ProjectMetadata.builder(projectId) + .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())) + .build() + ) .build(); when(watcherService.validate(clusterState)).thenReturn(true); @@ -152,18 +163,18 @@ public void testManualStartStop() { // required .numberOfShards(1) .numberOfReplicas(0); - Metadata.Builder metadataBuilder = Metadata.builder() + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId) .put(indexMetadataBuilder) .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())); if (randomBoolean()) { metadataBuilder.putCustom(WatcherMetadata.TYPE, new WatcherMetadata(false)); } - Metadata metadata = metadataBuilder.build(); + ProjectMetadata metadata = metadataBuilder.build(); IndexRoutingTable indexRoutingTable = indexRoutingTableBuilder.build(); ClusterState clusterState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) - .metadata(metadata) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) + .putProjectMetadata(metadata) .build(); when(watcherService.validate(clusterState)).thenReturn(true); @@ -171,8 +182,8 @@ public void testManualStartStop() { // mark watcher manually as stopped ClusterState stoppedClusterState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) - .metadata(Metadata.builder(metadata).putCustom(WatcherMetadata.TYPE, new WatcherMetadata(true)).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(metadata).putCustom(WatcherMetadata.TYPE, new WatcherMetadata(true)).build()) .build(); lifeCycleService.clusterChanged(new ClusterChangedEvent("foo", stoppedClusterState, clusterState)); @@ -210,25 +221,25 @@ public void testExceptionOnStart() { // required .numberOfShards(1) .numberOfReplicas(0); - Metadata.Builder metadataBuilder = Metadata.builder() + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId) .put(indexMetadataBuilder) .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())); if (randomBoolean()) { metadataBuilder.putCustom(WatcherMetadata.TYPE, new WatcherMetadata(false)); } - Metadata metadata = metadataBuilder.build(); + ProjectMetadata metadata = metadataBuilder.build(); IndexRoutingTable indexRoutingTable = indexRoutingTableBuilder.build(); ClusterState clusterState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) - .metadata(metadata) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) + .putProjectMetadata(metadata) .build(); // mark watcher manually as stopped ClusterState stoppedClusterState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) - .metadata(Metadata.builder(metadata).putCustom(WatcherMetadata.TYPE, new WatcherMetadata(true)).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(metadata).putCustom(WatcherMetadata.TYPE, new WatcherMetadata(true)).build()) .build(); lifeCycleService.clusterChanged(new ClusterChangedEvent("foo", stoppedClusterState, clusterState)); @@ -323,23 +334,22 @@ private ClusterChangedEvent[] masterChangeScenario() { // required .numberOfShards(1) .numberOfReplicas(0); - Metadata metadata = Metadata.builder() + ProjectMetadata metadata = ProjectMetadata.builder(projectId) .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())) .put(indexMetadataBuilder) .build(); - ProjectMetadata project = metadata.projects().values().iterator().next(); - GlobalRoutingTable globalRoutingTable = GlobalRoutingTable.builder().put(project.id(), routingTable).build(); + GlobalRoutingTable globalRoutingTable = GlobalRoutingTable.builder().put(metadata.id(), routingTable).build(); - ClusterState emptyState = ClusterState.builder(new ClusterName("my-cluster")).nodes(nodes).metadata(metadata).build(); + ClusterState emptyState = ClusterState.builder(new ClusterName("my-cluster")).nodes(nodes).putProjectMetadata(metadata).build(); ClusterState stateWithMasterNode1 = ClusterState.builder(new ClusterName("my-cluster")) .nodes(nodes.withMasterNodeId("node_1")) - .metadata(metadata) + .putProjectMetadata(metadata) .routingTable(globalRoutingTable) .build(); ClusterState stateWithMasterNode2 = ClusterState.builder(new ClusterName("my-cluster")) .nodes(nodes.withMasterNodeId("node_2")) - .metadata(metadata) + .putProjectMetadata(metadata) .routingTable(globalRoutingTable) .build(); @@ -369,8 +379,8 @@ public void testNoLocalShards() { .build(); ClusterState clusterStateWithLocalShards = ClusterState.builder(new ClusterName("my-cluster")) .nodes(nodes) - .routingTable(RoutingTable.builder().add(watchRoutingTable).build()) - .metadata(Metadata.builder().put(indexMetadata, false)) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadata, false)) .build(); // shard moved over to node 2 @@ -383,8 +393,8 @@ public void testNoLocalShards() { .build(); ClusterState clusterStateWithoutLocalShards = ClusterState.builder(new ClusterName("my-cluster")) .nodes(nodes) - .routingTable(RoutingTable.builder().add(watchRoutingTableNode2).build()) - .metadata(Metadata.builder().put(indexMetadata, false)) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTableNode2).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadata, false)) .build(); // set current allocation ids @@ -431,8 +441,8 @@ public void testReplicaWasAddedOrRemoved() { ClusterState stateWithPrimaryShard = ClusterState.builder(new ClusterName("my-cluster")) .nodes(discoveryNodes) - .routingTable(RoutingTable.builder().add(previousWatchRoutingTable).build()) - .metadata(Metadata.builder().put(indexMetadata, false)) + .putRoutingTable(projectId, RoutingTable.builder().add(previousWatchRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadata, false)) .build(); // add a replica in the local node @@ -452,8 +462,8 @@ public void testReplicaWasAddedOrRemoved() { ClusterState stateWithReplicaAdded = ClusterState.builder(new ClusterName("my-cluster")) .nodes(discoveryNodes) - .routingTable(RoutingTable.builder().add(currentWatchRoutingTable).build()) - .metadata(Metadata.builder().put(indexMetadata, false)) + .putRoutingTable(projectId, RoutingTable.builder().add(currentWatchRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadata, false)) .build(); // randomize between addition or removal of a replica @@ -497,9 +507,9 @@ public void testNonDataNode() { .settings(indexSettings(IndexVersion.current(), 1, 0)); ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(indexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1).add(node2).add(node3)) - .routingTable(RoutingTable.builder().add(indexRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTable).build()) .build(); IndexMetadata.Builder newIndexMetadataBuilder = IndexMetadata.builder(Watch.INDEX) @@ -508,9 +518,9 @@ public void testNonDataNode() { ShardRouting replicaShardRouting = TestShardRouting.newShardRouting(shardId, "node3", false, STARTED); IndexRoutingTable.Builder newRoutingTable = IndexRoutingTable.builder(index).addShard(shardRouting).addShard(replicaShardRouting); ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) - .metadata(Metadata.builder().put(newIndexMetadataBuilder)) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(newIndexMetadataBuilder)) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(node1).add(node2).add(node3)) - .routingTable(RoutingTable.builder().add(newRoutingTable).build()) + .putRoutingTable(projectId, RoutingTable.builder().add(newRoutingTable).build()) .build(); lifeCycleService.clusterChanged(new ClusterChangedEvent("any", currentState, previousState)); @@ -531,8 +541,8 @@ public void testThatMissingWatcherIndexMetadataOnlyResetsOnce() { ClusterState clusterStateWithWatcherIndex = ClusterState.builder(new ClusterName("my-cluster")) .nodes(nodes) - .routingTable(RoutingTable.builder().add(watchRoutingTable).build()) - .metadata(Metadata.builder().put(newIndexMetadataBuilder)) + .putRoutingTable(projectId, RoutingTable.builder().add(watchRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(newIndexMetadataBuilder)) .build(); ClusterState clusterStateWithoutWatcherIndex = ClusterState.builder(new ClusterName("my-cluster")).nodes(nodes).build(); @@ -556,12 +566,12 @@ public void testThatMissingWatcherIndexMetadataOnlyResetsOnce() { public void testWatcherServiceDoesNotStartIfIndexTemplatesAreMissing() throws Exception { DiscoveryNodes nodes = new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1")).build(); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); boolean isHistoryTemplateAdded = randomBoolean(); if (isHistoryTemplateAdded) { metadataBuilder.put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())); } - ClusterState state = ClusterState.builder(new ClusterName("my-cluster")).nodes(nodes).metadata(metadataBuilder).build(); + ClusterState state = ClusterState.builder(new ClusterName("my-cluster")).nodes(nodes).putProjectMetadata(metadataBuilder).build(); when(watcherService.validate(eq(state))).thenReturn(true); lifeCycleService.clusterChanged(new ClusterChangedEvent("any", state, state)); @@ -602,12 +612,12 @@ public void testMasterOnlyNodeCanStart() { } public void testDataNodeWithoutDataCanStart() { - Metadata metadata = Metadata.builder() + ProjectMetadata metadata = ProjectMetadata.builder(projectId) .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())) .build(); ClusterState state = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .metadata(metadata) + .putProjectMetadata(metadata) .build(); lifeCycleService.clusterChanged(new ClusterChangedEvent("test", state, state)); @@ -641,8 +651,8 @@ public void testWatcherReloadsOnNodeOutageWithWatcherShard() { ClusterState previousState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(previousDiscoveryNodes) - .routingTable(RoutingTable.builder().add(previousWatchRoutingTable).build()) - .metadata(Metadata.builder().put(indexMetadata, false)) + .putRoutingTable(projectId, RoutingTable.builder().add(previousWatchRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadata, false)) .build(); ShardRouting nowPrimaryShardRouting = replicaShardRouting.moveActiveReplicaToPrimary(); @@ -655,8 +665,8 @@ public void testWatcherReloadsOnNodeOutageWithWatcherShard() { ClusterState currentState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(currentDiscoveryNodes) - .routingTable(RoutingTable.builder().add(currentWatchRoutingTable).build()) - .metadata(Metadata.builder().put(indexMetadata, false)) + .putRoutingTable(projectId, RoutingTable.builder().add(currentWatchRoutingTable).build()) + .putProjectMetadata(ProjectMetadata.builder(projectId).put(indexMetadata, false)) .build(); // initialize the previous state, so all the allocation ids are loaded @@ -681,18 +691,18 @@ private void startWatcher() { // required .numberOfShards(1) .numberOfReplicas(0); - Metadata metadata = Metadata.builder() + ProjectMetadata metadata = ProjectMetadata.builder(projectId) .put(IndexTemplateMetadata.builder(HISTORY_TEMPLATE_NAME).patterns(randomIndexPatterns())) .put(indexMetadataBuilder) .build(); ClusterState state = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .routingTable(RoutingTable.builder().add(indexRoutingTableBuilder.build()).build()) - .metadata(metadata) + .putRoutingTable(projectId, RoutingTable.builder().add(indexRoutingTableBuilder.build()).build()) + .putProjectMetadata(metadata) .build(); ClusterState emptyState = ClusterState.builder(new ClusterName("my-cluster")) .nodes(new DiscoveryNodes.Builder().masterNodeId("node_1").localNodeId("node_1").add(newNode("node_1"))) - .metadata(metadata) + .putProjectMetadata(metadata) .build(); when(watcherService.validate(state)).thenReturn(true); diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherServiceTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherServiceTests.java index 70caeabc4971e..8dd19d925cf87 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherServiceTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherServiceTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; @@ -38,6 +39,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.shard.ShardId; @@ -84,6 +86,8 @@ public class WatcherServiceTests extends ESTestCase { private final Client client = mock(Client.class); + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + private final ProjectId projectId = ProjectId.DEFAULT; @Before public void configureMockClient() { @@ -113,10 +117,10 @@ void stopExecutor() {} }; ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); Settings indexSettings = indexSettings(IndexVersion.current(), 1, 1).build(); metadataBuilder.put(IndexMetadata.builder(Watch.INDEX).state(IndexMetadata.State.CLOSE).settings(indexSettings)); - csBuilder.metadata(metadataBuilder); + csBuilder.putProjectMetadata(metadataBuilder); assertThat(service.validate(csBuilder.build()), is(false)); } @@ -142,10 +146,10 @@ void stopExecutor() {} // cluster state setup, with one node, one shard ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); Settings indexSettings = indexSettings(IndexVersion.current(), 1, 1).build(); metadataBuilder.put(IndexMetadata.builder(Watch.INDEX).settings(indexSettings)); - csBuilder.metadata(metadataBuilder); + csBuilder.putProjectMetadata(metadataBuilder); Index watchIndex = new Index(Watch.INDEX, "uuid"); ShardId shardId = new ShardId(watchIndex, 0); @@ -157,7 +161,7 @@ void stopExecutor() {} ) .build(); RoutingTable routingTable = RoutingTable.builder().add(indexRoutingTable).build(); - csBuilder.routingTable(routingTable); + csBuilder.putRoutingTable(projectId, routingTable); csBuilder.nodes(new DiscoveryNodes.Builder().masterNodeId("node").localNodeId("node").add(newNode())); ClusterState clusterState = csBuilder.build(); @@ -165,7 +169,7 @@ void stopExecutor() {} // response setup, successful refresh response BroadcastResponse refreshResponse = mock(BroadcastResponse.class); when(refreshResponse.getSuccessfulShards()).thenReturn( - clusterState.getMetadata().getProject().indices().get(Watch.INDEX).getNumberOfShards() + clusterState.getMetadata().getProject(ProjectId.DEFAULT).indices().get(Watch.INDEX).getNumberOfShards() ); doAnswer(invocation -> { ActionListener listener = (ActionListener) invocation.getArguments()[2]; @@ -261,10 +265,10 @@ void refreshWatches(IndexMetadata indexMetadata) { }; ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); Settings indexSettings = indexSettings(IndexVersion.current(), 1, 1).build(); metadataBuilder.put(IndexMetadata.builder(Watch.INDEX).settings(indexSettings)); - csBuilder.metadata(metadataBuilder); + csBuilder.putProjectMetadata(metadataBuilder); ClusterState clusterState = csBuilder.build(); AtomicReference exceptionReference = new AtomicReference<>(); @@ -358,7 +362,7 @@ void stopExecutor() {} ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); Metadata metadata = mock(Metadata.class); ProjectMetadata project = mock(ProjectMetadata.class); - when(metadata.getProject()).thenReturn(project); + when(metadata.getProject(projectId)).thenReturn(project); // simulate exception in WatcherService's private loadWatches() when(project.getIndicesLookup()).thenThrow(RuntimeException.class); diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStoreTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStoreTests.java index 1cdb6debfbb80..d734d5816aa7d 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStoreTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStoreTests.java @@ -29,7 +29,8 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RoutingTable; @@ -40,6 +41,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexVersion; @@ -114,6 +116,8 @@ public void afterBulk(long executionId, BulkRequest request, Exception failure) throw new ElasticsearchException(failure); } }; + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + private final ProjectId projectId = ProjectId.DEFAULT; @Before public void init() { @@ -143,7 +147,7 @@ public void testValidateNoActivePrimaryShards() { ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("name")); RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); int numShards = 2 + randomInt(2); int numStartedShards = 1; @@ -173,8 +177,8 @@ public void testValidateNoActivePrimaryShards() { } routingTableBuilder.add(indexRoutingTableBuilder.build()); - csBuilder.metadata(metadataBuilder); - csBuilder.routingTable(routingTableBuilder.build()); + csBuilder.putProjectMetadata(metadataBuilder); + csBuilder.putRoutingTable(projectId, routingTableBuilder.build()); ClusterState cs = csBuilder.build(); assertThat(TriggeredWatchStore.validate(cs), is(false)); @@ -184,7 +188,7 @@ public void testFindTriggeredWatchesGoodCase() { ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); metadataBuilder.put(IndexMetadata.builder(TriggeredWatchStoreField.INDEX_NAME).settings(indexSettings)); final Index index = metadataBuilder.get(TriggeredWatchStoreField.INDEX_NAME).getIndex(); IndexRoutingTable.Builder indexRoutingTableBuilder = IndexRoutingTable.builder(index); @@ -196,8 +200,8 @@ public void testFindTriggeredWatchesGoodCase() { ); indexRoutingTableBuilder.addReplica(ShardRouting.Role.DEFAULT); routingTableBuilder.add(indexRoutingTableBuilder.build()); - csBuilder.metadata(metadataBuilder); - csBuilder.routingTable(routingTableBuilder.build()); + csBuilder.putProjectMetadata(metadataBuilder); + csBuilder.putRoutingTable(projectId, routingTableBuilder.build()); ClusterState cs = csBuilder.build(); doAnswer(invocation -> { @@ -298,7 +302,7 @@ public void testLoadStoreAsAlias() { ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); metadataBuilder.put( IndexMetadata.builder("triggered-watches-alias") .settings(indexSettings) @@ -314,8 +318,8 @@ public void testLoadStoreAsAlias() { ); indexRoutingTableBuilder.addReplica(ShardRouting.Role.DEFAULT); routingTableBuilder.add(indexRoutingTableBuilder.build()); - csBuilder.metadata(metadataBuilder); - csBuilder.routingTable(routingTableBuilder.build()); + csBuilder.putProjectMetadata(metadataBuilder); + csBuilder.putRoutingTable(projectId, routingTableBuilder.build()); ClusterState cs = csBuilder.build(); assertThat(TriggeredWatchStore.validate(cs), is(true)); @@ -326,7 +330,7 @@ public void testLoadStoreAsAlias() { public void testLoadingFailsWithTwoAliases() { ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); metadataBuilder.put( IndexMetadata.builder("triggered-watches-alias") @@ -355,8 +359,8 @@ public void testLoadingFailsWithTwoAliases() { ) ); - csBuilder.metadata(metadataBuilder); - csBuilder.routingTable(routingTableBuilder.build()); + csBuilder.putProjectMetadata(metadataBuilder); + csBuilder.putRoutingTable(projectId, routingTableBuilder.build()); ClusterState cs = csBuilder.build(); IllegalStateException e = expectThrows(IllegalStateException.class, () -> TriggeredWatchStore.validate(cs)); @@ -367,11 +371,11 @@ public void testLoadingFailsWithTwoAliases() { public void testTriggeredWatchesIndexIsClosed() { ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); metadataBuilder.put( IndexMetadata.builder(TriggeredWatchStoreField.INDEX_NAME).settings(indexSettings).state(IndexMetadata.State.CLOSE) ); - csBuilder.metadata(metadataBuilder); + csBuilder.putProjectMetadata(metadataBuilder); assertThat(TriggeredWatchStore.validate(csBuilder.build()), is(false)); } @@ -387,9 +391,9 @@ public void testTriggeredWatchesIndexDoesNotExistOnStartup() { public void testIndexNotFoundButInMetadata() { ClusterState.Builder csBuilder = new ClusterState.Builder(new ClusterName("_name")); - Metadata.Builder metadataBuilder = Metadata.builder() + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId) .put(IndexMetadata.builder(TriggeredWatchStoreField.INDEX_NAME).settings(indexSettings)); - csBuilder.metadata(metadataBuilder); + csBuilder.putProjectMetadata(metadataBuilder); ClusterState cs = csBuilder.build(); Watch watch = mock(Watch.class); diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistryTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistryTests.java index 6c88380a88c68..c8a18e7fe9792 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistryTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/support/WatcherIndexTemplateRegistryTests.java @@ -21,6 +21,8 @@ import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -28,6 +30,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -80,6 +83,8 @@ public class WatcherIndexTemplateRegistryTests extends ESTestCase { private ClusterService clusterService; private ThreadPool threadPool; private Client client; + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + private final ProjectId projectId = ProjectId.DEFAULT; @SuppressWarnings("unchecked") @Before @@ -335,9 +340,10 @@ private ClusterState createClusterState( return ClusterState.builder(new ClusterName("test")) .metadata( Metadata.builder() - .templates(indexTemplates) .transientSettings(nodeSettings) - .putCustom(IndexLifecycleMetadata.TYPE, ilmMeta) + .put( + ProjectMetadata.builder(projectId).templates(indexTemplates).putCustom(IndexLifecycleMetadata.TYPE, ilmMeta).build() + ) .build() ) .blocks(new ClusterBlocks.Builder().build()) @@ -380,7 +386,7 @@ private ClusterState createClusterState(Map existingTemplates) when(indexTemplate.indexPatterns()).thenReturn(Arrays.asList(generateRandomStringArray(10, 100, false, false))); templates.put(template.getKey(), indexTemplate); } - metadataBuilder.indexTemplates(templates); + metadataBuilder.put(ProjectMetadata.builder(projectId).indexTemplates(templates)); return ClusterState.builder(new ClusterName("foo")).metadata(metadataBuilder.build()).build(); } diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtilsTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtilsTests.java index 8662c28926ac2..acb4be9d3c586 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtilsTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/watch/WatchStoreUtilsTests.java @@ -14,8 +14,11 @@ import org.elasticsearch.cluster.metadata.DataStreamTestHelper; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.NotMultiProjectCapable; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.ESTestCase; @@ -27,10 +30,12 @@ import java.util.stream.Collectors; public class WatchStoreUtilsTests extends ESTestCase { + @NotMultiProjectCapable(description = "Watcher is not available in serverless") + private final ProjectId projectId = ProjectId.DEFAULT; public void testGetConcreteIndexForDataStream() { String dataStreamName = randomAlphaOfLength(20); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); Map customsBuilder = new HashMap<>(); Map dataStreams = new HashMap<>(); Map indexMetadataMapBuilder = new HashMap<>(); @@ -54,15 +59,18 @@ public void testGetConcreteIndexForDataStream() { dataStreamAliases ); customsBuilder.put(DataStreamMetadata.TYPE, dataStreamMetadata); - metadataBuilder.projectCustoms(customsBuilder); - IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(dataStreamName, metadataBuilder.build()); + metadataBuilder.customs(customsBuilder); + IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex( + dataStreamName, + Metadata.builder().put(metadataBuilder.build()).build() + ); assertNotNull(concreteIndex); assertEquals(indexNames.get(indexNames.size() - 1), concreteIndex.getIndex().getName()); } public void testGetConcreteIndexForAliasWithMultipleNonWritableIndices() { String aliasName = randomAlphaOfLength(20); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); AliasMetadata.Builder aliasMetadataBuilder = new AliasMetadata.Builder(aliasName); aliasMetadataBuilder.writeIndex(false); AliasMetadata aliasMetadata = aliasMetadataBuilder.build(); @@ -72,12 +80,15 @@ public void testGetConcreteIndexForAliasWithMultipleNonWritableIndices() { indexMetadataMapBuilder.put(indexName, createIndexMetaData(indexName, aliasMetadata)); } metadataBuilder.indices(indexMetadataMapBuilder); - expectThrows(IllegalStateException.class, () -> WatchStoreUtils.getConcreteIndex(aliasName, metadataBuilder.build())); + expectThrows( + IllegalStateException.class, + () -> WatchStoreUtils.getConcreteIndex(aliasName, Metadata.builder().put(metadataBuilder.build()).build()) + ); } public void testGetConcreteIndexForAliasWithMultipleIndicesWithWritable() { String aliasName = randomAlphaOfLength(20); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); AliasMetadata.Builder aliasMetadataBuilder = new AliasMetadata.Builder(aliasName); aliasMetadataBuilder.writeIndex(false); AliasMetadata nonWritableAliasMetadata = aliasMetadataBuilder.build(); @@ -100,14 +111,14 @@ public void testGetConcreteIndexForAliasWithMultipleIndicesWithWritable() { indexMetadataMapBuilder.put(indexName, createIndexMetaData(indexName, aliasMetadata)); } metadataBuilder.indices(indexMetadataMapBuilder); - IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(aliasName, metadataBuilder.build()); + IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(aliasName, Metadata.builder().put(metadataBuilder.build()).build()); assertNotNull(concreteIndex); assertEquals(indexNames.get(writableIndexIndex), concreteIndex.getIndex().getName()); } public void testGetConcreteIndexForAliasWithOneNonWritableIndex() { String aliasName = randomAlphaOfLength(20); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); AliasMetadata.Builder aliasMetadataBuilder = new AliasMetadata.Builder(aliasName); aliasMetadataBuilder.writeIndex(false); AliasMetadata aliasMetadata = aliasMetadataBuilder.build(); @@ -115,18 +126,18 @@ public void testGetConcreteIndexForAliasWithOneNonWritableIndex() { String indexName = aliasName + "_" + 0; indexMetadataMapBuilder.put(indexName, createIndexMetaData(indexName, aliasMetadata)); metadataBuilder.indices(indexMetadataMapBuilder); - IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(aliasName, metadataBuilder.build()); + IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(aliasName, Metadata.builder().put(metadataBuilder.build()).build()); assertNotNull(concreteIndex); assertEquals(indexName, concreteIndex.getIndex().getName()); } public void testGetConcreteIndexForConcreteIndex() { String indexName = randomAlphaOfLength(20); - Metadata.Builder metadataBuilder = Metadata.builder(); + ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(projectId); Map indexMetadataMapBuilder = new HashMap<>(); indexMetadataMapBuilder.put(indexName, createIndexMetaData(indexName, null)); metadataBuilder.indices(indexMetadataMapBuilder); - IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(indexName, metadataBuilder.build()); + IndexMetadata concreteIndex = WatchStoreUtils.getConcreteIndex(indexName, Metadata.builder().put(metadataBuilder.build()).build()); assertNotNull(concreteIndex); assertEquals(indexName, concreteIndex.getIndex().getName()); } diff --git a/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java b/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java index 3dbc4c1bf186f..b35d5669b1c3d 100644 --- a/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java +++ b/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java @@ -11,14 +11,11 @@ import org.elasticsearch.example.realm.CustomRealm; import org.elasticsearch.example.realm.CustomRoleMappingRealm; import org.elasticsearch.example.role.CustomInMemoryRolesProvider; -import org.elasticsearch.jdk.RuntimeVersionFeature; import org.elasticsearch.xpack.core.security.SecurityExtension; import org.elasticsearch.xpack.core.security.authc.AuthenticationFailureHandler; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authz.store.RoleRetrievalResult; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -35,17 +32,6 @@ */ public class ExampleSecurityExtension implements SecurityExtension { - static { - final boolean useEntitlements = true; - if (useEntitlements == false && RuntimeVersionFeature.isSecurityManagerAvailable()) { - // check that the extension's policy works. - AccessController.doPrivileged((PrivilegedAction) () -> { - System.getSecurityManager().checkPropertyAccess("myproperty"); - return null; - }); - } - } - @Override public String extensionName() { return "example"; diff --git a/x-pack/qa/smoke-test-plugins-ssl/build.gradle b/x-pack/qa/smoke-test-plugins-ssl/build.gradle index cbd837fc2ccf6..03e67bdf0dd4b 100644 --- a/x-pack/qa/smoke-test-plugins-ssl/build.gradle +++ b/x-pack/qa/smoke-test-plugins-ssl/build.gradle @@ -83,6 +83,8 @@ testClusters.matching { it.name == "yamlRestTest" }.configureEach { user username: "test_user", password: "x-pack-test-password" user username: "monitoring_agent", password: "x-pack-test-password", role: "remote_monitoring_agent" + systemProperty 'es.queryable_built_in_roles_enabled', 'false' + pluginPaths.each { pluginPath -> plugin pluginPath } diff --git a/x-pack/test/idp-fixture/src/main/resources/oidc/Dockerfile b/x-pack/test/idp-fixture/src/main/resources/oidc/Dockerfile index 858038d483349..92cd2f46436db 100644 --- a/x-pack/test/idp-fixture/src/main/resources/oidc/Dockerfile +++ b/x-pack/test/idp-fixture/src/main/resources/oidc/Dockerfile @@ -1,5 +1,5 @@ FROM c2id/c2id-server-demo:16.1.1 AS c2id -FROM openjdk:21-jdk-buster +FROM eclipse-temurin:17-noble # Using this to launch a fake server on container start; see `setup.sh` RUN apt-get update -qqy && apt-get install -qqy python3