diff --git a/.buildkite/pipelines/intake.yml b/.buildkite/pipelines/intake.yml index 31f6b6dce5fd2..ca4bdf7dfc951 100644 --- a/.buildkite/pipelines/intake.yml +++ b/.buildkite/pipelines/intake.yml @@ -56,7 +56,7 @@ steps: timeout_in_minutes: 300 matrix: setup: - BWC_VERSION: ["8.17.7", "8.18.2", "8.19.0", "9.0.2", "9.1.0"] + BWC_VERSION: ["8.17.6", "8.18.1", "8.19.0", "9.0.1", "9.1.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 diff --git a/.buildkite/pipelines/periodic-packaging.yml b/.buildkite/pipelines/periodic-packaging.yml index e2189ec313374..d9396d66a5507 100644 --- a/.buildkite/pipelines/periodic-packaging.yml +++ b/.buildkite/pipelines/periodic-packaging.yml @@ -303,8 +303,8 @@ steps: env: BWC_VERSION: 8.16.6 - - label: "{{matrix.image}} / 8.17.7 / packaging-tests-upgrade" - command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.17.7 + - label: "{{matrix.image}} / 8.17.6 / packaging-tests-upgrade" + command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.17.6 timeout_in_minutes: 300 matrix: setup: @@ -317,10 +317,10 @@ steps: machineType: custom-16-32768 buildDirectory: /dev/shm/bk env: - BWC_VERSION: 8.17.7 + BWC_VERSION: 8.17.6 - - label: "{{matrix.image}} / 8.18.2 / packaging-tests-upgrade" - command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.18.2 + - label: "{{matrix.image}} / 8.18.1 / packaging-tests-upgrade" + command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.18.1 timeout_in_minutes: 300 matrix: setup: @@ -333,7 +333,7 @@ steps: machineType: custom-16-32768 buildDirectory: /dev/shm/bk env: - BWC_VERSION: 8.18.2 + BWC_VERSION: 8.18.1 - label: "{{matrix.image}} / 8.19.0 / packaging-tests-upgrade" command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.19.0 @@ -351,8 +351,8 @@ steps: env: BWC_VERSION: 8.19.0 - - label: "{{matrix.image}} / 9.0.2 / packaging-tests-upgrade" - command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v9.0.2 + - label: "{{matrix.image}} / 9.0.1 / packaging-tests-upgrade" + command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v9.0.1 timeout_in_minutes: 300 matrix: setup: @@ -365,7 +365,7 @@ steps: machineType: custom-16-32768 buildDirectory: /dev/shm/bk env: - BWC_VERSION: 9.0.2 + BWC_VERSION: 9.0.1 - label: "{{matrix.image}} / 9.1.0 / packaging-tests-upgrade" command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v9.1.0 diff --git a/.buildkite/pipelines/periodic.template.yml b/.buildkite/pipelines/periodic.template.yml index 14f5a8105996b..dac8147df8de3 100644 --- a/.buildkite/pipelines/periodic.template.yml +++ b/.buildkite/pipelines/periodic.template.yml @@ -78,8 +78,8 @@ steps: BWC_VERSION: "{{matrix.BWC_VERSION}}" - group: java-matrix steps: - - label: "{{matrix.ES_RUNTIME_JAVA}} / {{matrix.GRADLE_TASK}} / java-matrix" - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true $$GRADLE_TASK + - label: "{{matrix.ES_RUNTIME_JAVA}} / entitlements={{matrix.ENTITLEMENTS_ENABLED}} / {{matrix.GRADLE_TASK}} / java-matrix" + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true -Dtests.jvm.argline="-Des.entitlements.enabled=$$ENTITLEMENTS_ENABLED" $$GRADLE_TASK timeout_in_minutes: 300 matrix: setup: @@ -94,6 +94,9 @@ steps: - checkPart4 - checkPart5 - checkRestCompat + ENTITLEMENTS_ENABLED: + - "true" + - "false" agents: provider: gcp image: family/elasticsearch-ubuntu-2004 @@ -102,6 +105,7 @@ steps: env: ES_RUNTIME_JAVA: "{{matrix.ES_RUNTIME_JAVA}}" GRADLE_TASK: "{{matrix.GRADLE_TASK}}" + ENTITLEMENTS_ENABLED: "{{matrix.ENTITLEMENTS_ENABLED}}" - label: "{{matrix.ES_RUNTIME_JAVA}} / {{matrix.BWC_VERSION}} / java-matrix-bwc" command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v$$BWC_VERSION#bwcTest timeout_in_minutes: 300 diff --git a/.buildkite/pipelines/periodic.yml b/.buildkite/pipelines/periodic.yml index c007452bb4f6d..46b4c6b8d9dba 100644 --- a/.buildkite/pipelines/periodic.yml +++ b/.buildkite/pipelines/periodic.yml @@ -325,8 +325,8 @@ steps: - signal_reason: agent_stop limit: 3 - - label: 8.17.7 / bwc - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.17.7#bwcTest + - label: 8.17.6 / bwc + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.17.6#bwcTest timeout_in_minutes: 300 agents: provider: gcp @@ -335,7 +335,7 @@ steps: buildDirectory: /dev/shm/bk preemptible: true env: - BWC_VERSION: 8.17.7 + BWC_VERSION: 8.17.6 retry: automatic: - exit_status: "-1" @@ -344,8 +344,8 @@ steps: - signal_reason: agent_stop limit: 3 - - label: 8.18.2 / bwc - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.18.2#bwcTest + - label: 8.18.1 / bwc + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.18.1#bwcTest timeout_in_minutes: 300 agents: provider: gcp @@ -354,7 +354,7 @@ steps: buildDirectory: /dev/shm/bk preemptible: true env: - BWC_VERSION: 8.18.2 + BWC_VERSION: 8.18.1 retry: automatic: - exit_status: "-1" @@ -382,8 +382,8 @@ steps: - signal_reason: agent_stop limit: 3 - - label: 9.0.2 / bwc - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v9.0.2#bwcTest + - label: 9.0.1 / bwc + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v9.0.1#bwcTest timeout_in_minutes: 300 agents: provider: gcp @@ -392,7 +392,7 @@ steps: buildDirectory: /dev/shm/bk preemptible: true env: - BWC_VERSION: 9.0.2 + BWC_VERSION: 9.0.1 retry: automatic: - exit_status: "-1" @@ -486,7 +486,7 @@ steps: setup: ES_RUNTIME_JAVA: - openjdk21 - BWC_VERSION: ["8.17.7", "8.18.2", "8.19.0", "9.0.2", "9.1.0"] + BWC_VERSION: ["8.17.6", "8.18.1", "8.19.0", "9.0.1", "9.1.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 @@ -497,8 +497,8 @@ steps: BWC_VERSION: "{{matrix.BWC_VERSION}}" - group: java-matrix steps: - - label: "{{matrix.ES_RUNTIME_JAVA}} / {{matrix.GRADLE_TASK}} / java-matrix" - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true $$GRADLE_TASK + - label: "{{matrix.ES_RUNTIME_JAVA}} / entitlements={{matrix.ENTITLEMENTS_ENABLED}} / {{matrix.GRADLE_TASK}} / java-matrix" + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true -Dtests.jvm.argline="-Des.entitlements.enabled=$$ENTITLEMENTS_ENABLED" $$GRADLE_TASK timeout_in_minutes: 300 matrix: setup: @@ -513,6 +513,9 @@ steps: - checkPart4 - checkPart5 - checkRestCompat + ENTITLEMENTS_ENABLED: + - "true" + - "false" agents: provider: gcp image: family/elasticsearch-ubuntu-2004 @@ -521,6 +524,7 @@ steps: env: ES_RUNTIME_JAVA: "{{matrix.ES_RUNTIME_JAVA}}" GRADLE_TASK: "{{matrix.GRADLE_TASK}}" + ENTITLEMENTS_ENABLED: "{{matrix.ENTITLEMENTS_ENABLED}}" - label: "{{matrix.ES_RUNTIME_JAVA}} / {{matrix.BWC_VERSION}} / java-matrix-bwc" command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v$$BWC_VERSION#bwcTest timeout_in_minutes: 300 @@ -529,7 +533,7 @@ steps: ES_RUNTIME_JAVA: - openjdk21 - openjdk23 - BWC_VERSION: ["8.17.7", "8.18.2", "8.19.0", "9.0.2", "9.1.0"] + BWC_VERSION: ["8.17.6", "8.18.1", "8.19.0", "9.0.1", "9.1.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 diff --git a/.buildkite/pipelines/pull-request/part-1-entitlements.yml b/.buildkite/pipelines/pull-request/part-1-entitlements.yml new file mode 100644 index 0000000000000..abb9edf67484f --- /dev/null +++ b/.buildkite/pipelines/pull-request/part-1-entitlements.yml @@ -0,0 +1,11 @@ +config: + allow-labels: "test-entitlements" +steps: + - label: part-1-entitlements + command: .ci/scripts/run-gradle.sh -Dignore.tests.seed -Dtests.jvm.argline="-Des.entitlements.enabled=true" checkPart1 + timeout_in_minutes: 300 + agents: + provider: gcp + image: family/elasticsearch-ubuntu-2004 + machineType: custom-32-98304 + buildDirectory: /dev/shm/bk diff --git a/.buildkite/pipelines/pull-request/part-2-entitlements.yml b/.buildkite/pipelines/pull-request/part-2-entitlements.yml new file mode 100644 index 0000000000000..ef889f3819c5b --- /dev/null +++ b/.buildkite/pipelines/pull-request/part-2-entitlements.yml @@ -0,0 +1,11 @@ +config: + allow-labels: "test-entitlements" +steps: + - label: part-2-entitlements + command: .ci/scripts/run-gradle.sh -Dignore.tests.seed -Dtests.jvm.argline="-Des.entitlements.enabled=true" checkPart2 + timeout_in_minutes: 300 + agents: + provider: gcp + image: family/elasticsearch-ubuntu-2004 + machineType: custom-32-98304 + buildDirectory: /dev/shm/bk diff --git a/.buildkite/pipelines/pull-request/part-3-entitlements.yml b/.buildkite/pipelines/pull-request/part-3-entitlements.yml new file mode 100644 index 0000000000000..c31ae5e6a4ce3 --- /dev/null +++ b/.buildkite/pipelines/pull-request/part-3-entitlements.yml @@ -0,0 +1,11 @@ +config: + allow-labels: "test-entitlements" +steps: + - label: part-3-entitlements + command: .ci/scripts/run-gradle.sh -Dignore.tests.seed -Dtests.jvm.argline="-Des.entitlements.enabled=true" checkPart3 + timeout_in_minutes: 300 + agents: + provider: gcp + image: family/elasticsearch-ubuntu-2004 + machineType: custom-32-98304 + buildDirectory: /dev/shm/bk diff --git a/.buildkite/pipelines/pull-request/part-4-entitlements.yml b/.buildkite/pipelines/pull-request/part-4-entitlements.yml new file mode 100644 index 0000000000000..67172f891b4b6 --- /dev/null +++ b/.buildkite/pipelines/pull-request/part-4-entitlements.yml @@ -0,0 +1,11 @@ +config: + allow-labels: "test-entitlements" +steps: + - label: part-4-entitlements + command: .ci/scripts/run-gradle.sh -Dignore.tests.seed -Dtests.jvm.argline="-Des.entitlements.enabled=true" checkPart4 + timeout_in_minutes: 300 + agents: + provider: gcp + image: family/elasticsearch-ubuntu-2004 + machineType: n1-standard-32 + buildDirectory: /dev/shm/bk diff --git a/.buildkite/pipelines/pull-request/part-5-entitlements.yml b/.buildkite/pipelines/pull-request/part-5-entitlements.yml new file mode 100644 index 0000000000000..5a92282361576 --- /dev/null +++ b/.buildkite/pipelines/pull-request/part-5-entitlements.yml @@ -0,0 +1,11 @@ +config: + allow-labels: "test-entitlements" +steps: + - label: part-5-entitlements + command: .ci/scripts/run-gradle.sh -Dignore.tests.seed -Dtests.jvm.argline="-Des.entitlements.enabled=true" checkPart5 + timeout_in_minutes: 300 + agents: + provider: gcp + image: family/elasticsearch-ubuntu-2004 + machineType: custom-32-98304 + buildDirectory: /dev/shm/bk diff --git a/.ci/bwcVersions b/.ci/bwcVersions index 2cd0d1da12f2e..f4130faa228e5 100644 --- a/.ci/bwcVersions +++ b/.ci/bwcVersions @@ -16,8 +16,8 @@ BWC_VERSION: - "8.14.3" - "8.15.5" - "8.16.6" - - "8.17.7" - - "8.18.2" + - "8.17.6" + - "8.18.1" - "8.19.0" - - "9.0.2" + - "9.0.1" - "9.1.0" diff --git a/.ci/snapshotBwcVersions b/.ci/snapshotBwcVersions index 3f69c8fff9457..286c69b082017 100644 --- a/.ci/snapshotBwcVersions +++ b/.ci/snapshotBwcVersions @@ -1,6 +1,6 @@ BWC_VERSION: - - "8.17.7" - - "8.18.2" + - "8.17.6" + - "8.18.1" - "8.19.0" - - "9.0.2" + - "9.0.1" - "9.1.0" diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java index 6ed1294e16299..afac7204f110b 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java @@ -123,7 +123,7 @@ private LogicalPlan plan(String query) { } @Benchmark - public void manyFields(Blackhole blackhole) { + public void run(Blackhole blackhole) { blackhole.consume(plan("FROM test | LIMIT 10")); } } diff --git a/build-tools-internal/gradle/wrapper/gradle-wrapper.properties b/build-tools-internal/gradle/wrapper/gradle-wrapper.properties index f373f37ad8290..2a6e21b2ba89a 100644 --- a/build-tools-internal/gradle/wrapper/gradle-wrapper.properties +++ b/build-tools-internal/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=efe9a3d147d948d7528a9887fa35abcf24ca1a43ad06439996490f77569b02d1 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-all.zip +distributionSha256Sum=fba8464465835e74f7270bbf43d6d8a8d7709ab0a43ce1aa3323f73e9aa0c612 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.13-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionModuleCheckTaskProvider.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionModuleCheckTaskProvider.java index 0b71460e8d92b..92a8db6b5b913 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionModuleCheckTaskProvider.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionModuleCheckTaskProvider.java @@ -60,6 +60,7 @@ public class InternalDistributionModuleCheckTaskProvider { "org.elasticsearch.nativeaccess", "org.elasticsearch.plugin", "org.elasticsearch.plugin.analysis", + "org.elasticsearch.securesm", "org.elasticsearch.server", "org.elasticsearch.simdvec", "org.elasticsearch.tdigest", diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestApiTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestApiTask.java index aa932a717858f..1617f317d52c9 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestApiTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestApiTask.java @@ -25,7 +25,8 @@ import org.gradle.api.tasks.SkipWhenEmpty; import org.gradle.api.tasks.TaskAction; import org.gradle.api.tasks.util.PatternFilterable; -import org.gradle.api.tasks.util.internal.PatternSetFactory; +import org.gradle.api.tasks.util.PatternSet; +import org.gradle.internal.Factory; import java.io.File; import java.io.IOException; @@ -64,14 +65,14 @@ public class CopyRestApiTask extends DefaultTask { @Inject public CopyRestApiTask( ProjectLayout projectLayout, - PatternSetFactory patternSetFactory, + Factory patternSetFactory, FileSystemOperations fileSystemOperations, ObjectFactory objectFactory ) { this.include = objectFactory.listProperty(String.class); this.outputResourceDir = objectFactory.directoryProperty(); this.additionalYamlTestsDir = objectFactory.directoryProperty(); - this.patternSet = patternSetFactory.createPatternSet(); + this.patternSet = patternSetFactory.create(); this.projectLayout = projectLayout; this.fileSystemOperations = fileSystemOperations; } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java index d6e888e33a3d5..6890cfb652952 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/CopyRestTestsTask.java @@ -29,7 +29,8 @@ import org.gradle.api.tasks.SkipWhenEmpty; import org.gradle.api.tasks.TaskAction; import org.gradle.api.tasks.util.PatternFilterable; -import org.gradle.api.tasks.util.internal.PatternSetFactory; +import org.gradle.api.tasks.util.PatternSet; +import org.gradle.internal.Factory; import java.io.File; import java.util.Map; @@ -65,25 +66,25 @@ public abstract class CopyRestTestsTask extends DefaultTask { private final ProjectLayout projectLayout; private final FileSystemOperations fileSystemOperations; + @Inject + public abstract FileOperations getFileOperations(); + @Inject public CopyRestTestsTask( ProjectLayout projectLayout, - PatternSetFactory patternSetFactory, + Factory patternSetFactory, FileSystemOperations fileSystemOperations, ObjectFactory objectFactory ) { this.includeCore = objectFactory.listProperty(String.class); this.includeXpack = objectFactory.listProperty(String.class); this.outputResourceDir = objectFactory.directoryProperty(); - this.corePatternSet = patternSetFactory.createPatternSet(); - this.xpackPatternSet = patternSetFactory.createPatternSet(); + this.corePatternSet = patternSetFactory.create(); + this.xpackPatternSet = patternSetFactory.create(); this.projectLayout = projectLayout; this.fileSystemOperations = fileSystemOperations; } - @Inject - public abstract FileOperations getFileOperations(); - @Input public ListProperty getIncludeCore() { return includeCore; diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java index 6d6590429feb1..ba242a8e23861 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/test/rest/compat/compat/RestCompatTestTransformTask.java @@ -58,7 +58,7 @@ import org.gradle.api.tasks.TaskAction; import org.gradle.api.tasks.util.PatternFilterable; import org.gradle.api.tasks.util.PatternSet; -import org.gradle.api.tasks.util.internal.PatternSetFactory; +import org.gradle.internal.Factory; import java.io.File; import java.io.IOException; @@ -98,13 +98,18 @@ public abstract class RestCompatTestTransformTask extends DefaultTask { // PatternFilterable -> list of full test names and reasons. Needed for 1 pattern may include many tests and reasons private final Map>> skippedTestByTestNameTransformations = new HashMap<>(); + @Inject + protected Factory getPatternSetFactory() { + throw new UnsupportedOperationException(); + } + @Inject public RestCompatTestTransformTask(FileSystemOperations fileSystemOperations, ObjectFactory objectFactory) { this.fileSystemOperations = fileSystemOperations; this.compatibleVersion = Version.fromString(VersionProperties.getVersions().get("elasticsearch")).getMajor() - 1; this.sourceDirectory = objectFactory.directoryProperty(); this.outputDirectory = objectFactory.directoryProperty(); - this.testPatternSet = getPatternSetFactory().createPatternSet(); + this.testPatternSet = getPatternSetFactory().create(); this.testPatternSet.include("/*" + "*/*.yml"); // concat these strings to keep build from thinking this is invalid javadoc // always inject compat headers headers.put("Content-Type", "application/vnd.elasticsearch+json;compatible-with=" + compatibleVersion); @@ -112,9 +117,6 @@ public RestCompatTestTransformTask(FileSystemOperations fileSystemOperations, Ob getTransformations().add(new InjectHeaders(headers, Sets.newHashSet(RestCompatTestTransformTask::doesNotHaveCatOperation))); } - @Inject - protected abstract PatternSetFactory getPatternSetFactory(); - private static boolean doesNotHaveCatOperation(ObjectNode doNodeValue) { final Iterator fieldNamesIterator = doNodeValue.fieldNames(); while (fieldNamesIterator.hasNext()) { @@ -142,7 +144,7 @@ public void skipTest(String fullTestName, String reason) { ); } - PatternSet skippedPatternSet = getPatternSetFactory().createPatternSet(); + PatternSet skippedPatternSet = getPatternSetFactory().create(); // create file patterns for all a1/a2/a3/b.yml possibilities. for (int i = testParts.length - 1; i > 1; i--) { final String lastPart = testParts[i]; @@ -156,7 +158,7 @@ public void skipTest(String fullTestName, String reason) { } public void skipTestsByFilePattern(String filePattern, String reason) { - PatternSet skippedPatternSet = getPatternSetFactory().createPatternSet(); + PatternSet skippedPatternSet = getPatternSetFactory().create(); skippedPatternSet.include(filePattern); skippedTestByFilePatternTransformations.put(skippedPatternSet, reason); } diff --git a/build-tools-internal/src/main/resources/checkstyle.xml b/build-tools-internal/src/main/resources/checkstyle.xml index d50af511ebeac..9ed31d993909e 100644 --- a/build-tools-internal/src/main/resources/checkstyle.xml +++ b/build-tools-internal/src/main/resources/checkstyle.xml @@ -68,11 +68,6 @@ - - - - - diff --git a/build-tools-internal/src/main/resources/minimumGradleVersion b/build-tools-internal/src/main/resources/minimumGradleVersion index b9d71048250a3..4e28b0862495c 100644 --- a/build-tools-internal/src/main/resources/minimumGradleVersion +++ b/build-tools-internal/src/main/resources/minimumGradleVersion @@ -1 +1 @@ -8.14 \ No newline at end of file +8.12.1 \ No newline at end of file diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/plugin/BasePluginBuildPlugin.java b/build-tools/src/main/java/org/elasticsearch/gradle/plugin/BasePluginBuildPlugin.java index 9e20ce64ed88e..a1c003c4c315d 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/plugin/BasePluginBuildPlugin.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/plugin/BasePluginBuildPlugin.java @@ -183,7 +183,11 @@ private static CopySpec createBundleSpec( ) { var bundleSpec = project.copySpec(); bundleSpec.from(buildProperties); - bundleSpec.from(pluginMetadata); + bundleSpec.from(pluginMetadata, copySpec -> { + // metadata (eg custom security policy) + // the codebases properties file is only for tests and not needed in production + copySpec.exclude("plugin-security.codebases"); + }); bundleSpec.from( (Callable>) () -> project.getPluginManager().hasPlugin("com.gradleup.shadow") ? project.getTasks().named("shadowJar") diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/test/SystemPropertyCommandLineArgumentProvider.java b/build-tools/src/main/java/org/elasticsearch/gradle/test/SystemPropertyCommandLineArgumentProvider.java index 02146ee454d3c..70be689ca637f 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/test/SystemPropertyCommandLineArgumentProvider.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/test/SystemPropertyCommandLineArgumentProvider.java @@ -8,6 +8,7 @@ */ package org.elasticsearch.gradle.test; +import org.gradle.api.provider.Provider; import org.gradle.api.tasks.Input; import org.gradle.process.CommandLineArgumentProvider; @@ -19,6 +20,10 @@ public class SystemPropertyCommandLineArgumentProvider implements CommandLineArgumentProvider { private final Map systemProperties = new LinkedHashMap<>(); + public void systemProperty(String key, Provider value) { + systemProperties.put(key, (Supplier) () -> String.valueOf(value.get())); + } + public void systemProperty(String key, Supplier value) { systemProperties.put(key, value); } diff --git a/distribution/archives/integ-test-zip/src/javaRestTest/resources/plugin-security.policy b/distribution/archives/integ-test-zip/src/javaRestTest/resources/plugin-security.policy new file mode 100644 index 0000000000000..f0cb0d58d3c1a --- /dev/null +++ b/distribution/archives/integ-test-zip/src/javaRestTest/resources/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // Needed to read the log file + permission java.io.FilePermission "@tests.logfile@", "read"; +}; diff --git a/distribution/docker/src/docker/Dockerfile.ess-fips b/distribution/docker/src/docker/Dockerfile.ess-fips index 27f03a40a056f..58ecf45820d63 100644 --- a/distribution/docker/src/docker/Dockerfile.ess-fips +++ b/distribution/docker/src/docker/Dockerfile.ess-fips @@ -96,6 +96,15 @@ COPY fips/resources/fips_java.policy /usr/share/elasticsearch/config/fips_java.p WORKDIR /usr/share/elasticsearch/config +## Add fips specific JVM options +RUN cat < /usr/share/elasticsearch/config/jvm.options.d/fips.options +-Djavax.net.ssl.keyStoreType=BCFKS +-Dorg.bouncycastle.fips.approved_only=true +-Djava.security.properties=config/fips_java.security +-Djava.security.policy=config/fips_java.policy +EOF + + ################################################################################ # Build stage 2 (the actual Elasticsearch image): # @@ -127,10 +136,6 @@ ENV ELASTIC_CONTAINER=true WORKDIR /usr/share/elasticsearch COPY --from=builder --chown=0:0 /usr/share/elasticsearch /usr/share/elasticsearch -COPY --from=builder --chown=0:0 /fips/libs/*.jar /usr/share/elasticsearch/lib/ -COPY --from=builder --chown=0:0 /opt /opt - -ENV ES_PLUGIN_ARCHIVE_DIR=/opt/plugins/archive ENV PATH=/usr/share/elasticsearch/bin:\$PATH ENV SHELL=/bin/bash COPY ${bin_dir}/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh @@ -154,28 +159,6 @@ RUN chmod g=u /etc/passwd && \\ RUN ln -sf /etc/ssl/certs/java/cacerts /usr/share/elasticsearch/jdk/lib/security/cacerts -# Convert cacerts (PKCS12) to BCFKS format using POSIX-compatible shell syntax -RUN printf "\\n" | jdk/bin/keytool -importkeystore \ - -srckeystore /usr/share/elasticsearch/jdk/lib/security/cacerts \ - -srcstoretype PKCS12 \ - -destkeystore config/cacerts.bcfks \ - -deststorepass passwordcacert \ - -deststoretype BCFKS \ - -providerclass org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider \ - -providerpath lib/bc-fips-1.0.2.5.jar \ - -destprovidername BCFIPS - - -## Add fips specific JVM options -RUN cat < /usr/share/elasticsearch/config/jvm.options.d/fips.options --Djavax.net.ssl.keyStoreType=BCFKS --Dorg.bouncycastle.fips.approved_only=true --Djava.security.properties=config/fips_java.security --Djava.security.policy=config/fips_java.policy --Djavax.net.ssl.trustStore=config/cacerts.bcfks --Djavax.net.ssl.trustStorePassword=passwordcacert -EOF - EXPOSE 9200 9300 LABEL org.label-schema.build-date="${build_date}" \\ @@ -208,16 +191,16 @@ LABEL name="Elasticsearch" \\ RUN mkdir /licenses && ln LICENSE.txt /licenses/LICENSE -# Generate a stub command that will be overwritten at runtime -RUN mkdir /app && \\ - echo -e '#!/bin/bash\\nexec /usr/local/bin/docker-entrypoint.sh eswrapper' > /app/elasticsearch.sh && \\ - chmod 0555 /app/elasticsearch.sh - ENTRYPOINT ["/sbin/tini", "--"] CMD ["/app/elasticsearch.sh"] USER 1000:0 +COPY --from=builder --chown=0:0 /opt /opt +ENV ES_PLUGIN_ARCHIVE_DIR=/opt/plugins/archive +WORKDIR /usr/share/elasticsearch +COPY --from=builder --chown=0:0 /fips/libs/*.jar /usr/share/elasticsearch/lib/ + ################################################################################ # End of multi-stage Dockerfile ################################################################################ diff --git a/distribution/tools/plugin-cli/src/main/java/org/elasticsearch/plugins/cli/InstallPluginAction.java b/distribution/tools/plugin-cli/src/main/java/org/elasticsearch/plugins/cli/InstallPluginAction.java index 2798b3353259b..0733fce0f5c77 100644 --- a/distribution/tools/plugin-cli/src/main/java/org/elasticsearch/plugins/cli/InstallPluginAction.java +++ b/distribution/tools/plugin-cli/src/main/java/org/elasticsearch/plugins/cli/InstallPluginAction.java @@ -922,7 +922,7 @@ void jarHellCheck(PluginDescriptor candidateInfo, Path candidateDir, Path plugin private PluginDescriptor installPlugin(InstallablePlugin descriptor, Path tmpRoot, List deleteOnFailure) throws Exception { final PluginDescriptor info = loadPluginInfo(tmpRoot); - Path legacyPolicyFile = tmpRoot.resolve("plugin-security.policy"); + Path legacyPolicyFile = tmpRoot.resolve(PluginDescriptor.ES_PLUGIN_POLICY); if (Files.exists(legacyPolicyFile)) { terminal.errorPrintln( "WARNING: this plugin contains a legacy Security Policy file. Starting with version 8.18, " diff --git a/docs/changelog/120488.yaml b/docs/changelog/120488.yaml deleted file mode 100644 index 8d5b07ad21634..0000000000000 --- a/docs/changelog/120488.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 120488 -summary: Publish queue latency metrics from tracked thread pools -area: "Infra/Metrics" -type: enhancement -issues: [] diff --git a/docs/changelog/125679.yaml b/docs/changelog/125679.yaml new file mode 100644 index 0000000000000..401d25317e096 --- /dev/null +++ b/docs/changelog/125679.yaml @@ -0,0 +1,5 @@ +pr: 125679 +summary: Adding support for generic Inference services +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/125694.yaml b/docs/changelog/125694.yaml new file mode 100644 index 0000000000000..c4c7a622dbdf6 --- /dev/null +++ b/docs/changelog/125694.yaml @@ -0,0 +1,5 @@ +pr: 125694 +summary: LTR score bounding +area: Ranking +type: bug +issues: [] diff --git a/docs/changelog/125922.yaml b/docs/changelog/125922.yaml new file mode 100644 index 0000000000000..9cbf0e8ef4e82 --- /dev/null +++ b/docs/changelog/125922.yaml @@ -0,0 +1,5 @@ +pr: 125922 +summary: Fix text structure NPE when fields in list have null value +area: Machine Learning +type: bug +issues: [] diff --git a/docs/changelog/126273.yaml b/docs/changelog/126273.yaml new file mode 100644 index 0000000000000..420c0eb317a03 --- /dev/null +++ b/docs/changelog/126273.yaml @@ -0,0 +1,5 @@ +pr: 126273 +summary: Fix LTR rescorer with model alias +area: Ranking +type: bug +issues: [] diff --git a/docs/changelog/126310.yaml b/docs/changelog/126310.yaml new file mode 100644 index 0000000000000..a419a1036bd67 --- /dev/null +++ b/docs/changelog/126310.yaml @@ -0,0 +1,6 @@ +pr: 126310 +summary: Add Issuer to failed SAML Signature validation logs when available +area: Security +type: enhancement +issues: + - 111022 diff --git a/docs/changelog/126342.yaml b/docs/changelog/126342.yaml new file mode 100644 index 0000000000000..b594deec97de5 --- /dev/null +++ b/docs/changelog/126342.yaml @@ -0,0 +1,5 @@ +pr: 126342 +summary: Enable sort optimization on float and `half_float` +area: Search +type: enhancement +issues: [] diff --git a/docs/changelog/126583.yaml b/docs/changelog/126583.yaml new file mode 100644 index 0000000000000..a6732b7936f8a --- /dev/null +++ b/docs/changelog/126583.yaml @@ -0,0 +1,5 @@ +pr: 126583 +summary: Cancel expired async search task when a remote returns its results +area: CCS +type: bug +issues: [] diff --git a/docs/changelog/126605.yaml b/docs/changelog/126605.yaml new file mode 100644 index 0000000000000..44031f5d51616 --- /dev/null +++ b/docs/changelog/126605.yaml @@ -0,0 +1,5 @@ +pr: 126605 +summary: Fix equality bug in `WaitForIndexColorStep` +area: ILM+SLM +type: bug +issues: [] diff --git a/docs/changelog/126612.yaml b/docs/changelog/126612.yaml deleted file mode 100644 index e8fd1825bfc2d..0000000000000 --- a/docs/changelog/126612.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 126612 -summary: Add Support for Providing a custom `ServiceAccountTokenStore` through `SecurityExtensions` -area: Authentication -type: enhancement -issues: [] diff --git a/docs/changelog/126614.yaml b/docs/changelog/126614.yaml new file mode 100644 index 0000000000000..e8424c8c78245 --- /dev/null +++ b/docs/changelog/126614.yaml @@ -0,0 +1,5 @@ +pr: 126614 +summary: Fix join masking eval +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/126637.yaml b/docs/changelog/126637.yaml new file mode 100644 index 0000000000000..6b51566457bfc --- /dev/null +++ b/docs/changelog/126637.yaml @@ -0,0 +1,5 @@ +pr: 126637 +summary: Improve resiliency of `UpdateTimeSeriesRangeService` +area: TSDB +type: bug +issues: [] diff --git a/docs/changelog/126686.yaml b/docs/changelog/126686.yaml new file mode 100644 index 0000000000000..802ec538e5c1e --- /dev/null +++ b/docs/changelog/126686.yaml @@ -0,0 +1,6 @@ +pr: 126686 +summary: Fix race condition in `RestCancellableNodeClient` +area: Task Management +type: bug +issues: + - 88201 diff --git a/docs/changelog/126729.yaml b/docs/changelog/126729.yaml new file mode 100644 index 0000000000000..0a5e296e4c250 --- /dev/null +++ b/docs/changelog/126729.yaml @@ -0,0 +1,6 @@ +pr: 126729 +summary: Use terminal reader in keystore add command +area: Infra/CLI +type: bug +issues: + - 98115 diff --git a/docs/changelog/126778.yaml b/docs/changelog/126778.yaml new file mode 100644 index 0000000000000..c695e24ba3c84 --- /dev/null +++ b/docs/changelog/126778.yaml @@ -0,0 +1,5 @@ +pr: 126778 +summary: Fix bbq quantization algorithm but for differently distributed components +area: Vector Search +type: bug +issues: [] diff --git a/docs/changelog/126783.yaml b/docs/changelog/126783.yaml new file mode 100644 index 0000000000000..ac91c7cfd412b --- /dev/null +++ b/docs/changelog/126783.yaml @@ -0,0 +1,6 @@ +pr: 126783 +summary: Fix shard size of initializing restored shard +area: Allocation +type: bug +issues: + - 105331 diff --git a/docs/changelog/126805.yaml b/docs/changelog/126805.yaml deleted file mode 100644 index 9051f775f698d..0000000000000 --- a/docs/changelog/126805.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 126805 -summary: Adding timeout to request for creating inference endpoint -area: Machine Learning -type: bug -issues: [] diff --git a/docs/changelog/126806.yaml b/docs/changelog/126806.yaml new file mode 100644 index 0000000000000..cdc9d97d750cc --- /dev/null +++ b/docs/changelog/126806.yaml @@ -0,0 +1,5 @@ +pr: 126806 +summary: Workaround max name limit imposed by Jackson 2.17 +area: Infra/Core +type: bug +issues: [] diff --git a/docs/changelog/126843.yaml b/docs/changelog/126843.yaml index 1497f75b2ea74..77d3916c31955 100644 --- a/docs/changelog/126843.yaml +++ b/docs/changelog/126843.yaml @@ -42,7 +42,8 @@ breaking: `com.amazonaws.sdk.ec2MetadataServiceEndpointOverride` system property. * AWS SDK v2 does not permit specifying a choice between HTTP and HTTPS so - the `s3.client.${CLIENT_NAME}.protocol` setting is deprecated. + the `s3.client.${CLIENT_NAME}.protocol` setting is deprecated and no longer + has any effect. * AWS SDK v2 does not permit control over throttling for retries, so the the `s3.client.${CLIENT_NAME}.use_throttle_retries` setting is deprecated @@ -80,9 +81,9 @@ breaking: * If applicable, discontinue use of the `com.amazonaws.sdk.ec2MetadataServiceEndpointOverride` system property. - * If applicable, specify the protocol to use to access the S3 API by - setting `s3.client.${CLIENT_NAME}.endpoint` to a URL which starts with - `http://` or `https://`. + * If applicable, specify that you wish to use the insecure HTTP protocol to + access the S3 API by setting `s3.client.${CLIENT_NAME}.endpoint` to a URL + which starts with `http://`. * If applicable, discontinue use of the `log-delivery-write` canned ACL. diff --git a/docs/changelog/126850.yaml b/docs/changelog/126850.yaml new file mode 100644 index 0000000000000..852d4657b15e6 --- /dev/null +++ b/docs/changelog/126850.yaml @@ -0,0 +1,5 @@ +pr: 126850 +summary: "[otel-data] Bump plugin version to release _metric_names_hash changes" +area: Data streams +type: bug +issues: [] diff --git a/docs/changelog/126852.yaml b/docs/changelog/126852.yaml new file mode 100644 index 0000000000000..e2fe44b24ed69 --- /dev/null +++ b/docs/changelog/126852.yaml @@ -0,0 +1,5 @@ +pr: 126852 +summary: "Validation checks on paths allowed for 'files' entitlements. Restrict the paths we allow access to, forbidding plugins to specify/request entitlements for reading or writing to specific protected directories." +area: Infra/Core +type: enhancement +issues: [] diff --git a/docs/changelog/126858.yaml b/docs/changelog/126858.yaml new file mode 100644 index 0000000000000..d1ea2ebba73ef --- /dev/null +++ b/docs/changelog/126858.yaml @@ -0,0 +1,6 @@ +pr: 126858 +summary: Leverage threadpool schedule for inference api to avoid long running thread +area: Machine Learning +type: bug +issues: + - 126853 diff --git a/docs/changelog/126876.yaml b/docs/changelog/126876.yaml deleted file mode 100644 index 895af10840d84..0000000000000 --- a/docs/changelog/126876.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 126876 -summary: Improve HNSW filtered search speed through new heuristic -area: Vector Search -type: enhancement -issues: [] diff --git a/docs/changelog/126884.yaml b/docs/changelog/126884.yaml new file mode 100644 index 0000000000000..7ad905c93bf85 --- /dev/null +++ b/docs/changelog/126884.yaml @@ -0,0 +1,5 @@ +pr: 126884 +summary: Rare terms aggregation false **positive** fix +area: Aggregations +type: bug +issues: [] diff --git a/docs/changelog/126889.yaml b/docs/changelog/126889.yaml new file mode 100644 index 0000000000000..33d15e3f124ac --- /dev/null +++ b/docs/changelog/126889.yaml @@ -0,0 +1,6 @@ +pr: 126889 +summary: Rework uniquify to not use iterators +area: Infra/Core +type: bug +issues: + - 126883 diff --git a/docs/changelog/126911.yaml b/docs/changelog/126911.yaml new file mode 100644 index 0000000000000..040d1dff767bf --- /dev/null +++ b/docs/changelog/126911.yaml @@ -0,0 +1,6 @@ +pr: 126911 +summary: Fix `vec_caps` to test for OS support too (on x64) +area: Vector Search +type: bug +issues: + - 126809 diff --git a/docs/changelog/126930.yaml b/docs/changelog/126930.yaml new file mode 100644 index 0000000000000..1507cec38ee02 --- /dev/null +++ b/docs/changelog/126930.yaml @@ -0,0 +1,5 @@ +pr: 126930 +summary: Adding missing `onFailure` call for Inference API start model request +area: Machine Learning +type: bug +issues: [] diff --git a/docs/changelog/126990.yaml b/docs/changelog/126990.yaml new file mode 100644 index 0000000000000..a8b875cc6a221 --- /dev/null +++ b/docs/changelog/126990.yaml @@ -0,0 +1,7 @@ +pr: 126990 +summary: "Fix: consider case sensitiveness differences in Windows/Unix-like filesystems\ + \ for files entitlements" +area: Infra/Core +type: bug +issues: + - 127047 diff --git a/docs/changelog/127146.yaml b/docs/changelog/127146.yaml new file mode 100644 index 0000000000000..a36d837f0bebd --- /dev/null +++ b/docs/changelog/127146.yaml @@ -0,0 +1,5 @@ +pr: 127146 +summary: Fix sneaky bug in single value query +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/127225.yaml b/docs/changelog/127225.yaml new file mode 100644 index 0000000000000..b161a5c40bfdf --- /dev/null +++ b/docs/changelog/127225.yaml @@ -0,0 +1,6 @@ +pr: 127225 +summary: Fix count optimization with pushable union types +area: ES|QL +type: bug +issues: + - 127200 diff --git a/docs/changelog/127353.yaml b/docs/changelog/127353.yaml new file mode 100644 index 0000000000000..1fde8f97115fd --- /dev/null +++ b/docs/changelog/127353.yaml @@ -0,0 +1,5 @@ +pr: 127353 +summary: Updating tika to 2.9.3 +area: Ingest Node +type: upgrade +issues: [] diff --git a/docs/changelog/127414.yaml b/docs/changelog/127414.yaml new file mode 100644 index 0000000000000..37d7c11d901a3 --- /dev/null +++ b/docs/changelog/127414.yaml @@ -0,0 +1,5 @@ +pr: 127414 +summary: Fix npe when using source confirmed text query against missing field +area: Search +type: bug +issues: [] diff --git a/docs/changelog/127524.yaml b/docs/changelog/127524.yaml deleted file mode 100644 index d11599ddcde58..0000000000000 --- a/docs/changelog/127524.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 127524 -summary: Resolve groupings in aggregate before resolving references to groupings in - the aggregations -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/127527.yaml b/docs/changelog/127527.yaml new file mode 100644 index 0000000000000..6e1d3e363c3bb --- /dev/null +++ b/docs/changelog/127527.yaml @@ -0,0 +1,5 @@ +pr: 127527 +summary: "No, line noise isn't a valid ip" +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/127582.yaml b/docs/changelog/127582.yaml deleted file mode 100644 index 589c20e8f2fbc..0000000000000 --- a/docs/changelog/127582.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 127582 -summary: Specialize ags `AddInput` for each block type -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/127658.yaml b/docs/changelog/127658.yaml deleted file mode 100644 index 1a8d5ced7c8b6..0000000000000 --- a/docs/changelog/127658.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 127658 -summary: Append all data to Chat Completion buffer -area: Machine Learning -type: bug -issues: [] diff --git a/docs/changelog/127734.yaml b/docs/changelog/127734.yaml deleted file mode 100644 index d33b201744c46..0000000000000 --- a/docs/changelog/127734.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 127734 -summary: Run coordinating `can_match` in field-caps -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/127798.yaml b/docs/changelog/127798.yaml deleted file mode 100644 index f9f2ceb93b4f1..0000000000000 --- a/docs/changelog/127798.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 127798 -summary: Handle streaming request body in audit log -area: Audit -type: bug -issues: [] diff --git a/docs/reference/elasticsearch-plugins/index.md b/docs/reference/elasticsearch-plugins/index.md index a6572bb3d4d7a..d343003ffe7e1 100644 --- a/docs/reference/elasticsearch-plugins/index.md +++ b/docs/reference/elasticsearch-plugins/index.md @@ -4,26 +4,30 @@ mapped_pages: - https://www.elastic.co/guide/en/elasticsearch/plugins/current/intro.html --- -# {{es}} plugins [intro] +# Elasticsearch plugins [intro] -This section contains reference information for {{es}} plugins. +:::{note} +This section provides detailed **reference information** for Elasticsearch plugins. -Refer to [Add plugins and extensions](docs-content://deploy-manage/deploy/elastic-cloud/add-plugins-extensions.md) for an overview, setup instructions, and conceptual details. +Refer to [Add plugins and extensions](docs-content://deploy-manage/deploy/elastic-cloud/add-plugins-extensions.md) in the **Deploy and manage** section for overview, getting started and conceptual information. +::: -Plugins are a way to enhance the core {{es}} functionality in a custom manner. They range from adding custom mapping types, custom analyzers, native scripts, custom discovery and more. +Plugins are a way to enhance the core Elasticsearch functionality in a custom manner. They range from adding custom mapping types, custom analyzers, native scripts, custom discovery and more. Plugins contain JAR files, but may also contain scripts and config files, and must be installed on every node in the cluster. After installation, each node must be restarted before the plugin becomes visible. -There are two categories of plugins: +::::{note} +A full cluster restart is required for installing plugins that have custom cluster state metadata. It is still possible to upgrade such plugins with a rolling restart. +:::: + + +This documentation distinguishes two categories of plugins: Core Plugins -: This category identifies plugins that are part of {{es}} project. Delivered at the same time as Elasticsearch, their version number always matches the version number of Elasticsearch itself. These plugins are maintained by the Elastic team with the appreciated help of amazing community members (for open source plugins). Issues and bug reports can be reported on the [Github project page](https://github.com/elastic/elasticsearch). +: This category identifies plugins that are part of Elasticsearch project. Delivered at the same time as Elasticsearch, their version number always matches the version number of Elasticsearch itself. These plugins are maintained by the Elastic team with the appreciated help of amazing community members (for open source plugins). Issues and bug reports can be reported on the [Github project page](https://github.com/elastic/elasticsearch). Community contributed -: This category identifies plugins that are external to the {{es}} project. They are provided by individual developers or private companies and have their own licenses as well as their own versioning system. Issues and bug reports can usually be reported on the community plugin’s web site. +: This category identifies plugins that are external to the Elasticsearch project. They are provided by individual developers or private companies and have their own licenses as well as their own versioning system. Issues and bug reports can usually be reported on the community plugin’s web site. -If you want to write your own plugin, refer to [Creating an {{es}} plugin](/extend/index.md). +For advice on writing your own plugin, refer to [*Creating an {{es}} plugin*](/extend/index.md). -:::{note} -A full cluster restart is required for installing plugins that have custom cluster state metadata. It is still possible to upgrade such plugins with a rolling restart. -::: \ No newline at end of file diff --git a/docs/reference/elasticsearch/configuration-reference/security-settings.md b/docs/reference/elasticsearch/configuration-reference/security-settings.md index f198d36eee172..d83ee1ed25803 100644 --- a/docs/reference/elasticsearch/configuration-reference/security-settings.md +++ b/docs/reference/elasticsearch/configuration-reference/security-settings.md @@ -1486,15 +1486,6 @@ $$$jwt-claim-pattern-principal$$$ `client_authentication.rotation_grace_period` : ([Static](docs-content://deploy-manage/deploy/self-managed/configure-elasticsearch.md#static-cluster-setting)) Sets the grace period for how long after rotating the `client_authentication.shared_secret` is valid. `client_authentication.shared_secret` can be rotated by updating the keystore then calling the [reload API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-nodes-reload-secure-settings). Defaults to `1m`. -`http.proxy.host` -: ([Static](docs-content://deploy-manage/deploy/self-managed/configure-elasticsearch.md#static-cluster-setting)) Specifies the address of the proxy server for the HTTP client that is used for fetching the JSON Web Key Set from a remote URL. - -`http.proxy.scheme` -: ([Static](docs-content://deploy-manage/deploy/self-managed/configure-elasticsearch.md#static-cluster-setting)) Specifies the protocol to use to connect to the proxy server for the HTTP client that is used for fetching the JSON Web Key Set from a remote URL. Must be `http`. - -`http.proxy.port` -: ([Static](docs-content://deploy-manage/deploy/self-managed/configure-elasticsearch.md#static-cluster-setting)) Specifies the port of the proxy server for the HTTP client that is used for fetching the JSON Web Key Set from a remote URL. Defaults to `80`. - `http.connect_timeout` ![logo cloud](https://doc-icons.s3.us-east-2.amazonaws.com/logo_cloud.svg "Supported on Elastic Cloud Hosted") : ([Static](docs-content://deploy-manage/deploy/self-managed/configure-elasticsearch.md#static-cluster-setting)) Sets the timeout for the HTTP client that is used for fetching the JSON Web Key Set from a remote URL. A value of zero means the timeout is not used. Defaults to `5s`. diff --git a/docs/reference/elasticsearch/index-settings/index-modules.md b/docs/reference/elasticsearch/index-settings/index-modules.md index 4ab35b9d80a88..682a6fa2a39d4 100644 --- a/docs/reference/elasticsearch/index-settings/index-modules.md +++ b/docs/reference/elasticsearch/index-settings/index-modules.md @@ -249,12 +249,6 @@ $$$index-final-pipeline$$$ $$$index-hidden$$$ `index.hidden` : Indicates whether the index should be hidden by default. Hidden indices are not returned by default when using a wildcard expression. This behavior is controlled per request through the use of the `expand_wildcards` parameter. Possible values are `true` and `false` (default). -$$$index-dense-vector-hnsw-filter-heuristic$$$ `index.dense_vector.hnsw_filter_heuristic` -: The heuristic to utilize when executing a filtered search against vectors in an HNSW graph. This setting is in technical preview may be changed or removed in a future release. It can be set to: - -* `acorn` (default) - Only vectors that match the filter criteria are searched. This is the fastest option, and generally provides faster searches at similar recall to `fanout`, but `num_candidates` might need to be increased for exceptionally high recall requirements. -* `fanout` - All vectors are compared with the query vector, but only those passing the criteria are added to the search results. Can be slower than `acorn`, but may yield higher recall. - $$$index-esql-stored-fields-sequential-proportion$$$ `index.esql.stored_fields_sequential_proportion` diff --git a/docs/reference/elasticsearch/index.md b/docs/reference/elasticsearch/index.md index a2060ee9c384c..5e23cb0854291 100644 --- a/docs/reference/elasticsearch/index.md +++ b/docs/reference/elasticsearch/index.md @@ -1,14 +1,18 @@ # Elasticsearch and index management -This section contains reference information for {{es}} and index management features. +% TO-DO: Add links to "Elasticsearch basics"% -To learn more about {{es}} features and how to get started, refer to the [{{es}}](docs-content://solutions/search.md) documentation. +This section contains reference information for Elasticsearch and index management features, including: -For more details about query and scripting languages, check these sections: -* [Query languages](../query-languages/index.md) -* [Scripting languages](../scripting-languages/index.md) - -{{es}} also provides the following REST APIs: +* Settings +* Security roles and privileges +* Index lifecycle actions +* Mappings +* Command line tools +* Curator +* Clients -* [{{es}} API](https://www.elastic.co/docs/api/doc/elasticsearch) -* [{{es}} Serverless API](https://www.elastic.co/docs/api/doc/elasticsearch-serverless) \ No newline at end of file +% TO-DO: Add links to "query language and scripting language sections"% + +Elasticsearch also provides REST APIs that are used by the UI components and can be called directly to configure and access Elasticsearch features. +Refer to [Elasticsearch API](https://www.elastic.co/docs/api/doc/elasticsearch) and [Elasticsearch Serverless API](https://www.elastic.co/docs/api/doc/elasticsearch-serverless). \ No newline at end of file diff --git a/docs/reference/query-languages/esql/_snippets/commands/layout/lookup-join.md b/docs/reference/query-languages/esql/_snippets/commands/layout/lookup-join.md index 2d39b39b11e8e..e25a0c75d8b85 100644 --- a/docs/reference/query-languages/esql/_snippets/commands/layout/lookup-join.md +++ b/docs/reference/query-languages/esql/_snippets/commands/layout/lookup-join.md @@ -42,7 +42,7 @@ If multiple documents in the lookup index match a single row in your results, the output will contain one row for each matching combination. ::::{tip} -For important information about using `LOOKUP JOIN`, refer to [Usage notes](../../../../esql/esql-lookup-join.md#usage-notes). +In case of name collisions, the newly created columns will override existing columns. :::: **Examples** diff --git a/docs/reference/query-languages/esql/_snippets/commands/layout/mv_expand.md b/docs/reference/query-languages/esql/_snippets/commands/layout/mv_expand.md index 3e204a2a3d1be..9812a7d0c2335 100644 --- a/docs/reference/query-languages/esql/_snippets/commands/layout/mv_expand.md +++ b/docs/reference/query-languages/esql/_snippets/commands/layout/mv_expand.md @@ -22,12 +22,6 @@ MV_EXPAND column `column` : The multivalued column to expand. -::::{warning} -The output rows produced by `MV_EXPAND` can be in any order and may not respect -preceding `SORT`s. To guarantee a certain ordering, place a `SORT` after any -`MV_EXPAND`s. -:::: - **Example** :::{include} ../examples/mv_expand.csv-spec/simple.md diff --git a/docs/reference/query-languages/esql/esql-lookup-join.md b/docs/reference/query-languages/esql/esql-lookup-join.md index d57437833c1b2..302681a511e00 100644 --- a/docs/reference/query-languages/esql/esql-lookup-join.md +++ b/docs/reference/query-languages/esql/esql-lookup-join.md @@ -156,27 +156,9 @@ To obtain a join key with a compatible type, use a [conversion function](/refere For a complete list of supported data types and their internal representations, see the [Supported Field Types documentation](/reference/query-languages/esql/limitations.md#_supported_types). -## Usage notes - -This section covers important details about `LOOKUP JOIN` that impact query behavior and results. Review these details to ensure your queries work as expected and to troubleshoot unexpected results. - -### Handling name collisions - -When fields from the lookup index match existing column names, the new columns override the existing ones. -Before the `LOOKUP JOIN` command, preserve columns by either: - -* Using `RENAME` to assign non-conflicting names -* Using `EVAL` to create new columns with different names - -### Sorting behavior - -The output rows produced by `LOOKUP JOIN` can be in any order and may not -respect preceding `SORT`s. To guarantee a certain ordering, place a `SORT` after -any `LOOKUP JOIN`s. - ## Limitations -The following are the current limitations with `LOOKUP JOIN`: +The following are the current limitations with `LOOKUP JOIN` * Indices in [`lookup` mode](/reference/elasticsearch/index-settings/index-modules.md#index-mode-setting) are always single-sharded. * Cross cluster search is unsupported initially. Both source and lookup indices must be local. diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/categorize.json b/docs/reference/query-languages/esql/kibana/definition/functions/categorize.json index 4f6e1379275e3..088384a3fa1a8 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/categorize.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/categorize.json @@ -2,7 +2,6 @@ "comment" : "This is generated by ESQL’s AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", "type" : "grouping", "name" : "categorize", - "license" : "PLATINUM", "description" : "Groups text messages into categories of similarly formatted text values.", "signatures" : [ { @@ -14,7 +13,6 @@ "description" : "Expression to categorize" } ], - "license" : "PLATINUM", "variadic" : false, "returnType" : "keyword" }, @@ -27,7 +25,6 @@ "description" : "Expression to categorize" } ], - "license" : "PLATINUM", "variadic" : false, "returnType" : "keyword" } diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/st_extent_agg.json b/docs/reference/query-languages/esql/kibana/definition/functions/st_extent_agg.json index 9c05870b2cfd1..fa129eec29da2 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/st_extent_agg.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/st_extent_agg.json @@ -25,7 +25,6 @@ "description" : "" } ], - "license" : "PLATINUM", "variadic" : false, "returnType" : "cartesian_shape" }, @@ -50,7 +49,6 @@ "description" : "" } ], - "license" : "PLATINUM", "variadic" : false, "returnType" : "geo_shape" } diff --git a/docs/reference/query-languages/index.md b/docs/reference/query-languages/index.md index 2781d2afd2cb6..2a22d99e226ff 100644 --- a/docs/reference/query-languages/index.md +++ b/docs/reference/query-languages/index.md @@ -5,7 +5,13 @@ applies_to: --- # Query languages -This section contains reference information for Elastic query languages, including: +:::{note} +This section provides detailed **reference information** for query languages. + +Refer to [Query Languages](docs-content://explore-analyze/query-filter/languages.md) in the **Explore and analyze** section for overview, getting started and conceptual information. +::: + +This section contains reference information for Elastic query languages. * [Query DSL](querydsl.md) * [ES|QL](esql.md) @@ -13,4 +19,6 @@ This section contains reference information for Elastic query languages, includi * [EQL](eql.md) * [KQL](kql.md) -For more information about each language, refer to the [Explore and analyze](docs-content://explore-analyze/query-filter/languages.md) section. +:::{tip} +Refer to [query languages](docs-content://explore-analyze/query-filter/languages.md) in the **Explore and analyze** section for overview and conceptual information about each language. +::: \ No newline at end of file diff --git a/docs/reference/scripting-languages/index.md b/docs/reference/scripting-languages/index.md index 645d267f0b1e5..a0993f92f7a96 100644 --- a/docs/reference/scripting-languages/index.md +++ b/docs/reference/scripting-languages/index.md @@ -1,5 +1,7 @@ # Scripting languages -This section provides reference information about the Painless scripting language. +:::{note} +This section provides detailed **reference information** about the the Painless scripting language. Refer to the [scripting languages overview](docs-content://explore-analyze/scripting.md) in the **Explore and analyze** section for an overview of available scripting languages in {{es}}. +::: \ No newline at end of file diff --git a/docs/reference/search-connectors/release-notes.md b/docs/reference/search-connectors/release-notes.md deleted file mode 100644 index 8bde830b2ce2a..0000000000000 --- a/docs/reference/search-connectors/release-notes.md +++ /dev/null @@ -1,36 +0,0 @@ ---- -navigation_title: "Release notes" -mapped_pages: - - https://www.elastic.co/guide/en/elasticsearch/reference/8.18/es-connectors-release-notes.html ---- - -# Connector release notes - -:::{admonition} Enterprise Search is discontinued in Elastic 9.0.0 -Please note that Enterprise Search is not available in Elastic 9.0+, including App Search, Workplace Search, the Elastic Web Crawler, and Elastic managed connectors. - -If you are an Enterprise Search user and want to upgrade to Elastic 9.0, refer to [our Enterprise Search FAQ](https://www.elastic.co/resources/search/enterprise-search-faq#what-features-are-impacted-by-this-announcement). -It includes detailed steps, tooling, and resources to help you transition to supported alternatives in 9.x, such as Elasticsearch, the Open Web Crawler, and self-managed connectors. -::: - -## 9.0.1 [connectors-9.0.1-release-notes] -No changes since 9.0.0 - -## 9.0.0 [connectors-9.0.0-release-notes] - -### Features and enhancements [connectors-9.0.0-features-enhancements] - -* Switched the default ingestion pipeline from `ent-search-generic-ingestion` to `search-default-ingestion`. The pipelines are functionally identical; only the name has changed to align with the deprecation of Enterprise Search. [#3049](https://github.com/elastic/connectors/pull/3049) -* Removed opinionated index mappings and settings from Connectors. Going forward, indices will use Elastic’s default mappings and settings, rather than legacy App Search–optimized ones. To retain the previous behavior, create the index manually before pointing a connector to it. [#3013](https://github.com/elastic/connectors/pull/3013) - -### Fixes [connectors-9.0.0-fixes] - -* Fixed an issue where full syncs could delete newly ingested documents if the document ID from the third-party source was numeric. [#3031](https://github.com/elastic/connectors/pull/3031) -* Fixed a bug where the Confluence connector failed to download some blog post documents due to unexpected response formats. [#2984](https://github.com/elastic/connectors/pull/2984) -* Fixed a bug in the Outlook connector where deactivated users could cause syncs to fail. [#2967](https://github.com/elastic/connectors/pull/2967) -* Resolved an issue where Network Drive connectors had trouble connecting to SMB 3.1.1 shares. [#2852](https://github.com/elastic/connectors/pull/2852) - -% ## Breaking changes [connectors-9.0.0-breaking-changes] -% ## Deprications [connectorsch-9.0.0-deprecations] -% ## Known issues [connectors-9.0.0-known-issues] - diff --git a/docs/reference/search-connectors/toc.yml b/docs/reference/search-connectors/toc.yml index 3fb2578cfebfd..744e3750a5767 100644 --- a/docs/reference/search-connectors/toc.yml +++ b/docs/reference/search-connectors/toc.yml @@ -63,5 +63,4 @@ toc: - file: use-cases.md children: - file: es-connectors-overview-architecture.md -- file: es-connectors-known-issues.md -- file: release-notes.md +- file: es-connectors-known-issues.md \ No newline at end of file diff --git a/docs/release-notes/breaking-changes.md b/docs/release-notes/breaking-changes.md index 875211b52944b..cc8fac24dc8ce 100644 --- a/docs/release-notes/breaking-changes.md +++ b/docs/release-notes/breaking-changes.md @@ -12,9 +12,10 @@ If you are migrating from a version prior to version 9.0, you must first upgrade % ## Next version [elasticsearch-nextversion-breaking-changes] -## 9.0.1 [elasticsearch-9.0.1-breaking-changes] +## 9.1.0 [elasticsearch-910-breaking-changes] -No breaking changes in this version. +ES|QL +: * Allow partial results by default in ES|QL [#125060](https://github.com/elastic/elasticsearch/pull/125060) ## 9.0.0 [elasticsearch-900-breaking-changes] diff --git a/docs/release-notes/deprecations.md b/docs/release-notes/deprecations.md index a691d9e28a1e2..b835466a91452 100644 --- a/docs/release-notes/deprecations.md +++ b/docs/release-notes/deprecations.md @@ -16,10 +16,6 @@ To give you insight into what deprecated features you’re using, {{es}}: % ## Next version [elasticsearch-nextversion-deprecations] -## 9.0.1 [elasticsearch-9.0.1-deprecations] - -No deprecations in this version. - ## 9.0.0 [elasticsearch-900-deprecations] ES|QL: diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 4e841acf18ba3..c94ab727dcaea 100644 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -20,87 +20,19 @@ To check for security updates, go to [Security announcements for the Elastic sta % ### Fixes [elasticsearch-next-fixes] % * -## 9.0.1 [elasticsearch-9.0.1-release-notes] - -### Features and enhancements [elasticsearch-9.0.1-features-enhancements] - -Infra/Core: -* Validation checks on paths allowed for 'files' entitlements. Restrict the paths we allow access to, forbidding plugins to specify/request entitlements for reading or writing to specific protected directories. [#126852](https://github.com/elastic/elasticsearch/pull/126852) - -Ingest Node: -* Updating tika to 2.9.3 [#127353](https://github.com/elastic/elasticsearch/pull/127353) - -Search: -* Enable sort optimization on float and `half_float` [#126342](https://github.com/elastic/elasticsearch/pull/126342) - -Security: -* Add Issuer to failed SAML Signature validation logs when available [#126310](https://github.com/elastic/elasticsearch/pull/126310) (issue: [#111022](https://github.com/elastic/elasticsearch/issues/111022)) - -### Fixes [elasticsearch-9.0.1-fixes] - -Aggregations: -* Rare terms aggregation false **positive** fix [#126884](https://github.com/elastic/elasticsearch/pull/126884) - -Allocation: -* Fix shard size of initializing restored shard [#126783](https://github.com/elastic/elasticsearch/pull/126783) (issue: [#105331](https://github.com/elastic/elasticsearch/issues/105331)) - -CCS: -* Cancel expired async search task when a remote returns its results [#126583](https://github.com/elastic/elasticsearch/pull/126583) - -Data streams: -* [otel-data] Bump plugin version to release _metric_names_hash changes [#126850](https://github.com/elastic/elasticsearch/pull/126850) - -ES|QL: -* Fix count optimization with pushable union types [#127225](https://github.com/elastic/elasticsearch/pull/127225) (issue: [#127200](https://github.com/elastic/elasticsearch/issues/127200)) -* Fix join masking eval [#126614](https://github.com/elastic/elasticsearch/pull/126614) -* Fix sneaky bug in single value query [#127146](https://github.com/elastic/elasticsearch/pull/127146) -* No, line noise isn't a valid ip [#127527](https://github.com/elastic/elasticsearch/pull/127527) - -ILM+SLM: -* Fix equality bug in `WaitForIndexColorStep` [#126605](https://github.com/elastic/elasticsearch/pull/126605) - -Infra/CLI: -* Use terminal reader in keystore add command [#126729](https://github.com/elastic/elasticsearch/pull/126729) (issue: [#98115](https://github.com/elastic/elasticsearch/issues/98115)) - -Infra/Core: -* Fix: consider case sensitiveness differences in Windows/Unix-like filesystems for files entitlements [#126990](https://github.com/elastic/elasticsearch/pull/126990) (issue: [#127047](https://github.com/elastic/elasticsearch/issues/127047)) -* Rework uniquify to not use iterators [#126889](https://github.com/elastic/elasticsearch/pull/126889) (issue: [#126883](https://github.com/elastic/elasticsearch/issues/126883)) -* Workaround max name limit imposed by Jackson 2.17 [#126806](https://github.com/elastic/elasticsearch/pull/126806) - -Machine Learning: -* Adding missing `onFailure` call for Inference API start model request [#126930](https://github.com/elastic/elasticsearch/pull/126930) -* Fix text structure NPE when fields in list have null value [#125922](https://github.com/elastic/elasticsearch/pull/125922) -* Leverage threadpool schedule for inference api to avoid long running thread [#126858](https://github.com/elastic/elasticsearch/pull/126858) (issue: [#126853](https://github.com/elastic/elasticsearch/issues/126853)) - -Ranking: -* Fix LTR rescorer with model alias [#126273](https://github.com/elastic/elasticsearch/pull/126273) -* LTR score bounding [#125694](https://github.com/elastic/elasticsearch/pull/125694) - -Search: -* Fix npe when using source confirmed text query against missing field [#127414](https://github.com/elastic/elasticsearch/pull/127414) - -TSDB: -* Improve resiliency of `UpdateTimeSeriesRangeService` [#126637](https://github.com/elastic/elasticsearch/pull/126637) - -Task Management: -* Fix race condition in `RestCancellableNodeClient` [#126686](https://github.com/elastic/elasticsearch/pull/126686) (issue: [#88201](https://github.com/elastic/elasticsearch/issues/88201)) - -Vector Search: -* Fix `vec_caps` to test for OS support too (on x64) [#126911](https://github.com/elastic/elasticsearch/pull/126911) (issue: [#126809](https://github.com/elastic/elasticsearch/issues/126809)) -* Fix bbq quantization algorithm but for differently distributed components [#126778](https://github.com/elastic/elasticsearch/pull/126778) - - ## 9.0.0 [elasticsearch-900-release-notes] ### Highlights [elasticsearch-900-highlights] ::::{dropdown} rank_vectors field type is now available for late-interaction ranking + [`rank_vectors`](../reference/elasticsearch/mapping-reference/rank-vectors.md) is a new field type released as an experimental feature in Elasticsearch 9.0. It is designed to be used with dense vectors and allows for late-interaction second order ranking. Late-interaction models are powerful rerankers. While their size and overall cost doesn’t lend itself for HNSW indexing, utilizing them as second order reranking can provide excellent boosts in relevance. The new `rank_vectors` mapping allows for rescoring over new and novel multi-vector late-interaction models like ColBERT or ColPali. :::: ::::{dropdown} ES|QL LOOKUP JOIN is now available in technical preview + [LOOKUP JOIN](../reference/query-languages/esql/esql-commands.md) is now available in technical preview. LOOKUP JOIN combines data from your ES|QL queries with matching records from a lookup index, enabling you to: - Enrich your search results with reference data diff --git a/docs/release-notes/known-issues.md b/docs/release-notes/known-issues.md index c95f2a353bd06..56b135ea9ff69 100644 --- a/docs/release-notes/known-issues.md +++ b/docs/release-notes/known-issues.md @@ -27,10 +27,3 @@ This issue will be fixed in a future patch release (see [PR #126990](https://git ``` For information about editing your JVM settings, refer to [JVM settings](https://www.elastic.co/docs/reference/elasticsearch/jvm-settings). - -* Users upgrading from an Elasticsearch cluster that had previously been on a version between 7.10.0 and 7.12.1 may see that Watcher will not start on 9.x. The solution is to run the following commands in Kibana Dev Tools (or the equivalent using curl): - ``` - DELETE _index_template/.triggered_watches - DELETE _index_template/.watches - POST /_watcher/_start - ``` diff --git a/gradle/build.versions.toml b/gradle/build.versions.toml index a5ebb38ff3a6b..9e470e210fe2b 100644 --- a/gradle/build.versions.toml +++ b/gradle/build.versions.toml @@ -48,4 +48,4 @@ wiremock = "com.github.tomakehurst:wiremock-jre8-standalone:2.23.2" xmlunit-core = "org.xmlunit:xmlunit-core:2.8.2" [plugins] -ospackage = { id = "com.netflix.nebula.ospackage-base", version = "11.11.2" } +ospackage = { id = "com.netflix.nebula.ospackage-base", version = "11.11.1" } diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 86a6b5cf75204..eb43190a68bac 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -991,9 +991,9 @@ - - - + + + diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 1b33c55baabb5..9bbc975c742b2 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f373f37ad8290..2a6e21b2ba89a 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=efe9a3d147d948d7528a9887fa35abcf24ca1a43ad06439996490f77569b02d1 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-all.zip +distributionSha256Sum=fba8464465835e74f7270bbf43d6d8a8d7709ab0a43ce1aa3323f73e9aa0c612 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.13-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/gradlew b/gradlew index 23d15a9367071..faf93008b77e7 100755 --- a/gradlew +++ b/gradlew @@ -114,7 +114,7 @@ case "$( uname )" in #( NONSTOP* ) nonstop=true ;; esac -CLASSPATH="\\\"\\\"" +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar # Determine the Java command to use to start the JVM. @@ -213,7 +213,7 @@ DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' set -- \ "-Dorg.gradle.appname=$APP_BASE_NAME" \ -classpath "$CLASSPATH" \ - -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + org.gradle.wrapper.GradleWrapperMain \ "$@" # Stop when "xargs" is not available. diff --git a/gradlew.bat b/gradlew.bat index 5eed7ee845284..9b42019c7915b 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -70,11 +70,11 @@ goto fail :execute @rem Setup the command line -set CLASSPATH= +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell diff --git a/libs/core/src/main/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoader.java b/libs/core/src/main/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoader.java index 5e5c82af4807e..751c5146b484a 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoader.java +++ b/libs/core/src/main/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoader.java @@ -23,8 +23,10 @@ import java.nio.file.FileSystems; import java.nio.file.Files; import java.nio.file.Path; +import java.security.AccessController; import java.security.CodeSigner; import java.security.CodeSource; +import java.security.PrivilegedAction; import java.security.SecureClassLoader; import java.util.ArrayList; import java.util.Collections; @@ -94,7 +96,8 @@ record JarMeta(String prefix, boolean isMultiRelease, Set packages, Map< private final ClassLoader parent; static EmbeddedImplClassLoader getInstance(ClassLoader parent, String providerName) { - return new EmbeddedImplClassLoader(parent, getProviderPrefixes(parent, providerName)); + PrivilegedAction pa = () -> new EmbeddedImplClassLoader(parent, getProviderPrefixes(parent, providerName)); + return AccessController.doPrivileged(pa); } private EmbeddedImplClassLoader(ClassLoader parent, Map prefixToCodeBase) { @@ -117,12 +120,14 @@ private EmbeddedImplClassLoader(ClassLoader parent, Map pre record Resource(InputStream inputStream, CodeSource codeSource) {} /** Searches for the named resource. Iterates over all prefixes. */ - private Resource getResourceOrNull(JarMeta jarMeta, String pkg, String filepath) { - InputStream is = findResourceInLoaderPkgOrNull(jarMeta, pkg, filepath, parent::getResourceAsStream); - if (is != null) { - return new Resource(is, prefixToCodeBase.get(jarMeta.prefix())); - } - return null; + private Resource privilegedGetResourceOrNull(JarMeta jarMeta, String pkg, String filepath) { + return AccessController.doPrivileged((PrivilegedAction) () -> { + InputStream is = findResourceInLoaderPkgOrNull(jarMeta, pkg, filepath, parent::getResourceAsStream); + if (is != null) { + return new Resource(is, prefixToCodeBase.get(jarMeta.prefix())); + } + return null; + }); } @Override @@ -143,7 +148,7 @@ public Class findClass(String name) throws ClassNotFoundException { String pkg = toPackageName(filepath); JarMeta jarMeta = packageToJarMeta.get(pkg); if (jarMeta != null) { - Resource res = getResourceOrNull(jarMeta, pkg, filepath); + Resource res = privilegedGetResourceOrNull(jarMeta, pkg, filepath); if (res != null) { try (InputStream in = res.inputStream()) { byte[] bytes = in.readAllBytes(); diff --git a/libs/core/src/main/java/org/elasticsearch/core/internal/provider/ProviderLocator.java b/libs/core/src/main/java/org/elasticsearch/core/internal/provider/ProviderLocator.java index e3b36463d80ea..902c61402c058 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/internal/provider/ProviderLocator.java +++ b/libs/core/src/main/java/org/elasticsearch/core/internal/provider/ProviderLocator.java @@ -15,6 +15,9 @@ import java.io.UncheckedIOException; import java.lang.module.Configuration; import java.lang.module.ModuleFinder; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.Locale; import java.util.Objects; import java.util.ServiceConfigurationError; @@ -94,9 +97,10 @@ public ProviderLocator(String providerName, Class providerType, String provid @Override public T get() { try { - return load(); - } catch (IOException e) { - throw new UncheckedIOException(e); + PrivilegedExceptionAction pa = this::load; + return AccessController.doPrivileged(pa); + } catch (PrivilegedActionException e) { + throw new UncheckedIOException((IOException) e.getCause()); } } diff --git a/libs/core/src/test/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoaderTests.java b/libs/core/src/test/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoaderTests.java index d8a963b84bc6c..34c8ed1c6d851 100644 --- a/libs/core/src/test/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoaderTests.java +++ b/libs/core/src/test/java/org/elasticsearch/core/internal/provider/EmbeddedImplClassLoaderTests.java @@ -9,11 +9,11 @@ package org.elasticsearch.core.internal.provider; -import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Strings; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.internal.provider.EmbeddedImplClassLoader.CompoundEnumeration; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.PrivilegedOperations; import org.elasticsearch.test.compiler.InMemoryJavaCompiler; import org.elasticsearch.test.jar.JarUtils; @@ -195,10 +195,13 @@ private Object newFooBar(boolean enableMulti, int... versions) throws Exception Path outerJar = topLevelDir.resolve("impl.jar"); JarUtils.createJarWithEntries(outerJar, jarEntries); URL[] urls = new URL[] { outerJar.toUri().toURL() }; - try (URLClassLoader parent = loader(urls)) { + URLClassLoader parent = URLClassLoader.newInstance(urls, EmbeddedImplClassLoaderTests.class.getClassLoader()); + try { EmbeddedImplClassLoader loader = EmbeddedImplClassLoader.getInstance(parent, "x-foo"); Class c = loader.loadClass("p.FooBar"); return c.getConstructor().newInstance(); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } @@ -242,7 +245,8 @@ public void testResourcesBasic() throws Exception { Path outerJar = topLevelDir.resolve("impl.jar"); JarUtils.createJarWithEntriesUTF(outerJar, jarEntries); URL[] urls = new URL[] { outerJar.toUri().toURL() }; - try (URLClassLoader parent = loader(urls)) { + URLClassLoader parent = URLClassLoader.newInstance(urls, EmbeddedImplClassLoaderTests.class.getClassLoader()); + try { EmbeddedImplClassLoader loader = EmbeddedImplClassLoader.getInstance(parent, "res"); // resource in a valid java package dir URL url = loader.findResource("p/res.txt"); @@ -270,6 +274,8 @@ public void testResourcesBasic() throws Exception { hasToString(endsWith("impl.jar!/IMPL-JARS/res/zoo-impl.jar/A-C/res.txt")) ) ); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } @@ -320,7 +326,9 @@ private void testResourcesParent(String resourcePath) throws Exception { containsInAnyOrder("Parent Resource", "Embedded Resource") ); } finally { - IOUtils.close(closeables); + for (URLClassLoader closeable : closeables) { + PrivilegedOperations.closeURLClassLoader(closeable); + } } } @@ -455,7 +463,9 @@ private void testResourcesVersioned(String resourcePath, boolean enableMulti, in assertThat(new String(is.readAllBytes(), UTF_8), is("Hello World" + expectedVersion)); } } finally { - IOUtils.close(closeables); + for (URLClassLoader closeable : closeables) { + PrivilegedOperations.closeURLClassLoader(closeable); + } } } @@ -483,7 +493,8 @@ public void testIDontHaveIt() throws Exception { Path outerJar = topLevelDir.resolve("impl.jar"); JarUtils.createJarWithEntriesUTF(outerJar, jarEntries); URL[] urls = new URL[] { outerJar.toUri().toURL() }; - try (URLClassLoader parent = loader(urls)) { + URLClassLoader parent = URLClassLoader.newInstance(urls, EmbeddedImplClassLoaderTests.class.getClassLoader()); + try { embedLoader = EmbeddedImplClassLoader.getInstance(parent, "res"); Class c = embedLoader.loadClass("java.lang.Object"); @@ -503,6 +514,8 @@ public void testIDontHaveIt() throws Exception { expectThrows(NPE, () -> embedLoader.getResourceAsStream(null)); expectThrows(NPE, () -> embedLoader.resources(null)); expectThrows(NPE, () -> embedLoader.loadClass(null)); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } @@ -529,7 +542,8 @@ public void testLoadWithJarDependencies() throws Exception { JarUtils.createJarWithEntries(outerJar, jarEntries); URL[] urls = new URL[] { outerJar.toUri().toURL() }; - try (URLClassLoader parent = loader(urls)) { + URLClassLoader parent = URLClassLoader.newInstance(urls, EmbeddedImplClassLoaderTests.class.getClassLoader()); + try { EmbeddedImplClassLoader loader = EmbeddedImplClassLoader.getInstance(parent, "blah"); Class c = loader.loadClass("p.Foo"); Object obj = c.getConstructor().newInstance(); @@ -541,6 +555,8 @@ public void testLoadWithJarDependencies() throws Exception { expectThrows(CNFE, () -> loader.loadClass("p.Unknown")); expectThrows(CNFE, () -> loader.loadClass("q.Unknown")); expectThrows(CNFE, () -> loader.loadClass("r.Unknown")); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } @@ -561,20 +577,18 @@ public void testResourcesWithMultipleJars() throws Exception { Path outerJar = topLevelDir.resolve("impl.jar"); JarUtils.createJarWithEntriesUTF(outerJar, jarEntries); URL[] urls = new URL[] { outerJar.toUri().toURL() }; - - try (URLClassLoader parent = loader(urls)) { + URLClassLoader parent = URLClassLoader.newInstance(urls, EmbeddedImplClassLoaderTests.class.getClassLoader()); + try { EmbeddedImplClassLoader loader = EmbeddedImplClassLoader.getInstance(parent, "blah"); var res = Collections.list(loader.getResources("res.txt")); assertThat(res, hasSize(3)); List l = res.stream().map(EmbeddedImplClassLoaderTests::urlToString).toList(); assertThat(l, containsInAnyOrder("fooRes", "barRes", "bazRes")); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } - private static URLClassLoader loader(URL[] urls) { - return URLClassLoader.newInstance(urls, EmbeddedImplClassLoaderTests.class.getClassLoader()); - } - @SuppressForbidden(reason = "file urls") static String urlToString(URL url) { try { diff --git a/libs/core/src/test/java/org/elasticsearch/core/internal/provider/ProviderLocatorTests.java b/libs/core/src/test/java/org/elasticsearch/core/internal/provider/ProviderLocatorTests.java index 63553ef72d08c..61d4a7c6bd700 100644 --- a/libs/core/src/test/java/org/elasticsearch/core/internal/provider/ProviderLocatorTests.java +++ b/libs/core/src/test/java/org/elasticsearch/core/internal/provider/ProviderLocatorTests.java @@ -10,11 +10,11 @@ package org.elasticsearch.core.internal.provider; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.PrivilegedOperations; import org.elasticsearch.test.compiler.InMemoryJavaCompiler; import org.elasticsearch.test.jar.JarUtils; import java.lang.module.ModuleDescriptor; -import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.nio.file.Files; @@ -117,8 +117,12 @@ public class FooIntSupplier implements java.util.function.IntSupplier { Path topLevelDir = createTempDir(getTestName()); Path outerJar = topLevelDir.resolve("impl.jar"); JarUtils.createJarWithEntries(outerJar, jarEntries); + URLClassLoader parent = URLClassLoader.newInstance( + new URL[] { outerJar.toUri().toURL() }, + ProviderLocatorTests.class.getClassLoader() + ); - try (URLClassLoader parent = loader(outerJar)) { + try { // test scenario ProviderLocator locator = new ProviderLocator<>("x-foo", IntSupplier.class, parent, "x.foo.impl", Set.of(), true); IntSupplier impl = locator.get(); @@ -135,6 +139,8 @@ public class FooIntSupplier implements java.util.function.IntSupplier { assertThat(md.exports(), containsInAnyOrder(exportsOf("p"))); assertThat(md.opens(), containsInAnyOrder(opensOf("q"))); assertThat(md.packages(), containsInAnyOrder(equalTo("p"), equalTo("q"), equalTo("r"))); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } @@ -166,8 +172,12 @@ public class FooLongSupplier implements java.util.function.LongSupplier { Path topLevelDir = createTempDir(getTestName()); Path outerJar = topLevelDir.resolve("impl.jar"); JarUtils.createJarWithEntries(outerJar, jarEntries); + URLClassLoader parent = URLClassLoader.newInstance( + new URL[] { outerJar.toUri().toURL() }, + ProviderLocatorTests.class.getClassLoader() + ); - try (URLClassLoader parent = loader(outerJar)) { + try { // test scenario ProviderLocator locator = new ProviderLocator<>("x-foo", LongSupplier.class, parent, "", Set.of(), false); LongSupplier impl = locator.get(); @@ -175,6 +185,8 @@ public class FooLongSupplier implements java.util.function.LongSupplier { assertThat(impl.toString(), equalTo("Hello from FooLongSupplier - non-modular!")); assertThat(impl.getClass().getName(), equalTo("p.FooLongSupplier")); assertThat(impl.getClass().getModule().isNamed(), is(false)); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } @@ -203,7 +215,12 @@ public class BarIntSupplier implements java.util.function.IntSupplier { Path pb = Files.createDirectories(barRoot.resolve("pb")); Files.write(pb.resolve("BarIntSupplier.class"), classToBytes.get("pb.BarIntSupplier")); - try (URLClassLoader parent = loader(topLevelDir)) { + URLClassLoader parent = URLClassLoader.newInstance( + new URL[] { topLevelDir.toUri().toURL() }, + ProviderLocatorTests.class.getClassLoader() + ); + + try { // test scenario ProviderLocator locator = new ProviderLocator<>("y-bar", IntSupplier.class, parent, "", Set.of(), false); IntSupplier impl = locator.get(); @@ -211,10 +228,8 @@ public class BarIntSupplier implements java.util.function.IntSupplier { assertThat(impl.toString(), equalTo("Hello from BarIntSupplier - exploded non-modular!")); assertThat(impl.getClass().getName(), equalTo("pb.BarIntSupplier")); assertThat(impl.getClass().getModule().isNamed(), is(false)); + } finally { + PrivilegedOperations.closeURLClassLoader(parent); } } - - private static URLClassLoader loader(Path jar) throws MalformedURLException { - return URLClassLoader.newInstance(new URL[] { jar.toUri().toURL() }, ProviderLocatorTests.class.getClassLoader()); - } } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementCheckerUtils.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementCheckerUtils.java deleted file mode 100644 index 684f20ae4b0bc..0000000000000 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementCheckerUtils.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.initialization; - -class EntitlementCheckerUtils { - - /** - * Returns the "most recent" checker class compatible with the provided runtime Java version. - * For checkers, we have (optionally) version specific classes, each with a prefix (e.g. Java23). - * The mapping cannot be automatic, as it depends on the actual presence of these classes in the final Jar (see - * the various mainXX source sets). - */ - static Class getVersionSpecificCheckerClass(Class baseClass, int javaVersion) { - String packageName = baseClass.getPackageName(); - String baseClassName = baseClass.getSimpleName(); - - final String classNamePrefix; - if (javaVersion >= 23) { - // All Java version from 23 onwards will be able to use che checks in the Java23EntitlementChecker interface and implementation - classNamePrefix = "Java23"; - } else { - // For any other Java version, the basic EntitlementChecker interface and implementation contains all the supported checks - classNamePrefix = ""; - } - final String className = packageName + "." + classNamePrefix + baseClassName; - Class clazz; - try { - clazz = Class.forName(className); - } catch (ClassNotFoundException e) { - throw new AssertionError("entitlement lib cannot find entitlement class " + className, e); - } - return clazz; - } -} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java index bd7c946fc1640..2e1c6c8f753bc 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java @@ -10,19 +10,51 @@ package org.elasticsearch.entitlement.initialization; import org.elasticsearch.core.Booleans; +import org.elasticsearch.core.Strings; import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap; import org.elasticsearch.entitlement.bridge.EntitlementChecker; import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker; +import org.elasticsearch.entitlement.runtime.policy.FileAccessTree; import org.elasticsearch.entitlement.runtime.policy.PathLookup; import org.elasticsearch.entitlement.runtime.policy.Policy; import org.elasticsearch.entitlement.runtime.policy.PolicyManager; +import org.elasticsearch.entitlement.runtime.policy.PolicyUtils; +import org.elasticsearch.entitlement.runtime.policy.Scope; +import org.elasticsearch.entitlement.runtime.policy.entitlements.CreateClassLoaderEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.Entitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.ExitVMEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.FileData; +import org.elasticsearch.entitlement.runtime.policy.entitlements.InboundNetworkEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.LoadNativeLibrariesEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.ManageThreadsEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.OutboundNetworkEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.ReadStoreAttributesEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.SetHttpsConnectionPropertiesEntitlement; +import org.elasticsearch.entitlement.runtime.policy.entitlements.WriteSystemPropertiesEntitlement; import java.lang.instrument.Instrumentation; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.CONFIG; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.DATA; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.LIB; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.LOGS; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.MODULES; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.PLUGINS; +import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.SHARED_REPO; +import static org.elasticsearch.entitlement.runtime.policy.Platform.LINUX; +import static org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.Mode.READ; +import static org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.Mode.READ_WRITE; + /** * Called by the agent during {@code agentmain} to configure the entitlement system, * instantiate and configure an {@link EntitlementChecker}, @@ -68,11 +100,7 @@ public static void initialize(Instrumentation inst) throws Exception { ensureClassesSensitiveToVerificationAreInitialized(); } - DynamicInstrumentation.initialize( - inst, - EntitlementCheckerUtils.getVersionSpecificCheckerClass(EntitlementChecker.class, Runtime.version().feature()), - verifyBytecode - ); + DynamicInstrumentation.initialize(inst, getVersionSpecificCheckerClass(EntitlementChecker.class), verifyBytecode); } private static PolicyManager createPolicyManager() { @@ -80,11 +108,151 @@ private static PolicyManager createPolicyManager() { Map pluginPolicies = bootstrapArgs.pluginPolicies(); PathLookup pathLookup = bootstrapArgs.pathLookup(); - FilesEntitlementsValidation.validate(pluginPolicies, pathLookup); + List serverScopes = new ArrayList<>(); + List serverModuleFileDatas = new ArrayList<>(); + Collections.addAll( + serverModuleFileDatas, + // Base ES directories + FileData.ofBaseDirPath(PLUGINS, READ), + FileData.ofBaseDirPath(MODULES, READ), + FileData.ofBaseDirPath(CONFIG, READ), + FileData.ofBaseDirPath(LOGS, READ_WRITE), + FileData.ofBaseDirPath(LIB, READ), + FileData.ofBaseDirPath(DATA, READ_WRITE), + FileData.ofBaseDirPath(SHARED_REPO, READ_WRITE), + // exclusive settings file + FileData.ofRelativePath(Path.of("operator/settings.json"), CONFIG, READ_WRITE).withExclusive(true), + // OS release on Linux + FileData.ofPath(Path.of("/etc/os-release"), READ).withPlatform(LINUX), + FileData.ofPath(Path.of("/etc/system-release"), READ).withPlatform(LINUX), + FileData.ofPath(Path.of("/usr/lib/os-release"), READ).withPlatform(LINUX), + // read max virtual memory areas + FileData.ofPath(Path.of("/proc/sys/vm/max_map_count"), READ).withPlatform(LINUX), + FileData.ofPath(Path.of("/proc/meminfo"), READ).withPlatform(LINUX), + // load averages on Linux + FileData.ofPath(Path.of("/proc/loadavg"), READ).withPlatform(LINUX), + // control group stats on Linux. cgroup v2 stats are in an unpredicable + // location under `/sys/fs/cgroup`, so unfortunately we have to allow + // read access to the entire directory hierarchy. + FileData.ofPath(Path.of("/proc/self/cgroup"), READ).withPlatform(LINUX), + FileData.ofPath(Path.of("/sys/fs/cgroup/"), READ).withPlatform(LINUX), + // // io stats on Linux + FileData.ofPath(Path.of("/proc/self/mountinfo"), READ).withPlatform(LINUX), + FileData.ofPath(Path.of("/proc/diskstats"), READ).withPlatform(LINUX) + ); + if (pathLookup.pidFile() != null) { + serverModuleFileDatas.add(FileData.ofPath(pathLookup.pidFile(), READ_WRITE)); + } + + Collections.addAll( + serverScopes, + new Scope( + "org.elasticsearch.base", + List.of( + new CreateClassLoaderEntitlement(), + new FilesEntitlement( + List.of( + // TODO: what in es.base is accessing shared repo? + FileData.ofBaseDirPath(SHARED_REPO, READ_WRITE), + FileData.ofBaseDirPath(DATA, READ_WRITE) + ) + ) + ) + ), + new Scope("org.elasticsearch.xcontent", List.of(new CreateClassLoaderEntitlement())), + new Scope( + "org.elasticsearch.server", + List.of( + new ExitVMEntitlement(), + new ReadStoreAttributesEntitlement(), + new CreateClassLoaderEntitlement(), + new InboundNetworkEntitlement(), + new LoadNativeLibrariesEntitlement(), + new ManageThreadsEntitlement(), + new FilesEntitlement(serverModuleFileDatas) + ) + ), + new Scope("java.desktop", List.of(new LoadNativeLibrariesEntitlement())), + new Scope("org.apache.httpcomponents.httpclient", List.of(new OutboundNetworkEntitlement())), + new Scope( + "org.apache.lucene.core", + List.of( + new LoadNativeLibrariesEntitlement(), + new ManageThreadsEntitlement(), + new FilesEntitlement(List.of(FileData.ofBaseDirPath(CONFIG, READ), FileData.ofBaseDirPath(DATA, READ_WRITE))) + ) + ), + new Scope( + "org.apache.lucene.misc", + List.of(new FilesEntitlement(List.of(FileData.ofBaseDirPath(DATA, READ_WRITE))), new ReadStoreAttributesEntitlement()) + ), + new Scope( + "org.apache.logging.log4j.core", + List.of(new ManageThreadsEntitlement(), new FilesEntitlement(List.of(FileData.ofBaseDirPath(LOGS, READ_WRITE)))) + ), + new Scope( + "org.elasticsearch.nativeaccess", + List.of(new LoadNativeLibrariesEntitlement(), new FilesEntitlement(List.of(FileData.ofBaseDirPath(DATA, READ_WRITE)))) + ) + ); + + // conditionally add FIPS entitlements if FIPS only functionality is enforced + if (Booleans.parseBoolean(System.getProperty("org.bouncycastle.fips.approved_only"), false)) { + // if custom trust store is set, grant read access to its location, otherwise use the default JDK trust store + String trustStore = System.getProperty("javax.net.ssl.trustStore"); + Path trustStorePath = trustStore != null + ? Path.of(trustStore) + : Path.of(System.getProperty("java.home")).resolve("lib/security/jssecacerts"); + + Collections.addAll( + serverScopes, + new Scope( + "org.bouncycastle.fips.tls", + List.of( + new FilesEntitlement(List.of(FileData.ofPath(trustStorePath, READ))), + new ManageThreadsEntitlement(), + new OutboundNetworkEntitlement() + ) + ), + new Scope( + "org.bouncycastle.fips.core", + // read to lib dir is required for checksum validation + List.of(new FilesEntitlement(List.of(FileData.ofBaseDirPath(LIB, READ))), new ManageThreadsEntitlement()) + ) + ); + } + + var serverPolicy = new Policy( + "server", + bootstrapArgs.serverPolicyPatch() == null + ? serverScopes + : PolicyUtils.mergeScopes(serverScopes, bootstrapArgs.serverPolicyPatch().scopes()) + ); + + // agents run without a module, so this is a special hack for the apm agent + // this should be removed once https://github.com/elastic/elasticsearch/issues/109335 is completed + // See also modules/apm/src/main/plugin-metadata/entitlement-policy.yaml + List agentEntitlements = List.of( + new CreateClassLoaderEntitlement(), + new ManageThreadsEntitlement(), + new SetHttpsConnectionPropertiesEntitlement(), + new OutboundNetworkEntitlement(), + new WriteSystemPropertiesEntitlement(Set.of("AsyncProfiler.safemode")), + new LoadNativeLibrariesEntitlement(), + new FilesEntitlement( + List.of( + FileData.ofBaseDirPath(LOGS, READ_WRITE), + FileData.ofPath(Path.of("/proc/meminfo"), READ), + FileData.ofPath(Path.of("/sys/fs/cgroup/"), READ) + ) + ) + ); + + validateFilesEntitlements(pluginPolicies, pathLookup); return new PolicyManager( - HardcodedEntitlements.serverPolicy(pathLookup.pidFile(), bootstrapArgs.serverPolicyPatch()), - HardcodedEntitlements.agentEntitlements(), + serverPolicy, + agentEntitlements, pluginPolicies, EntitlementBootstrap.bootstrapArgs().scopeResolver(), EntitlementBootstrap.bootstrapArgs().sourcePaths(), @@ -94,6 +262,74 @@ private static PolicyManager createPolicyManager() { ); } + // package visible for tests + static void validateFilesEntitlements(Map pluginPolicies, PathLookup pathLookup) { + Set readAccessForbidden = new HashSet<>(); + pathLookup.getBaseDirPaths(PLUGINS).forEach(p -> readAccessForbidden.add(p.toAbsolutePath().normalize())); + pathLookup.getBaseDirPaths(MODULES).forEach(p -> readAccessForbidden.add(p.toAbsolutePath().normalize())); + pathLookup.getBaseDirPaths(LIB).forEach(p -> readAccessForbidden.add(p.toAbsolutePath().normalize())); + Set writeAccessForbidden = new HashSet<>(); + pathLookup.getBaseDirPaths(CONFIG).forEach(p -> writeAccessForbidden.add(p.toAbsolutePath().normalize())); + for (var pluginPolicy : pluginPolicies.entrySet()) { + for (var scope : pluginPolicy.getValue().scopes()) { + var filesEntitlement = scope.entitlements() + .stream() + .filter(x -> x instanceof FilesEntitlement) + .map(x -> ((FilesEntitlement) x)) + .findFirst(); + if (filesEntitlement.isPresent()) { + var fileAccessTree = FileAccessTree.withoutExclusivePaths(filesEntitlement.get(), pathLookup, null); + validateReadFilesEntitlements(pluginPolicy.getKey(), scope.moduleName(), fileAccessTree, readAccessForbidden); + validateWriteFilesEntitlements(pluginPolicy.getKey(), scope.moduleName(), fileAccessTree, writeAccessForbidden); + } + } + } + } + + private static IllegalArgumentException buildValidationException( + String componentName, + String moduleName, + Path forbiddenPath, + FilesEntitlement.Mode mode + ) { + return new IllegalArgumentException( + Strings.format( + "policy for module [%s] in [%s] has an invalid file entitlement. Any path under [%s] is forbidden for mode [%s].", + moduleName, + componentName, + forbiddenPath, + mode + ) + ); + } + + private static void validateReadFilesEntitlements( + String componentName, + String moduleName, + FileAccessTree fileAccessTree, + Set readForbiddenPaths + ) { + + for (Path forbiddenPath : readForbiddenPaths) { + if (fileAccessTree.canRead(forbiddenPath)) { + throw buildValidationException(componentName, moduleName, forbiddenPath, READ); + } + } + } + + private static void validateWriteFilesEntitlements( + String componentName, + String moduleName, + FileAccessTree fileAccessTree, + Set writeForbiddenPaths + ) { + for (Path forbiddenPath : writeForbiddenPaths) { + if (fileAccessTree.canWrite(forbiddenPath)) { + throw buildValidationException(componentName, moduleName, forbiddenPath, READ_WRITE); + } + } + } + /** * If bytecode verification is enabled, ensure these classes get loaded before transforming/retransforming them. * For these classes, the order in which we transform and verify them matters. Verification during class transformation is at least an @@ -112,13 +348,39 @@ private static void ensureClassesSensitiveToVerificationAreInitialized() { } } + /** + * Returns the "most recent" checker class compatible with the current runtime Java version. + * For checkers, we have (optionally) version specific classes, each with a prefix (e.g. Java23). + * The mapping cannot be automatic, as it depends on the actual presence of these classes in the final Jar (see + * the various mainXX source sets). + */ + private static Class getVersionSpecificCheckerClass(Class baseClass) { + String packageName = baseClass.getPackageName(); + String baseClassName = baseClass.getSimpleName(); + int javaVersion = Runtime.version().feature(); + + final String classNamePrefix; + if (javaVersion >= 23) { + // All Java version from 23 onwards will be able to use che checks in the Java23EntitlementChecker interface and implementation + classNamePrefix = "Java23"; + } else { + // For any other Java version, the basic EntitlementChecker interface and implementation contains all the supported checks + classNamePrefix = ""; + } + final String className = packageName + "." + classNamePrefix + baseClassName; + Class clazz; + try { + clazz = Class.forName(className); + } catch (ClassNotFoundException e) { + throw new AssertionError("entitlement lib cannot find entitlement class " + className, e); + } + return clazz; + } + private static ElasticsearchEntitlementChecker initChecker() { final PolicyManager policyManager = createPolicyManager(); - final Class clazz = EntitlementCheckerUtils.getVersionSpecificCheckerClass( - ElasticsearchEntitlementChecker.class, - Runtime.version().feature() - ); + final Class clazz = getVersionSpecificCheckerClass(ElasticsearchEntitlementChecker.class); Constructor constructor; try { diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/FilesEntitlementsValidation.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/FilesEntitlementsValidation.java deleted file mode 100644 index 4e0cc8f3a0a8a..0000000000000 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/FilesEntitlementsValidation.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.initialization; - -import org.elasticsearch.core.Strings; -import org.elasticsearch.entitlement.runtime.policy.FileAccessTree; -import org.elasticsearch.entitlement.runtime.policy.PathLookup; -import org.elasticsearch.entitlement.runtime.policy.Policy; -import org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement; - -import java.nio.file.Path; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.CONFIG; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.LIB; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.MODULES; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.PLUGINS; -import static org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.Mode.READ; -import static org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.Mode.READ_WRITE; - -class FilesEntitlementsValidation { - - static void validate(Map pluginPolicies, PathLookup pathLookup) { - Set readAccessForbidden = new HashSet<>(); - pathLookup.getBaseDirPaths(PLUGINS).forEach(p -> readAccessForbidden.add(p.toAbsolutePath().normalize())); - pathLookup.getBaseDirPaths(MODULES).forEach(p -> readAccessForbidden.add(p.toAbsolutePath().normalize())); - pathLookup.getBaseDirPaths(LIB).forEach(p -> readAccessForbidden.add(p.toAbsolutePath().normalize())); - Set writeAccessForbidden = new HashSet<>(); - pathLookup.getBaseDirPaths(CONFIG).forEach(p -> writeAccessForbidden.add(p.toAbsolutePath().normalize())); - for (var pluginPolicy : pluginPolicies.entrySet()) { - for (var scope : pluginPolicy.getValue().scopes()) { - var filesEntitlement = scope.entitlements() - .stream() - .filter(x -> x instanceof FilesEntitlement) - .map(x -> ((FilesEntitlement) x)) - .findFirst(); - if (filesEntitlement.isPresent()) { - var fileAccessTree = FileAccessTree.withoutExclusivePaths(filesEntitlement.get(), pathLookup, null); - validateReadFilesEntitlements(pluginPolicy.getKey(), scope.moduleName(), fileAccessTree, readAccessForbidden); - validateWriteFilesEntitlements(pluginPolicy.getKey(), scope.moduleName(), fileAccessTree, writeAccessForbidden); - } - } - } - } - - private static IllegalArgumentException buildValidationException( - String componentName, - String moduleName, - Path forbiddenPath, - FilesEntitlement.Mode mode - ) { - return new IllegalArgumentException( - Strings.format( - "policy for module [%s] in [%s] has an invalid file entitlement. Any path under [%s] is forbidden for mode [%s].", - moduleName, - componentName, - forbiddenPath, - mode - ) - ); - } - - private static void validateReadFilesEntitlements( - String componentName, - String moduleName, - FileAccessTree fileAccessTree, - Set readForbiddenPaths - ) { - - for (Path forbiddenPath : readForbiddenPaths) { - if (fileAccessTree.canRead(forbiddenPath)) { - throw buildValidationException(componentName, moduleName, forbiddenPath, READ); - } - } - } - - private static void validateWriteFilesEntitlements( - String componentName, - String moduleName, - FileAccessTree fileAccessTree, - Set writeForbiddenPaths - ) { - for (Path forbiddenPath : writeForbiddenPaths) { - if (fileAccessTree.canWrite(forbiddenPath)) { - throw buildValidationException(componentName, moduleName, forbiddenPath, READ_WRITE); - } - } - } -} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/HardcodedEntitlements.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/HardcodedEntitlements.java deleted file mode 100644 index 33f197b0a63d9..0000000000000 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/HardcodedEntitlements.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.initialization; - -import org.elasticsearch.core.Booleans; -import org.elasticsearch.entitlement.runtime.policy.Policy; -import org.elasticsearch.entitlement.runtime.policy.PolicyUtils; -import org.elasticsearch.entitlement.runtime.policy.Scope; -import org.elasticsearch.entitlement.runtime.policy.entitlements.CreateClassLoaderEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.Entitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.ExitVMEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.InboundNetworkEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.LoadNativeLibrariesEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.ManageThreadsEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.OutboundNetworkEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.ReadStoreAttributesEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.SetHttpsConnectionPropertiesEntitlement; -import org.elasticsearch.entitlement.runtime.policy.entitlements.WriteSystemPropertiesEntitlement; - -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Set; - -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.CONFIG; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.DATA; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.LIB; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.LOGS; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.MODULES; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.PLUGINS; -import static org.elasticsearch.entitlement.runtime.policy.PathLookup.BaseDir.SHARED_REPO; -import static org.elasticsearch.entitlement.runtime.policy.Platform.LINUX; -import static org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.Mode.READ; -import static org.elasticsearch.entitlement.runtime.policy.entitlements.FilesEntitlement.Mode.READ_WRITE; - -class HardcodedEntitlements { - - private static List createServerEntitlements(Path pidFile) { - - List serverScopes = new ArrayList<>(); - List serverModuleFileDatas = new ArrayList<>(); - Collections.addAll( - serverModuleFileDatas, - // Base ES directories - FilesEntitlement.FileData.ofBaseDirPath(PLUGINS, READ), - FilesEntitlement.FileData.ofBaseDirPath(MODULES, READ), - FilesEntitlement.FileData.ofBaseDirPath(CONFIG, READ), - FilesEntitlement.FileData.ofBaseDirPath(LOGS, READ_WRITE), - FilesEntitlement.FileData.ofBaseDirPath(LIB, READ), - FilesEntitlement.FileData.ofBaseDirPath(DATA, READ_WRITE), - FilesEntitlement.FileData.ofBaseDirPath(SHARED_REPO, READ_WRITE), - // exclusive settings file - FilesEntitlement.FileData.ofRelativePath(Path.of("operator/settings.json"), CONFIG, READ_WRITE).withExclusive(true), - // OS release on Linux - FilesEntitlement.FileData.ofPath(Path.of("/etc/os-release"), READ).withPlatform(LINUX), - FilesEntitlement.FileData.ofPath(Path.of("/etc/system-release"), READ).withPlatform(LINUX), - FilesEntitlement.FileData.ofPath(Path.of("/usr/lib/os-release"), READ).withPlatform(LINUX), - // read max virtual memory areas - FilesEntitlement.FileData.ofPath(Path.of("/proc/sys/vm/max_map_count"), READ).withPlatform(LINUX), - FilesEntitlement.FileData.ofPath(Path.of("/proc/meminfo"), READ).withPlatform(LINUX), - // load averages on Linux - FilesEntitlement.FileData.ofPath(Path.of("/proc/loadavg"), READ).withPlatform(LINUX), - // control group stats on Linux. cgroup v2 stats are in an unpredicable - // location under `/sys/fs/cgroup`, so unfortunately we have to allow - // read access to the entire directory hierarchy. - FilesEntitlement.FileData.ofPath(Path.of("/proc/self/cgroup"), READ).withPlatform(LINUX), - FilesEntitlement.FileData.ofPath(Path.of("/sys/fs/cgroup/"), READ).withPlatform(LINUX), - // // io stats on Linux - FilesEntitlement.FileData.ofPath(Path.of("/proc/self/mountinfo"), READ).withPlatform(LINUX), - FilesEntitlement.FileData.ofPath(Path.of("/proc/diskstats"), READ).withPlatform(LINUX) - ); - if (pidFile != null) { - serverModuleFileDatas.add(FilesEntitlement.FileData.ofPath(pidFile, READ_WRITE)); - } - - Collections.addAll( - serverScopes, - new Scope( - "org.elasticsearch.base", - List.of( - new CreateClassLoaderEntitlement(), - new FilesEntitlement( - List.of( - // TODO: what in es.base is accessing shared repo? - FilesEntitlement.FileData.ofBaseDirPath(SHARED_REPO, READ_WRITE), - FilesEntitlement.FileData.ofBaseDirPath(DATA, READ_WRITE) - ) - ) - ) - ), - new Scope("org.elasticsearch.xcontent", List.of(new CreateClassLoaderEntitlement())), - new Scope( - "org.elasticsearch.server", - List.of( - new ExitVMEntitlement(), - new ReadStoreAttributesEntitlement(), - new CreateClassLoaderEntitlement(), - new InboundNetworkEntitlement(), - new LoadNativeLibrariesEntitlement(), - new ManageThreadsEntitlement(), - new FilesEntitlement(serverModuleFileDatas) - ) - ), - new Scope("java.desktop", List.of(new LoadNativeLibrariesEntitlement())), - new Scope("org.apache.httpcomponents.httpclient", List.of(new OutboundNetworkEntitlement())), - new Scope( - "org.apache.lucene.core", - List.of( - new LoadNativeLibrariesEntitlement(), - new ManageThreadsEntitlement(), - new FilesEntitlement( - List.of( - FilesEntitlement.FileData.ofBaseDirPath(CONFIG, READ), - FilesEntitlement.FileData.ofBaseDirPath(DATA, READ_WRITE) - ) - ) - ) - ), - new Scope( - "org.apache.lucene.misc", - List.of( - new FilesEntitlement(List.of(FilesEntitlement.FileData.ofBaseDirPath(DATA, READ_WRITE))), - new ReadStoreAttributesEntitlement() - ) - ), - new Scope( - "org.apache.logging.log4j.core", - List.of( - new ManageThreadsEntitlement(), - new FilesEntitlement(List.of(FilesEntitlement.FileData.ofBaseDirPath(LOGS, READ_WRITE))) - ) - ), - new Scope( - "org.elasticsearch.nativeaccess", - List.of( - new LoadNativeLibrariesEntitlement(), - new FilesEntitlement(List.of(FilesEntitlement.FileData.ofBaseDirPath(DATA, READ_WRITE))) - ) - ) - ); - - // conditionally add FIPS entitlements if FIPS only functionality is enforced - if (Booleans.parseBoolean(System.getProperty("org.bouncycastle.fips.approved_only"), false)) { - // if custom trust store is set, grant read access to its location, otherwise use the default JDK trust store - String trustStore = System.getProperty("javax.net.ssl.trustStore"); - Path trustStorePath = trustStore != null - ? Path.of(trustStore) - : Path.of(System.getProperty("java.home")).resolve("lib/security/jssecacerts"); - - Collections.addAll( - serverScopes, - new Scope( - "org.bouncycastle.fips.tls", - List.of( - new FilesEntitlement(List.of(FilesEntitlement.FileData.ofPath(trustStorePath, READ))), - new ManageThreadsEntitlement(), - new OutboundNetworkEntitlement() - ) - ), - new Scope( - "org.bouncycastle.fips.core", - // read to lib dir is required for checksum validation - List.of( - new FilesEntitlement(List.of(FilesEntitlement.FileData.ofBaseDirPath(LIB, READ))), - new ManageThreadsEntitlement() - ) - ) - ); - } - return serverScopes; - } - - static Policy serverPolicy(Path pidFile, Policy serverPolicyPatch) { - var serverScopes = createServerEntitlements(pidFile); - return new Policy( - "server", - serverPolicyPatch == null ? serverScopes : PolicyUtils.mergeScopes(serverScopes, serverPolicyPatch.scopes()) - ); - } - - // agents run without a module, so this is a special hack for the apm agent - // this should be removed once https://github.com/elastic/elasticsearch/issues/109335 is completed - // See also modules/apm/src/main/plugin-metadata/entitlement-policy.yaml - static List agentEntitlements() { - return List.of( - new CreateClassLoaderEntitlement(), - new ManageThreadsEntitlement(), - new SetHttpsConnectionPropertiesEntitlement(), - new OutboundNetworkEntitlement(), - new WriteSystemPropertiesEntitlement(Set.of("AsyncProfiler.safemode")), - new LoadNativeLibrariesEntitlement(), - new FilesEntitlement( - List.of( - FilesEntitlement.FileData.ofBaseDirPath(LOGS, READ_WRITE), - FilesEntitlement.FileData.ofPath(Path.of("/proc/meminfo"), READ), - FilesEntitlement.FileData.ofPath(Path.of("/sys/fs/cgroup/"), READ) - ) - ) - ); - } -} diff --git a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/initialization/FilesEntitlementsValidationTests.java b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/initialization/EntitlementInitializationTests.java similarity index 91% rename from libs/entitlement/src/test/java/org/elasticsearch/entitlement/initialization/FilesEntitlementsValidationTests.java rename to libs/entitlement/src/test/java/org/elasticsearch/entitlement/initialization/EntitlementInitializationTests.java index 4ca57a99e0a32..6bbcec9cc400a 100644 --- a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/initialization/FilesEntitlementsValidationTests.java +++ b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/initialization/EntitlementInitializationTests.java @@ -27,7 +27,7 @@ import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.startsWith; -public class FilesEntitlementsValidationTests extends ESTestCase { +public class EntitlementInitializationTests extends ESTestCase { private static PathLookup TEST_PATH_LOOKUP; @@ -75,7 +75,7 @@ public void testValidationPass() { ) ) ); - FilesEntitlementsValidation.validate(Map.of("plugin", policy), TEST_PATH_LOOKUP); + EntitlementInitialization.validateFilesEntitlements(Map.of("plugin", policy), TEST_PATH_LOOKUP); } public void testValidationFailForRead() { @@ -94,7 +94,7 @@ public void testValidationFailForRead() { var ex = expectThrows( IllegalArgumentException.class, - () -> FilesEntitlementsValidation.validate(Map.of("plugin", policy), TEST_PATH_LOOKUP) + () -> EntitlementInitialization.validateFilesEntitlements(Map.of("plugin", policy), TEST_PATH_LOOKUP) ); assertThat( ex.getMessage(), @@ -119,7 +119,7 @@ public void testValidationFailForRead() { ex = expectThrows( IllegalArgumentException.class, - () -> FilesEntitlementsValidation.validate(Map.of("plugin2", policy2), TEST_PATH_LOOKUP) + () -> EntitlementInitialization.validateFilesEntitlements(Map.of("plugin2", policy2), TEST_PATH_LOOKUP) ); assertThat( ex.getMessage(), @@ -145,7 +145,7 @@ public void testValidationFailForWrite() { var ex = expectThrows( IllegalArgumentException.class, - () -> FilesEntitlementsValidation.validate(Map.of("plugin", policy), TEST_PATH_LOOKUP) + () -> EntitlementInitialization.validateFilesEntitlements(Map.of("plugin", policy), TEST_PATH_LOOKUP) ); assertThat( ex.getMessage(), diff --git a/libs/secure-sm/build.gradle b/libs/secure-sm/build.gradle new file mode 100644 index 0000000000000..d93afcf84afed --- /dev/null +++ b/libs/secure-sm/build.gradle @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +apply plugin: 'elasticsearch.publish' + +dependencies { + // do not add non-test compile dependencies to secure-sm without a good reason to do so + + testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}" + testImplementation "junit:junit:${versions.junit}" + testImplementation "org.hamcrest:hamcrest:${versions.hamcrest}" + + testImplementation(project(":test:framework")) { + exclude group: 'org.elasticsearch', module: 'secure-sm' + } +} + +tasks.named('forbiddenApisMain').configure { + replaceSignatureFiles 'jdk-signatures' +} + +// JAR hell is part of core which we do not want to add as a dependency +tasks.named("jarHell").configure { enabled = false } +tasks.named("testTestingConventions").configure { + baseClass 'junit.framework.TestCase' + baseClass 'org.junit.Assert' +} diff --git a/libs/secure-sm/src/main/java/module-info.java b/libs/secure-sm/src/main/java/module-info.java new file mode 100644 index 0000000000000..fb3b6e357a1e5 --- /dev/null +++ b/libs/secure-sm/src/main/java/module-info.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +module org.elasticsearch.securesm { + exports org.elasticsearch.secure_sm; +} diff --git a/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java new file mode 100644 index 0000000000000..02d0491118dc7 --- /dev/null +++ b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java @@ -0,0 +1,275 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.secure_sm; + +import java.security.AccessController; +import java.security.Permission; +import java.security.PrivilegedAction; +import java.util.Objects; + +/** + * Extension of SecurityManager that works around a few design flaws in Java Security. + *

+ * There are a few major problems that require custom {@code SecurityManager} logic to fix: + *

    + *
  • {@code exitVM} permission is implicitly granted to all code by the default + * Policy implementation. For a server app, this is not wanted.
  • + *
  • ThreadGroups are not enforced by default, instead only system threads are + * protected out of box by {@code modifyThread/modifyThreadGroup}. Applications + * are encouraged to override the logic here to implement a stricter policy. + *
  • System threads are not even really protected, because if the system uses + * ThreadPools, {@code modifyThread} is abused by its {@code shutdown} checks. This means + * a thread must have {@code modifyThread} to even terminate its own pool, leaving + * system threads unprotected. + *
+ * This class throws exception on {@code exitVM} calls, and provides a whitelist where calls + * from exit are allowed. + *

+ * Additionally it enforces threadgroup security with the following rules: + *

    + *
  • {@code modifyThread} and {@code modifyThreadGroup} are required for any thread access + * checks: with these permissions, access is granted as long as the thread group is + * the same or an ancestor ({@code sourceGroup.parentOf(targetGroup) == true}). + *
  • code without these permissions can do very little, except to interrupt itself. It may + * not even create new threads. + *
  • very special cases (like test runners) that have {@link ThreadPermission} can violate + * threadgroup security rules. + *
+ *

+ * If java security debugging ({@code java.security.debug}) is enabled, and this SecurityManager + * is installed, it will emit additional debugging information when threadgroup access checks fail. + * + * @see SecurityManager#checkAccess(Thread) + * @see SecurityManager#checkAccess(ThreadGroup) + * @see + * http://cs.oswego.edu/pipermail/concurrency-interest/2009-August/006508.html + */ +public class SecureSM extends SecurityManager { + + private final String[] classesThatCanExit; + + /** + * Creates a new security manager where no packages can exit nor halt the virtual machine. + */ + public SecureSM() { + this(new String[0]); + } + + /** + * Creates a new security manager with the specified list of regular expressions as the those that class names will be tested against to + * check whether or not a class can exit or halt the virtual machine. + * + * @param classesThatCanExit the list of classes that can exit or halt the virtual machine + */ + public SecureSM(final String[] classesThatCanExit) { + this.classesThatCanExit = classesThatCanExit; + } + + /** + * Creates a new security manager with a standard set of test packages being the only packages that can exit or halt the virtual + * machine. The packages that can exit are: + *

    + *
  • org.apache.maven.surefire.booter.
  • + *
  • com.carrotsearch.ant.tasks.junit4.
  • + *
  • org.eclipse.internal.junit.runner.
  • + *
  • com.intellij.rt.execution.junit.
  • + *
+ * + * @return an instance of SecureSM where test packages can halt or exit the virtual machine + */ + public static SecureSM createTestSecureSM() { + return new SecureSM(TEST_RUNNER_PACKAGES); + } + + static final String[] TEST_RUNNER_PACKAGES = new String[] { + // surefire test runner + "org\\.apache\\.maven\\.surefire\\.booter\\..*", + // junit4 test runner + "com\\.carrotsearch\\.ant\\.tasks\\.junit4\\.slave\\..*", + // eclipse test runner + "org\\.eclipse.jdt\\.internal\\.junit\\.runner\\..*", + // intellij test runner (before IDEA version 2019.3) + "com\\.intellij\\.rt\\.execution\\.junit\\..*", + // intellij test runner (since IDEA version 2019.3) + "com\\.intellij\\.rt\\.junit\\..*" }; + + // java.security.debug support + private static final boolean DEBUG = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Boolean run() { + try { + String v = System.getProperty("java.security.debug"); + // simple check that they are trying to debug + return v != null && v.length() > 0; + } catch (SecurityException e) { + return false; + } + } + }); + + @Override + @SuppressForbidden(reason = "java.security.debug messages go to standard error") + public void checkAccess(Thread t) { + try { + checkThreadAccess(t); + } catch (SecurityException e) { + if (DEBUG) { + System.err.println("access: caller thread=" + Thread.currentThread()); + System.err.println("access: target thread=" + t); + debugThreadGroups(Thread.currentThread().getThreadGroup(), t.getThreadGroup()); + } + throw e; + } + } + + @Override + @SuppressForbidden(reason = "java.security.debug messages go to standard error") + public void checkAccess(ThreadGroup g) { + try { + checkThreadGroupAccess(g); + } catch (SecurityException e) { + if (DEBUG) { + System.err.println("access: caller thread=" + Thread.currentThread()); + debugThreadGroups(Thread.currentThread().getThreadGroup(), g); + } + throw e; + } + } + + @SuppressForbidden(reason = "java.security.debug messages go to standard error") + private static void debugThreadGroups(final ThreadGroup caller, final ThreadGroup target) { + System.err.println("access: caller group=" + caller); + System.err.println("access: target group=" + target); + } + + // thread permission logic + + private static final Permission MODIFY_THREAD_PERMISSION = new RuntimePermission("modifyThread"); + private static final Permission MODIFY_ARBITRARY_THREAD_PERMISSION = new ThreadPermission("modifyArbitraryThread"); + + // Returns true if the given thread is an instance of the JDK's InnocuousThread. + private static boolean isInnocuousThread(Thread t) { + final Class c = t.getClass(); + return c.getModule() == Object.class.getModule() + && (c.getName().equals("jdk.internal.misc.InnocuousThread") + || c.getName().equals("java.util.concurrent.ForkJoinWorkerThread$InnocuousForkJoinWorkerThread")); + } + + protected void checkThreadAccess(Thread t) { + Objects.requireNonNull(t); + + boolean targetThreadIsInnocuous = isInnocuousThread(t); + // we don't need to check if innocuous thread is modifying itself (like changes its name) + if (Thread.currentThread() != t || targetThreadIsInnocuous == false) { + // first, check if we can modify threads at all. + checkPermission(MODIFY_THREAD_PERMISSION); + } + + // check the threadgroup, if its our thread group or an ancestor, its fine. + final ThreadGroup source = Thread.currentThread().getThreadGroup(); + final ThreadGroup target = t.getThreadGroup(); + + if (target == null) { + return; // its a dead thread, do nothing. + } else if (source.parentOf(target) == false && targetThreadIsInnocuous == false) { + checkPermission(MODIFY_ARBITRARY_THREAD_PERMISSION); + } + } + + private static final Permission MODIFY_THREADGROUP_PERMISSION = new RuntimePermission("modifyThreadGroup"); + private static final Permission MODIFY_ARBITRARY_THREADGROUP_PERMISSION = new ThreadPermission("modifyArbitraryThreadGroup"); + + // Returns true if the given thread is an instance of the JDK's InnocuousThread. + private static boolean isInnocuousThreadGroup(ThreadGroup t) { + final Class c = t.getClass(); + return c.getModule() == Object.class.getModule() && t.getName().equals("InnocuousForkJoinWorkerThreadGroup"); + } + + protected void checkThreadGroupAccess(ThreadGroup g) { + Objects.requireNonNull(g); + + boolean targetThreadGroupIsInnocuous = isInnocuousThreadGroup(g); + + // first, check if we can modify thread groups at all. + if (targetThreadGroupIsInnocuous == false) { + checkPermission(MODIFY_THREADGROUP_PERMISSION); + } + + // check the threadgroup, if its our thread group or an ancestor, its fine. + final ThreadGroup source = Thread.currentThread().getThreadGroup(); + final ThreadGroup target = g; + + if (source == null) { + return; // we are a dead thread, do nothing + } else if (source.parentOf(target) == false && targetThreadGroupIsInnocuous == false) { + checkPermission(MODIFY_ARBITRARY_THREADGROUP_PERMISSION); + } + } + + // exit permission logic + @Override + public void checkExit(int status) { + innerCheckExit(status); + } + + /** + * The "Uwe Schindler" algorithm. + * + * @param status the exit status + */ + protected void innerCheckExit(final int status) { + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Void run() { + final String systemClassName = System.class.getName(), runtimeClassName = Runtime.class.getName(); + String exitMethodHit = null; + for (final StackTraceElement se : Thread.currentThread().getStackTrace()) { + final String className = se.getClassName(), methodName = se.getMethodName(); + if (("exit".equals(methodName) || "halt".equals(methodName)) + && (systemClassName.equals(className) || runtimeClassName.equals(className))) { + exitMethodHit = className + '#' + methodName + '(' + status + ')'; + continue; + } + + if (exitMethodHit != null) { + if (classesThatCanExit == null) { + break; + } + if (classCanExit(className, classesThatCanExit)) { + // this exit point is allowed, we return normally from closure: + return null; + } + // anything else in stack trace is not allowed, break and throw SecurityException below: + break; + } + } + + if (exitMethodHit == null) { + // should never happen, only if JVM hides stack trace - replace by generic: + exitMethodHit = "JVM exit method"; + } + throw new SecurityException(exitMethodHit + " calls are not allowed"); + } + }); + + // we passed the stack check, delegate to super, so default policy can still deny permission: + super.checkExit(status); + } + + static boolean classCanExit(final String className, final String[] classesThatCanExit) { + for (final String classThatCanExit : classesThatCanExit) { + if (className.matches(classThatCanExit)) { + return true; + } + } + return false; + } + +} diff --git a/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SuppressForbidden.java b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SuppressForbidden.java new file mode 100644 index 0000000000000..6586097b5ddf4 --- /dev/null +++ b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SuppressForbidden.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.secure_sm; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to suppress forbidden-apis errors inside a whole class, a method, or a field. + */ +@Retention(RetentionPolicy.CLASS) +@Target({ ElementType.CONSTRUCTOR, ElementType.FIELD, ElementType.METHOD, ElementType.TYPE }) +@interface SuppressForbidden { + String reason(); +} diff --git a/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/ThreadPermission.java b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/ThreadPermission.java new file mode 100644 index 0000000000000..caae4acd888ef --- /dev/null +++ b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/ThreadPermission.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.secure_sm; + +import java.security.BasicPermission; + +/** + * Permission to modify threads or thread groups normally not accessible + * to the current thread. + *

+ * {@link SecureSM} enforces ThreadGroup security: threads with + * {@code RuntimePermission("modifyThread")} or {@code RuntimePermission("modifyThreadGroup")} + * are only allowed to modify their current thread group or an ancestor of that group. + *

+ * In some cases (e.g. test runners), code needs to manipulate arbitrary threads, + * so this Permission provides for that: the targets {@code modifyArbitraryThread} and + * {@code modifyArbitraryThreadGroup} allow a thread blanket access to any group. + * + * @see ThreadGroup + * @see SecureSM + */ +public final class ThreadPermission extends BasicPermission { + + /** + * Creates a new ThreadPermission object. + * + * @param name target name + */ + public ThreadPermission(String name) { + super(name); + } + + /** + * Creates a new ThreadPermission object. + * This constructor exists for use by the {@code Policy} object to instantiate new Permission objects. + * + * @param name target name + * @param actions ignored + */ + public ThreadPermission(String name, String actions) { + super(name, actions); + } +} diff --git a/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java b/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java new file mode 100644 index 0000000000000..965696d13613f --- /dev/null +++ b/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.secure_sm; + +import com.carrotsearch.randomizedtesting.JUnit3MethodProvider; +import com.carrotsearch.randomizedtesting.RandomizedRunner; +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.TestMethodProviders; + +import org.elasticsearch.jdk.RuntimeVersionFeature; +import org.junit.BeforeClass; +import org.junit.runner.RunWith; + +import java.security.Permission; +import java.security.Policy; +import java.security.ProtectionDomain; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +/** Simple tests for SecureSM */ +@TestMethodProviders({ JUnit3MethodProvider.class }) +@RunWith(RandomizedRunner.class) +public class SecureSMTests extends org.junit.Assert { + + @BeforeClass + public static void initialize() { + RandomizedTest.assumeFalse( + "SecurityManager has been permanently removed in JDK 24", + RuntimeVersionFeature.isSecurityManagerAvailable() == false + ); + // install a mock security policy: + // AllPermission to source code + // ThreadPermission not granted anywhere else + final var sourceCode = Set.of(SecureSM.class.getProtectionDomain(), RandomizedRunner.class.getProtectionDomain()); + Policy.setPolicy(new Policy() { + @Override + public boolean implies(ProtectionDomain domain, Permission permission) { + if (sourceCode.contains(domain)) { + return true; + } else if (permission instanceof ThreadPermission) { + return false; + } + return true; + } + }); + System.setSecurityManager(SecureSM.createTestSecureSM()); + } + + @SuppressForbidden(reason = "testing that System#exit is blocked") + public void testTryToExit() { + try { + System.exit(1); + fail("did not hit expected exception"); + } catch (SecurityException expected) {} + } + + public void testClassCanExit() { + assertTrue(SecureSM.classCanExit("org.apache.maven.surefire.booter.CommandReader", SecureSM.TEST_RUNNER_PACKAGES)); + assertTrue(SecureSM.classCanExit("com.carrotsearch.ant.tasks.junit4.slave.JvmExit", SecureSM.TEST_RUNNER_PACKAGES)); + assertTrue(SecureSM.classCanExit("org.eclipse.jdt.internal.junit.runner.RemoteTestRunner", SecureSM.TEST_RUNNER_PACKAGES)); + assertTrue(SecureSM.classCanExit("com.intellij.rt.execution.junit.JUnitStarter", SecureSM.TEST_RUNNER_PACKAGES)); + assertTrue(SecureSM.classCanExit("org.elasticsearch.Foo", new String[] { "org.elasticsearch.Foo" })); + assertFalse(SecureSM.classCanExit("org.elasticsearch.Foo", new String[] { "org.elasticsearch.Bar" })); + } + + public void testCreateThread() throws Exception { + Thread t = new Thread(); + t.start(); + t.join(); + // no exception + } + + public void testCreateThreadGroup() throws Exception { + Thread t = new Thread(new ThreadGroup("childgroup"), "child"); + t.start(); + t.join(); + // no exception + } + + public void testModifyChild() throws Exception { + final AtomicBoolean interrupted = new AtomicBoolean(false); + Thread t = new Thread(new ThreadGroup("childgroup"), "child") { + @Override + public void run() { + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException expected) { + interrupted.set(true); + } + } + }; + t.start(); + t.interrupt(); + t.join(); + // no exception + assertTrue(interrupted.get()); + } + + public void testNoModifySibling() throws Exception { + final AtomicBoolean interrupted1 = new AtomicBoolean(false); + final AtomicBoolean interrupted2 = new AtomicBoolean(false); + + final Thread t1 = new Thread(new ThreadGroup("childgroup"), "child") { + @Override + public void run() { + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException expected) { + interrupted1.set(true); + } + } + }; + t1.start(); + + Thread t2 = new Thread(new ThreadGroup("anothergroup"), "another child") { + @Override + public void run() { + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException expected) { + interrupted2.set(true); + try { + t1.interrupt(); // try to bogusly interrupt our sibling + fail("did not hit expected exception"); + } catch (SecurityException expected2) {} + } + } + }; + t2.start(); + t2.interrupt(); + t2.join(); + // sibling attempted to but was not able to muck with its other sibling + assertTrue(interrupted2.get()); + assertFalse(interrupted1.get()); + // but we are the parent and can terminate + t1.interrupt(); + t1.join(); + assertTrue(interrupted1.get()); + } + + public void testParallelStreamThreadGroup() throws Exception { + List list = new ArrayList<>(); + for (int i = 0; i < 100; ++i) { + list.add(i); + } + list.parallelStream().collect(Collectors.toSet()); + } +} diff --git a/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/ThreadPermissionTests.java b/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/ThreadPermissionTests.java new file mode 100644 index 0000000000000..3a398e324fcc6 --- /dev/null +++ b/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/ThreadPermissionTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.secure_sm; + +import junit.framework.TestCase; + +import java.security.AllPermission; + +/** + * Simple tests for ThreadPermission + */ +public class ThreadPermissionTests extends TestCase { + + public void testEquals() { + assertEquals(new ThreadPermission("modifyArbitraryThread"), new ThreadPermission("modifyArbitraryThread")); + assertFalse(new ThreadPermission("modifyArbitraryThread").equals(new AllPermission())); + assertFalse(new ThreadPermission("modifyArbitraryThread").equals(new ThreadPermission("modifyArbitraryThreadGroup"))); + } + + public void testImplies() { + assertTrue(new ThreadPermission("modifyArbitraryThread").implies(new ThreadPermission("modifyArbitraryThread"))); + assertTrue(new ThreadPermission("modifyArbitraryThreadGroup").implies(new ThreadPermission("modifyArbitraryThreadGroup"))); + assertFalse(new ThreadPermission("modifyArbitraryThread").implies(new ThreadPermission("modifyArbitraryThreadGroup"))); + assertFalse(new ThreadPermission("modifyArbitraryThreadGroup").implies(new ThreadPermission("modifyArbitraryThread"))); + assertFalse(new ThreadPermission("modifyArbitraryThread").implies(new AllPermission())); + assertFalse(new ThreadPermission("modifyArbitraryThreadGroup").implies(new AllPermission())); + assertTrue(new ThreadPermission("*").implies(new ThreadPermission("modifyArbitraryThread"))); + assertTrue(new ThreadPermission("*").implies(new ThreadPermission("modifyArbitraryThreadGroup"))); + assertFalse(new ThreadPermission("*").implies(new AllPermission())); + } +} diff --git a/modules/apm/src/main/plugin-metadata/plugin-security.policy b/modules/apm/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..763ae7f582d38 --- /dev/null +++ b/modules/apm/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + permission java.lang.RuntimePermission "accessSystemModules"; + permission java.lang.RuntimePermission "createClassLoader"; + permission java.lang.RuntimePermission "getClassLoader"; + permission java.util.PropertyPermission "elastic.apm.*", "write"; + permission java.util.PropertyPermission "*", "read,write"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; +}; + +grant codeBase "${codebase.elastic-apm-agent}" { + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.lang.RuntimePermission "setFactory"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + permission java.net.SocketPermission "*", "connect,resolve"; + // profiling function in APM agent + permission java.util.PropertyPermission "AsyncProfiler.safemode", "write"; + permission java.lang.RuntimePermission "accessUserInformation"; + permission java.lang.RuntimePermission "loadLibrary.*"; + permission java.lang.RuntimePermission "getClassLoader"; + permission java.io.FilePermission "<>", "read,write"; + permission org.elasticsearch.secure_sm.ThreadPermission "modifyArbitraryThreadGroup"; + permission java.net.NetPermission "getProxySelector"; +}; diff --git a/modules/ingest-geoip/src/main/plugin-metadata/plugin-security.policy b/modules/ingest-geoip/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..bfd77bc296124 --- /dev/null +++ b/modules/ingest-geoip/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/modules/lang-expression/src/main/plugin-metadata/plugin-security.policy b/modules/lang-expression/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..344ff666ebdb4 --- /dev/null +++ b/modules/lang-expression/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // needed to generate runtime classes + permission java.lang.RuntimePermission "createClassLoader"; + + // expression runtime + permission org.elasticsearch.script.ClassPermission "java.lang.String"; + permission org.elasticsearch.script.ClassPermission "org.apache.lucene.expressions.Expression"; + permission org.elasticsearch.script.ClassPermission "org.apache.lucene.search.DoubleValues"; + // available functions + permission org.elasticsearch.script.ClassPermission "java.lang.Math"; + permission org.elasticsearch.script.ClassPermission "org.apache.lucene.util.MathUtil"; + permission org.elasticsearch.script.ClassPermission "org.apache.lucene.util.SloppyMath"; + permission org.elasticsearch.script.ClassPermission "org.apache.lucene.expressions.js.ExpressionMath"; +}; diff --git a/modules/lang-painless/spi/src/test/java/org/elasticsearch/painless/WhitelistLoaderTests.java b/modules/lang-painless/spi/src/test/java/org/elasticsearch/painless/WhitelistLoaderTests.java index 7260a43e2feab..b46bc118e0913 100644 --- a/modules/lang-painless/spi/src/test/java/org/elasticsearch/painless/WhitelistLoaderTests.java +++ b/modules/lang-painless/spi/src/test/java/org/elasticsearch/painless/WhitelistLoaderTests.java @@ -132,7 +132,7 @@ public void testMissingWhitelistResourceInModule() throws Exception { JarUtils.createJarWithEntries(jar, jarEntries); try (var loader = JarUtils.loadJar(jar)) { - Controller controller = JarUtils.loadModule(jar, loader, "m"); + Controller controller = JarUtils.loadModule(jar, loader.classloader(), "m"); Module module = controller.layer().findModule("m").orElseThrow(); Class ownerClass = module.getClassLoader().loadClass("p.TestOwner"); diff --git a/modules/lang-painless/src/main/plugin-metadata/plugin-security.policy b/modules/lang-painless/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..794044a2669c4 --- /dev/null +++ b/modules/lang-painless/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // needed to generate runtime classes + permission java.lang.RuntimePermission "createClassLoader"; + + // needed to find the classloader to load whitelisted classes from + permission java.lang.RuntimePermission "getClassLoader"; +}; diff --git a/modules/reindex/build.gradle b/modules/reindex/build.gradle index adfd6921445a4..93333579e6f8a 100644 --- a/modules/reindex/build.gradle +++ b/modules/reindex/build.gradle @@ -168,11 +168,11 @@ if (OS.current() == OS.WINDOWS) { tasks.named("javaRestTest").configure { dependsOn fixture - systemProperty "tests.fromOld", "true" - /* Use a closure on the string to delay evaluation until right before we - * run the integration tests so that we can be sure that the file is - * ready. */ - nonInputProperties.systemProperty "es${version}.port", fixture.map(f->f.addressAndPort) + systemProperty "tests.fromOld", "true" + /* Use a closure on the string to delay evaluation until right before we + * run the integration tests so that we can be sure that the file is + * ready. */ + nonInputProperties.systemProperty "es${version}.port", fixture.map(f->f.addressAndPort) } } } diff --git a/modules/reindex/src/main/plugin-metadata/plugin-security.codebases b/modules/reindex/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..0f1fbba4b76c2 --- /dev/null +++ b/modules/reindex/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1,2 @@ +elasticsearch-rest-client: org.elasticsearch.client.RestClient +httpasyncclient: org.apache.http.nio.client.HttpAsyncClient diff --git a/modules/reindex/src/main/plugin-metadata/plugin-security.policy b/modules/reindex/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..016cc6365b6ee --- /dev/null +++ b/modules/reindex/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // reindex opens socket connections using the rest client + permission java.net.SocketPermission "*", "connect"; +}; + +grant codeBase "${codebase.elasticsearch-rest-client}" { + // rest client uses system properties which gets the default proxy + permission java.net.NetPermission "getProxySelector"; +}; + +grant codeBase "${codebase.httpasyncclient}" { + // rest client uses system properties which gets the default proxy + permission java.net.NetPermission "getProxySelector"; +}; diff --git a/modules/repository-azure/src/main/plugin-metadata/plugin-security.policy b/modules/repository-azure/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..3aeeb6bde3914 --- /dev/null +++ b/modules/repository-azure/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // azure client opens socket connections for to access repository + permission java.net.SocketPermission "*", "connect"; + // io.netty.util.concurrent.GlobalEventExecutor.startThread + permission java.lang.RuntimePermission "setContextClassLoader"; + // io.netty.util.concurrent.GlobalEventExecutor.startThread + permission java.lang.RuntimePermission "getClassLoader"; + // Used by jackson bean deserialization + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; +}; diff --git a/modules/repository-gcs/src/main/plugin-metadata/plugin-security.policy b/modules/repository-gcs/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..36149b5d4ecd5 --- /dev/null +++ b/modules/repository-gcs/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // required by: com.google.api.client.json.JsonParser#parseValue + permission java.lang.RuntimePermission "accessDeclaredMembers"; + // required by: com.google.api.client.json.GenericJson# + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + // required to add google certs to the gcs client trustore + permission java.lang.RuntimePermission "setFactory"; + + // gcs client opens socket connections for to access repository + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/modules/repository-s3/qa/insecure-credentials/src/test/resources/plugin-security.policy b/modules/repository-s3/qa/insecure-credentials/src/test/resources/plugin-security.policy new file mode 100644 index 0000000000000..4b3e89e3f60e3 --- /dev/null +++ b/modules/repository-s3/qa/insecure-credentials/src/test/resources/plugin-security.policy @@ -0,0 +1,3 @@ +grant { + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/modules/repository-s3/qa/web-identity-token/src/test/resources/plugin-security.policy b/modules/repository-s3/qa/web-identity-token/src/test/resources/plugin-security.policy new file mode 100644 index 0000000000000..4b3e89e3f60e3 --- /dev/null +++ b/modules/repository-s3/qa/web-identity-token/src/test/resources/plugin-security.policy @@ -0,0 +1,3 @@ +grant { + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/RepositoryS3ExplicitProtocolRestIT.java b/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/RepositoryS3ExplicitProtocolRestIT.java deleted file mode 100644 index bc4bac1209cd1..0000000000000 --- a/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/RepositoryS3ExplicitProtocolRestIT.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.repositories.s3; - -import fixture.aws.DynamicRegionSupplier; -import fixture.s3.S3HttpFixture; - -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - -import org.elasticsearch.test.cluster.ElasticsearchCluster; -import org.elasticsearch.test.fixtures.testcontainers.TestContainersThreadFilter; -import org.junit.ClassRule; -import org.junit.rules.RuleChain; -import org.junit.rules.TestRule; - -import java.util.function.Supplier; - -import static fixture.aws.AwsCredentialsUtils.fixedAccessKey; -import static org.hamcrest.Matchers.startsWith; - -@ThreadLeakFilters(filters = { TestContainersThreadFilter.class }) -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) // https://github.com/elastic/elasticsearch/issues/102482 -public class RepositoryS3ExplicitProtocolRestIT extends AbstractRepositoryS3RestTestCase { - - private static final String PREFIX = getIdentifierPrefix("RepositoryS3ExplicitProtocolRestIT"); - private static final String BUCKET = PREFIX + "bucket"; - private static final String BASE_PATH = PREFIX + "base_path"; - private static final String ACCESS_KEY = PREFIX + "access-key"; - private static final String SECRET_KEY = PREFIX + "secret-key"; - private static final String CLIENT = "explicit_protocol_client"; - - private static final Supplier regionSupplier = new DynamicRegionSupplier(); - private static final S3HttpFixture s3Fixture = new S3HttpFixture( - true, - BUCKET, - BASE_PATH, - fixedAccessKey(ACCESS_KEY, regionSupplier, "s3") - ); - - private static String getEndpoint() { - final var s3FixtureAddress = s3Fixture.getAddress(); - assertThat(s3FixtureAddress, startsWith("http://")); - return s3FixtureAddress.substring("http://".length()); - } - - public static ElasticsearchCluster cluster = ElasticsearchCluster.local() - .module("repository-s3") - .systemProperty("aws.region", regionSupplier) - .keystore("s3.client." + CLIENT + ".access_key", ACCESS_KEY) - .keystore("s3.client." + CLIENT + ".secret_key", SECRET_KEY) - .setting("s3.client." + CLIENT + ".endpoint", RepositoryS3ExplicitProtocolRestIT::getEndpoint) - .setting("s3.client." + CLIENT + ".protocol", () -> "http") - .build(); - - @ClassRule - public static TestRule ruleChain = RuleChain.outerRule(s3Fixture).around(cluster); - - @Override - protected String getTestRestCluster() { - return cluster.getHttpAddresses(); - } - - @Override - protected String getBucketName() { - return BUCKET; - } - - @Override - protected String getBasePath() { - return BASE_PATH; - } - - @Override - protected String getClientName() { - return CLIENT; - } -} diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3ClientSettings.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3ClientSettings.java index 797a16240f338..48ec20992ec20 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3ClientSettings.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3ClientSettings.java @@ -77,10 +77,9 @@ final class S3ClientSettings { key -> new Setting<>(key, "", s -> s.toLowerCase(Locale.ROOT), Property.NodeScope) ); - /** The protocol to use to connect to s3, now only used if {@link #endpoint} is not a proper URI that starts with {@code http://} or - * {@code https://}. */ + /** Formerly the protocol to use to connect to s3, now unused. V2 AWS SDK can infer the protocol from {@link #endpoint}. */ @UpdateForV10(owner = UpdateForV10.Owner.DISTRIBUTED_COORDINATION) // no longer used, should be removed in v10 - static final Setting.AffixSetting PROTOCOL_SETTING = Setting.affixKeySetting( + static final Setting.AffixSetting UNUSED_PROTOCOL_SETTING = Setting.affixKeySetting( PREFIX, "protocol", key -> new Setting<>(key, "https", s -> HttpScheme.valueOf(s.toUpperCase(Locale.ROOT)), Property.NodeScope, Property.Deprecated) @@ -182,9 +181,6 @@ final class S3ClientSettings { /** Credentials to authenticate with s3. */ final AwsCredentials credentials; - /** The scheme (HTTP or HTTPS) for talking to the endpoint, for use only if the endpoint doesn't contain an explicit scheme */ - final HttpScheme protocol; - /** The s3 endpoint the client should talk to, or empty string to use the default. */ final String endpoint; @@ -225,7 +221,6 @@ final class S3ClientSettings { private S3ClientSettings( AwsCredentials credentials, - HttpScheme protocol, String endpoint, String proxyHost, int proxyPort, @@ -240,7 +235,6 @@ private S3ClientSettings( String region ) { this.credentials = credentials; - this.protocol = protocol; this.endpoint = endpoint; this.proxyHost = proxyHost; this.proxyPort = proxyPort; @@ -267,7 +261,6 @@ S3ClientSettings refine(Settings repositorySettings) { .put(repositorySettings) .normalizePrefix(PREFIX + PLACEHOLDER_CLIENT + '.') .build(); - final HttpScheme newProtocol = getRepoSettingOrDefault(PROTOCOL_SETTING, normalizedSettings, protocol); final String newEndpoint = getRepoSettingOrDefault(ENDPOINT_SETTING, normalizedSettings, endpoint); final String newProxyHost = getRepoSettingOrDefault(PROXY_HOST_SETTING, normalizedSettings, proxyHost); @@ -291,8 +284,7 @@ S3ClientSettings refine(Settings repositorySettings) { newCredentials = credentials; } final String newRegion = getRepoSettingOrDefault(REGION, normalizedSettings, region); - if (Objects.equals(protocol, newProtocol) - && Objects.equals(endpoint, newEndpoint) + if (Objects.equals(endpoint, newEndpoint) && Objects.equals(proxyHost, newProxyHost) && proxyPort == newProxyPort && proxyScheme == newProxyScheme @@ -307,7 +299,6 @@ S3ClientSettings refine(Settings repositorySettings) { } return new S3ClientSettings( newCredentials, - newProtocol, newEndpoint, newProxyHost, newProxyPort, @@ -414,7 +405,6 @@ static S3ClientSettings getClientSettings(final Settings settings, final String ) { return new S3ClientSettings( S3ClientSettings.loadCredentials(settings, clientName), - getConfigValue(settings, clientName, PROTOCOL_SETTING), getConfigValue(settings, clientName, ENDPOINT_SETTING), getConfigValue(settings, clientName, PROXY_HOST_SETTING), getConfigValue(settings, clientName, PROXY_PORT_SETTING), @@ -445,7 +435,6 @@ public boolean equals(final Object o) { && maxConnections == that.maxConnections && maxRetries == that.maxRetries && Objects.equals(credentials, that.credentials) - && Objects.equals(protocol, that.protocol) && Objects.equals(endpoint, that.endpoint) && Objects.equals(proxyHost, that.proxyHost) && proxyScheme == that.proxyScheme @@ -459,7 +448,6 @@ public boolean equals(final Object o) { public int hashCode() { return Objects.hash( credentials, - protocol, endpoint, proxyHost, proxyPort, diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3RepositoryPlugin.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3RepositoryPlugin.java index 43520bb123647..64295a7249ed6 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3RepositoryPlugin.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3RepositoryPlugin.java @@ -131,7 +131,7 @@ public List> getSettings() { S3ClientSettings.SECRET_KEY_SETTING, S3ClientSettings.SESSION_TOKEN_SETTING, S3ClientSettings.ENDPOINT_SETTING, - S3ClientSettings.PROTOCOL_SETTING, + S3ClientSettings.UNUSED_PROTOCOL_SETTING, S3ClientSettings.PROXY_HOST_SETTING, S3ClientSettings.PROXY_PORT_SETTING, S3ClientSettings.PROXY_SCHEME_SETTING, diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java index 82f0ea5964963..07dfb332cbf8c 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java @@ -257,18 +257,14 @@ protected S3ClientBuilder buildClientBuilder(S3ClientSettings clientSettings, Sd String endpoint = clientSettings.endpoint; if ((endpoint.startsWith("http://") || endpoint.startsWith("https://")) == false) { // The SDK does not know how to interpret endpoints without a scheme prefix and will error. Therefore, when the scheme is - // absent, we'll look at the deprecated .protocol setting + // absent, we'll supply HTTPS as a default to avoid errors. // See https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/client-configuration.html#client-config-other-diffs - endpoint = switch (clientSettings.protocol) { - case HTTP -> "http://" + endpoint; - case HTTPS -> "https://" + endpoint; - }; + endpoint = "https://" + endpoint; LOGGER.warn( """ - found S3 client with endpoint [{}] that is missing a scheme, guessing it should be [{}]; \ + found S3 client with endpoint [{}] that is missing a scheme, guessing it should use 'https://'; \ to suppress this warning, add a scheme prefix to the [{}] setting on this node""", clientSettings.endpoint, - endpoint, S3ClientSettings.ENDPOINT_SETTING.getConcreteSettingForNamespace("CLIENT_NAME").getKey() ); } diff --git a/modules/repository-s3/src/main/plugin-metadata/plugin-security.policy b/modules/repository-s3/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..9c8495aa9423d --- /dev/null +++ b/modules/repository-s3/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + + // needed because of problems in ClientConfiguration + // TODO: get these fixed in aws sdk + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + // Needed because of problems in AmazonS3Client: + // When no region is set on a AmazonS3Client instance, the + // AWS SDK loads all known partitions from a JSON file and + // uses a Jackson's ObjectMapper for that: this one, in + // version 2.5.3 with the default binding options, tries + // to suppress access checks of ctor/field/method and thus + // requires this special permission. AWS must be fixed to + // uses Jackson correctly and have the correct modifiers + // on binded classes. + // TODO: get these fixed in aws sdk + // See https://github.com/aws/aws-sdk-java/issues/766 + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + + // s3 client opens socket connections for to access repository + permission java.net.SocketPermission "*", "connect"; + + // only for tests : org.elasticsearch.repositories.s3.S3RepositoryPlugin + permission java.util.PropertyPermission "es.allow_insecure_settings", "read,write"; +}; diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ClientSettingsTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ClientSettingsTests.java index 5e7f083de8eb2..3de070934bdbd 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ClientSettingsTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ClientSettingsTests.java @@ -34,7 +34,6 @@ public void testThereIsADefaultClientByDefault() { final S3ClientSettings defaultSettings = settings.get("default"); assertThat(defaultSettings.credentials, nullValue()); - assertThat(defaultSettings.protocol, is(HttpScheme.HTTPS)); assertThat(defaultSettings.endpoint, is(emptyString())); assertThat(defaultSettings.proxyHost, is(emptyString())); assertThat(defaultSettings.proxyPort, is(80)); diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ServiceTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ServiceTests.java index dd18932cea7d3..d4cb72dd04257 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ServiceTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3ServiceTests.java @@ -14,6 +14,7 @@ import software.amazon.awssdk.core.retry.conditions.RetryCondition; import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider; import software.amazon.awssdk.services.s3.model.S3Exception; @@ -222,57 +223,20 @@ public void testGetClientRegionFallbackToUsEast1() { } public void testEndpointOverrideSchemeDefaultsToHttpsWhenNotSpecified() { - final var endpointWithoutScheme = randomIdentifier() + ".ignore"; - final var clientName = randomIdentifier(); - assertThat( - getEndpointUri(Settings.builder().put("s3.client." + clientName + ".endpoint", endpointWithoutScheme), clientName), - equalTo(URI.create("https://" + endpointWithoutScheme)) - ); - } - - public void testEndpointOverrideSchemeUsesHttpsIfHttpsProtocolSpecified() { - final var endpointWithoutScheme = randomIdentifier() + ".ignore"; - final var clientName = randomIdentifier(); - assertThat( - getEndpointUri( - Settings.builder() - .put("s3.client." + clientName + ".endpoint", endpointWithoutScheme) - .put("s3.client." + clientName + ".protocol", "https"), - clientName - ), - equalTo(URI.create("https://" + endpointWithoutScheme)) - ); - assertWarnings(Strings.format(""" - [s3.client.%s.protocol] setting was deprecated in Elasticsearch and will be removed in a future release. \ - See the breaking changes documentation for the next major version.""", clientName)); - } - - public void testEndpointOverrideSchemeUsesHttpIfHttpProtocolSpecified() { - final var endpointWithoutScheme = randomIdentifier() + ".ignore"; - final var clientName = randomIdentifier(); - assertThat( - getEndpointUri( - Settings.builder() - .put("s3.client." + clientName + ".endpoint", endpointWithoutScheme) - .put("s3.client." + clientName + ".protocol", "http"), - clientName - ), - equalTo(URI.create("http://" + endpointWithoutScheme)) - ); - assertWarnings(Strings.format(""" - [s3.client.%s.protocol] setting was deprecated in Elasticsearch and will be removed in a future release. \ - See the breaking changes documentation for the next major version.""", clientName)); - } - - private static URI getEndpointUri(Settings.Builder settings, String clientName) { - return new S3Service( + final S3Service s3Service = new S3Service( mock(Environment.class), Settings.EMPTY, mock(ResourceWatcherService.class), - () -> Region.of(randomIdentifier()) - ).buildClient(S3ClientSettings.getClientSettings(settings.build(), clientName), mock(SdkHttpClient.class)) - .serviceClientConfiguration() - .endpointOverride() - .get(); + () -> Region.of("es-test-region") + ); + final String endpointWithoutScheme = randomIdentifier() + ".ignore"; + S3Client s3Client = s3Service.buildClient( + S3ClientSettings.getClientSettings( + Settings.builder().put("s3.client.test-client.endpoint", endpointWithoutScheme).build(), + "test-client" + ), + mock(SdkHttpClient.class) + ); + assertThat(s3Client.serviceClientConfiguration().endpointOverride().get(), equalTo(URI.create("https://" + endpointWithoutScheme))); } } diff --git a/modules/repository-url/src/main/plugin-metadata/plugin-security.policy b/modules/repository-url/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..bfd77bc296124 --- /dev/null +++ b/modules/repository-url/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/modules/systemd/src/main/plugin-metadata/plugin-security.codebases b/modules/systemd/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..a2ab9277ab27b --- /dev/null +++ b/modules/systemd/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1 @@ +systemd: org.elasticsearch.systemd.SystemdPlugin diff --git a/modules/systemd/src/main/plugin-metadata/plugin-security.policy b/modules/systemd/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..c8f6a798fdc2f --- /dev/null +++ b/modules/systemd/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant codeBase "${codebase.systemd}" { + // for registering native methods + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java index ef9cd8fb5ced9..3a788ba24879d 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java @@ -61,10 +61,10 @@ import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.http.HttpBodyTracer; import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.HttpTransportSettings; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; @@ -107,8 +107,6 @@ import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON; import static io.netty.handler.codec.http.HttpMethod.POST; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_CLIENT_STATS_MAX_CLOSED_CHANNEL_AGE; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.greaterThan; @@ -122,13 +120,12 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase { @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { - return Settings.builder() - .put(super.nodeSettings(nodeOrdinal, otherSettings)) - // reduce max content length just to cut down test duration - .put(SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), ByteSizeValue.of(MAX_CONTENT_LENGTH, ByteSizeUnit.BYTES)) - // disable time-based expiry of channel stats since we assert that the total request size accumulates - .put(SETTING_HTTP_CLIENT_STATS_MAX_CLOSED_CHANNEL_AGE.getKey(), TimeValue.MAX_VALUE) - .build(); + Settings.Builder builder = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); + builder.put( + HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), + ByteSizeValue.of(MAX_CONTENT_LENGTH, ByteSizeUnit.BYTES) + ); + return builder.build(); } // ensure empty http content has single 0 size chunk @@ -433,7 +430,7 @@ public void testHttpClientStats() throws Exception { clientContext.channel().writeAndFlush(httpRequest(opaqueId, contentSize)); clientContext.channel().writeAndFlush(randomContent(contentSize, true)); final var handler = clientContext.awaitRestChannelAccepted(opaqueId); - assertEquals(contentSize, handler.readAllBytes()); + handler.readAllBytes(); handler.sendResponse(new RestResponse(RestStatus.OK, "")); assertEquals(totalBytesSent, clientContext.transportStatsRequestBytesSize()); } diff --git a/modules/transport-netty4/src/main/plugin-metadata/plugin-security.codebases b/modules/transport-netty4/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..8bef817663502 --- /dev/null +++ b/modules/transport-netty4/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1 @@ +netty-transport: io.netty.channel.Channel diff --git a/modules/transport-netty4/src/main/plugin-metadata/plugin-security.policy b/modules/transport-netty4/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..dbf8e728c1606 --- /dev/null +++ b/modules/transport-netty4/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant codeBase "${codebase.netty-common}" { + // for reading the system-wide configuration for the backlog of established sockets + permission java.io.FilePermission "/proc/sys/net/core/somaxconn", "read"; + + // netty makes and accepts socket connections + permission java.net.SocketPermission "*", "accept,connect"; + + // Netty gets and sets classloaders for some of its internal threads + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.lang.RuntimePermission "getClassLoader"; +}; + +grant codeBase "${codebase.netty-transport}" { + // Netty NioEventLoop wants to change this, because of https://bugs.openjdk.java.net/browse/JDK-6427854 + // the bug says it only happened rarely, and that its fixed, but apparently it still happens rarely! + permission java.util.PropertyPermission "sun.nio.ch.bugLevel", "write"; +}; diff --git a/plugins/discovery-azure-classic/src/main/plugin-metadata/plugin-security.policy b/plugins/discovery-azure-classic/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..d766397d7b1b4 --- /dev/null +++ b/plugins/discovery-azure-classic/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // azure client opens socket connections for discovery + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/plugins/discovery-ec2/src/main/plugin-metadata/plugin-security.policy b/plugins/discovery-ec2/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..7827c4b9bb987 --- /dev/null +++ b/plugins/discovery-ec2/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // needed because of problems in ClientConfiguration + // TODO: get these fixed in aws sdk + permission java.lang.RuntimePermission "accessDeclaredMembers"; + // NOTE: no tests fail without this, but we know the problem + // exists in AWS sdk, and tests here are not thorough + permission java.lang.RuntimePermission "getClassLoader"; + + // ec2 client opens socket connections for discovery + permission java.net.SocketPermission "*", "connect"; + + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + permission java.util.PropertyPermission "http.proxyHost", "read"; +}; diff --git a/plugins/discovery-gce/src/main/plugin-metadata/plugin-security.policy b/plugins/discovery-gce/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..467d6d4502869 --- /dev/null +++ b/plugins/discovery-gce/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // needed because of problems in gce + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "setFactory"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + + // gce client opens socket connections for discovery + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/plugins/examples/gradle/wrapper/gradle-wrapper.properties b/plugins/examples/gradle/wrapper/gradle-wrapper.properties index f373f37ad8290..2a6e21b2ba89a 100644 --- a/plugins/examples/gradle/wrapper/gradle-wrapper.properties +++ b/plugins/examples/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=efe9a3d147d948d7528a9887fa35abcf24ca1a43ad06439996490f77569b02d1 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-all.zip +distributionSha256Sum=fba8464465835e74f7270bbf43d6d8a8d7709ab0a43ce1aa3323f73e9aa0c612 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.13-all.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/plugins/repository-hdfs/src/main/plugin-metadata/plugin-security.policy b/plugins/repository-hdfs/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..db5a511267626 --- /dev/null +++ b/plugins/repository-hdfs/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // Hadoop UserGroupInformation, HdfsConstants, PipelineAck clinit + permission java.lang.RuntimePermission "getClassLoader"; + + // UserGroupInformation (UGI) Metrics clinit + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + + // Needed so that Hadoop can load the correct classes for SPI and JAAS + // org.apache.hadoop.security.SecurityUtil clinit + // org.apache.hadoop.security.UserGroupInformation.newLoginContext() + permission java.lang.RuntimePermission "setContextClassLoader"; + + // org.apache.hadoop.util.StringUtils clinit + permission java.util.PropertyPermission "*", "read,write"; + + // JAAS is used by Hadoop for authentication purposes + // The Hadoop Login JAAS module modifies a Subject's private credentials and principals + // The Hadoop RPC Layer must be able to read these credentials, and initiate Kerberos connections + + // org.apache.hadoop.security.UserGroupInformation.getCurrentUser() + permission javax.security.auth.AuthPermission "getSubject"; + + // org.apache.hadoop.security.UserGroupInformation.doAs() + permission javax.security.auth.AuthPermission "doAs"; + + // org.apache.hadoop.security.UserGroupInformation.getCredentialsInternal() + permission javax.security.auth.PrivateCredentialPermission "org.apache.hadoop.security.Credentials * \"*\"", "read"; + + // Hadoop depends on the Kerberos login module for kerberos authentication + // com.sun.security.auth.module.Krb5LoginModule.login() + permission java.lang.RuntimePermission "accessClassInPackage.sun.security.krb5"; + + // com.sun.security.auth.module.Krb5LoginModule.commit() + permission javax.security.auth.AuthPermission "modifyPrivateCredentials"; + permission javax.security.auth.AuthPermission "modifyPrincipals"; + permission javax.security.auth.PrivateCredentialPermission "javax.security.auth.kerberos.KeyTab * \"*\"", "read"; + permission javax.security.auth.PrivateCredentialPermission "javax.security.auth.kerberos.KerberosTicket * \"*\"", "read"; + + // Hadoop depends on OS level user information for simple authentication + // Unix: UnixLoginModule: com.sun.security.auth.module.UnixSystem.UnixSystem init + permission java.lang.RuntimePermission "loadLibrary.jaas"; + permission java.lang.RuntimePermission "loadLibrary.jaas_unix"; + // Windows: NTLoginModule: com.sun.security.auth.module.NTSystem.loadNative + permission java.lang.RuntimePermission "loadLibrary.jaas_nt"; + permission javax.security.auth.AuthPermission "modifyPublicCredentials"; + + // org.apache.hadoop.security.SaslRpcServer.init() + permission java.security.SecurityPermission "putProviderProperty.SaslPlainServer"; + + // org.apache.hadoop.security.SaslPlainServer.SecurityProvider.SecurityProvider init + permission java.security.SecurityPermission "insertProvider"; + + // org.apache.hadoop.security.SaslRpcClient.getServerPrincipal -> KerberosPrincipal init + permission javax.security.auth.kerberos.ServicePermission "*", "initiate"; + + // hdfs client opens socket connections for to access repository + permission java.net.SocketPermission "*", "connect"; + + // client binds to the address returned from the host name of any principal set up as a service principal + // org.apache.hadoop.ipc.Client.Connection.setupConnection + permission java.net.SocketPermission "localhost:0", "listen,resolve"; +}; diff --git a/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/ESPolicyUnitTests.java b/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/ESPolicyUnitTests.java new file mode 100644 index 0000000000000..34d15fa4aebba --- /dev/null +++ b/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/ESPolicyUnitTests.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.jdk.RuntimeVersionFeature; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; + +import java.io.FilePermission; +import java.net.SocketPermission; +import java.net.URL; +import java.security.AllPermission; +import java.security.CodeSource; +import java.security.Permission; +import java.security.PermissionCollection; +import java.security.Permissions; +import java.security.Policy; +import java.security.ProtectionDomain; +import java.security.cert.Certificate; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Map.entry; +import static org.elasticsearch.bootstrap.ESPolicy.POLICY_RESOURCE; + +/** + * Unit tests for ESPolicy: these cannot run with security manager, + * we don't allow messing with the policy + */ +public class ESPolicyUnitTests extends ESTestCase { + + static final Map TEST_CODEBASES = BootstrapForTesting.getCodebases(); + static Policy DEFAULT_POLICY; + + @BeforeClass + public static void setupPolicy() { + assumeTrue("test requires security manager to be supported", RuntimeVersionFeature.isSecurityManagerAvailable()); + assumeTrue("test cannot run with security manager", System.getSecurityManager() == null); + DEFAULT_POLICY = PolicyUtil.readPolicy(ESPolicy.class.getResource(POLICY_RESOURCE), TEST_CODEBASES); + } + + /** + * Test policy with null codesource. + *

+ * This can happen when restricting privileges with doPrivileged, + * even though ProtectionDomain's ctor javadocs might make you think + * that the policy won't be consulted. + */ + @SuppressForbidden(reason = "to create FilePermission object") + public void testNullCodeSource() throws Exception { + // create a policy with AllPermission + Permission all = new AllPermission(); + PermissionCollection allCollection = all.newPermissionCollection(); + allCollection.add(all); + ESPolicy policy = new ESPolicy(DEFAULT_POLICY, allCollection, Map.of(), true, List.of(), Map.of()); + // restrict ourselves to NoPermission + PermissionCollection noPermissions = new Permissions(); + assertFalse(policy.implies(new ProtectionDomain(null, noPermissions), new FilePermission("foo", "read"))); + } + + /** + * As of JDK 9, {@link CodeSource#getLocation} is documented to potentially return {@code null} + */ + @SuppressForbidden(reason = "to create FilePermission object") + public void testNullLocation() throws Exception { + PermissionCollection noPermissions = new Permissions(); + ESPolicy policy = new ESPolicy(DEFAULT_POLICY, noPermissions, Map.of(), true, List.of(), Map.of()); + assertFalse( + policy.implies( + new ProtectionDomain(new CodeSource(null, (Certificate[]) null), noPermissions), + new FilePermission("foo", "read") + ) + ); + } + + public void testListen() { + final PermissionCollection noPermissions = new Permissions(); + final ESPolicy policy = new ESPolicy(DEFAULT_POLICY, noPermissions, Map.of(), true, List.of(), Map.of()); + assertFalse( + policy.implies( + new ProtectionDomain(ESPolicyUnitTests.class.getProtectionDomain().getCodeSource(), noPermissions), + new SocketPermission("localhost:" + randomFrom(0, randomIntBetween(49152, 65535)), "listen") + ) + ); + } + + @SuppressForbidden(reason = "to create FilePermission object") + public void testDataPathPermissionIsChecked() { + final ESPolicy policy = new ESPolicy( + DEFAULT_POLICY, + new Permissions(), + Map.of(), + true, + List.of(new FilePermission("/home/elasticsearch/data/-", "read")), + Map.of() + ); + assertTrue( + policy.implies( + new ProtectionDomain(new CodeSource(null, (Certificate[]) null), new Permissions()), + new FilePermission("/home/elasticsearch/data/index/file.si", "read") + ) + ); + } + + @SuppressForbidden(reason = "to create FilePermission object") + public void testSecuredAccess() { + String file1 = "/home/elasticsearch/config/pluginFile1.yml"; + URL codebase1 = randomFrom(TEST_CODEBASES.values()); + String file2 = "/home/elasticsearch/config/pluginFile2.yml"; + URL codebase2 = randomValueOtherThan(codebase1, () -> randomFrom(TEST_CODEBASES.values())); + String dir1 = "/home/elasticsearch/config/pluginDir/"; + URL codebase3 = randomValueOtherThanMany(Set.of(codebase1, codebase2)::contains, () -> randomFrom(TEST_CODEBASES.values())); + URL otherCodebase = randomValueOtherThanMany( + Set.of(codebase1, codebase2, codebase3)::contains, + () -> randomFrom(TEST_CODEBASES.values()) + ); + + ESPolicy policy = new ESPolicy( + DEFAULT_POLICY, + new Permissions(), + Map.of(), + true, + List.of(), + Map.ofEntries(entry(file1, Set.of(codebase1)), entry(file2, Set.of(codebase1, codebase2)), entry(dir1 + "*", Set.of(codebase3))) + ); + + ProtectionDomain nullDomain = new ProtectionDomain(new CodeSource(null, (Certificate[]) null), new Permissions()); + ProtectionDomain codebase1Domain = new ProtectionDomain(new CodeSource(codebase1, (Certificate[]) null), new Permissions()); + ProtectionDomain codebase2Domain = new ProtectionDomain(new CodeSource(codebase2, (Certificate[]) null), new Permissions()); + ProtectionDomain codebase3Domain = new ProtectionDomain(new CodeSource(codebase3, (Certificate[]) null), new Permissions()); + ProtectionDomain otherCodebaseDomain = new ProtectionDomain(new CodeSource(otherCodebase, (Certificate[]) null), new Permissions()); + + Set actions = Set.of("read", "write", "read,write", "delete", "read,write,execute,readlink,delete"); + + assertFalse(policy.implies(nullDomain, new FilePermission(file1, randomFrom(actions)))); + assertFalse(policy.implies(otherCodebaseDomain, new FilePermission(file1, randomFrom(actions)))); + assertTrue(policy.implies(codebase1Domain, new FilePermission(file1, randomFrom(actions)))); + assertFalse(policy.implies(codebase2Domain, new FilePermission(file1, randomFrom(actions)))); + assertFalse(policy.implies(codebase3Domain, new FilePermission(file1, randomFrom(actions)))); + + assertFalse(policy.implies(nullDomain, new FilePermission(file2, randomFrom(actions)))); + assertFalse(policy.implies(otherCodebaseDomain, new FilePermission(file2, randomFrom(actions)))); + assertTrue(policy.implies(codebase1Domain, new FilePermission(file2, randomFrom(actions)))); + assertTrue(policy.implies(codebase2Domain, new FilePermission(file2, randomFrom(actions)))); + assertFalse(policy.implies(codebase3Domain, new FilePermission(file2, randomFrom(actions)))); + + String dirFile = dir1 + "file.yml"; + assertFalse(policy.implies(nullDomain, new FilePermission(dirFile, randomFrom(actions)))); + assertFalse(policy.implies(otherCodebaseDomain, new FilePermission(dirFile, randomFrom(actions)))); + assertFalse(policy.implies(codebase1Domain, new FilePermission(dirFile, randomFrom(actions)))); + assertFalse(policy.implies(codebase2Domain, new FilePermission(dirFile, randomFrom(actions)))); + assertTrue(policy.implies(codebase3Domain, new FilePermission(dirFile, randomFrom(actions)))); + } +} diff --git a/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/PolicyUtilTests.java b/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/PolicyUtilTests.java new file mode 100644 index 0000000000000..5a21cda02e7f3 --- /dev/null +++ b/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/PolicyUtilTests.java @@ -0,0 +1,391 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.jdk.RuntimeVersionFeature; +import org.elasticsearch.plugins.PluginDescriptor; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.Permission; +import java.security.Policy; +import java.security.URIParameter; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.hamcrest.collection.IsMapContaining.hasKey; + +public class PolicyUtilTests extends ESTestCase { + + @Before + public void assumeSecurityManagerDisabled() { + assumeTrue("test requires security manager to be supported", RuntimeVersionFeature.isSecurityManagerAvailable()); + assumeTrue("test cannot run with security manager enabled", System.getSecurityManager() == null); + } + + URL makeUrl(String s) { + try { + return new URL(s); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + } + + Path makeDummyPlugin(String policy, String... files) throws IOException { + Path plugin = createTempDir(); + Files.copy(this.getDataPath(policy), plugin.resolve(PluginDescriptor.ES_PLUGIN_POLICY)); + for (String file : files) { + Files.createFile(plugin.resolve(file)); + } + return plugin; + } + + @SuppressForbidden(reason = "set for test") + void setProperty(String key, String value) { + System.setProperty(key, value); + } + + @SuppressForbidden(reason = "cleanup test") + void clearProperty(String key) { + System.clearProperty(key); + } + + public void testCodebaseJarMap() throws Exception { + Set urls = new LinkedHashSet<>(List.of(makeUrl("file:///foo.jar"), makeUrl("file:///bar.txt"), makeUrl("file:///a/bar.jar"))); + + Map jarMap = PolicyUtil.getCodebaseJarMap(urls); + assertThat(jarMap, hasKey("foo.jar")); + assertThat(jarMap, hasKey("bar.jar")); + // only jars are grabbed + assertThat(jarMap, not(hasKey("bar.txt"))); + + // order matters + assertThat(jarMap.keySet(), contains("foo.jar", "bar.jar")); + } + + public void testPluginPolicyInfoEmpty() throws Exception { + assertThat(PolicyUtil.readPolicyInfo(createTempDir()), is(nullValue())); + } + + public void testPluginPolicyInfoNoJars() throws Exception { + Path noJarsPlugin = makeDummyPlugin("dummy.policy"); + PluginPolicyInfo info = PolicyUtil.readPolicyInfo(noJarsPlugin); + assertThat(info.policy(), is(not(nullValue()))); + assertThat(info.jars(), emptyIterable()); + } + + public void testPluginPolicyInfo() throws Exception { + Path plugin = makeDummyPlugin("dummy.policy", "foo.jar", "foo.txt", "bar.jar"); + PluginPolicyInfo info = PolicyUtil.readPolicyInfo(plugin); + assertThat(info.policy(), is(not(nullValue()))); + assertThat(info.jars(), containsInAnyOrder(plugin.resolve("foo.jar").toUri().toURL(), plugin.resolve("bar.jar").toUri().toURL())); + } + + public void testPolicyMissingCodebaseProperty() throws Exception { + Path plugin = makeDummyPlugin("missing-codebase.policy", "foo.jar"); + URL policyFile = plugin.resolve(PluginDescriptor.ES_PLUGIN_POLICY).toUri().toURL(); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> PolicyUtil.readPolicy(policyFile, Map.of())); + assertThat(e.getMessage(), containsString("Unknown codebases [codebase.doesnotexist] in policy file")); + } + + public void testPolicyPermissions() throws Exception { + Path plugin = makeDummyPlugin("global-and-jar.policy", "foo.jar", "bar.jar"); + Path tmpDir = createTempDir(); + try { + URL jarUrl = plugin.resolve("foo.jar").toUri().toURL(); + setProperty("jarUrl", jarUrl.toString()); + URL policyFile = plugin.resolve(PluginDescriptor.ES_PLUGIN_POLICY).toUri().toURL(); + Policy policy = Policy.getInstance("JavaPolicy", new URIParameter(policyFile.toURI())); + + Set globalPermissions = PolicyUtil.getPolicyPermissions(null, policy, tmpDir); + assertThat(globalPermissions, contains(new RuntimePermission("queuePrintJob"))); + + Set jarPermissions = PolicyUtil.getPolicyPermissions(jarUrl, policy, tmpDir); + assertThat(jarPermissions, containsInAnyOrder(new RuntimePermission("getClassLoader"), new RuntimePermission("queuePrintJob"))); + } finally { + clearProperty("jarUrl"); + } + } + + private Path makeSinglePermissionPlugin(String jarUrl, String clazz, String name, String actions) throws IOException { + Path plugin = createTempDir(); + StringBuilder policyString = new StringBuilder("grant"); + if (jarUrl != null) { + Path jar = plugin.resolve(jarUrl); + Files.createFile(jar); + policyString.append(" codeBase \"" + jar.toUri().toURL().toString() + "\""); + } + policyString.append(" {\n permission "); + policyString.append(clazz); + // wildcard + policyString.append(" \"" + name + "\""); + if (actions != null) { + policyString.append(", \"" + actions + "\""); + } + policyString.append(";\n};"); + + logger.info(policyString.toString()); + Files.writeString(plugin.resolve(PluginDescriptor.ES_PLUGIN_POLICY), policyString.toString()); + + return plugin; + } + + interface PolicyParser { + PluginPolicyInfo parse(Path pluginRoot, Path tmpDir) throws IOException; + } + + void assertAllowedPermission(String clazz, String name, String actions, Path tmpDir, PolicyParser parser) throws Exception { + // global policy + Path plugin = makeSinglePermissionPlugin(null, clazz, name, actions); + assertNotNull(parser.parse(plugin, tmpDir)); // no error + + // specific jar policy + plugin = makeSinglePermissionPlugin("foobar.jar", clazz, name, actions); + assertNotNull(parser.parse(plugin, tmpDir)); // no error + } + + void assertAllowedPermissions(List allowedPermissions, PolicyParser parser) throws Exception { + Path tmpDir = createTempDir(); + for (String rawPermission : allowedPermissions) { + String[] elements = rawPermission.split(" "); + assert elements.length <= 3; + assert elements.length >= 2; + + assertAllowedPermission(elements[0], elements[1], elements.length == 3 ? elements[2] : null, tmpDir, parser); + } + } + + void assertIllegalPermission(String clazz, String name, String actions, Path tmpDir, PolicyParser parser) throws Exception { + // global policy + final Path globalPlugin = makeSinglePermissionPlugin(null, clazz, name, actions); + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + "Permission (" + clazz + " " + name + (actions == null ? "" : (" " + actions)) + ") should be illegal", + () -> parser.parse(globalPlugin, tmpDir) + ); // no error + assertThat(e.getMessage(), containsString("contains illegal permission")); + assertThat(e.getMessage(), containsString("in global grant")); + + // specific jar policy + final Path jarPlugin = makeSinglePermissionPlugin("foobar.jar", clazz, name, actions); + e = expectThrows(IllegalArgumentException.class, () -> parser.parse(jarPlugin, tmpDir)); // no error + assertThat(e.getMessage(), containsString("contains illegal permission")); + assertThat(e.getMessage(), containsString("for jar")); + } + + void assertIllegalPermissions(List illegalPermissions, PolicyParser parser) throws Exception { + Path tmpDir = createTempDir(); + for (String rawPermission : illegalPermissions) { + String[] elements = rawPermission.split(" "); + assert elements.length <= 3; + assert elements.length >= 2; + + assertIllegalPermission(elements[0], elements[1], elements.length == 3 ? elements[2] : null, tmpDir, parser); + } + } + + static final List PLUGIN_TEST_PERMISSIONS = List.of( + // TODO: move this back to module test permissions, see https://github.com/elastic/elasticsearch/issues/69464 + "java.io.FilePermission /foo/bar read", + + "java.lang.reflect.ReflectPermission suppressAccessChecks", + "java.lang.RuntimePermission getClassLoader", + "java.lang.RuntimePermission setContextClassLoader", + "java.lang.RuntimePermission setFactory", + "java.lang.RuntimePermission loadLibrary.*", + "java.lang.RuntimePermission accessClassInPackage.*", + "java.lang.RuntimePermission accessDeclaredMembers", + "java.net.NetPermission requestPasswordAuthentication", + "java.net.NetPermission getProxySelector", + "java.net.NetPermission getCookieHandler", + "java.net.NetPermission getResponseCache", + "java.net.SocketPermission * accept,connect,listen,resolve", + "java.net.SocketPermission www.elastic.co accept,connect,listen,resolve", + "java.net.URLPermission https://elastic.co", + "java.net.URLPermission http://elastic.co", + "java.security.SecurityPermission createAccessControlContext", + "java.security.SecurityPermission insertProvider", + "java.security.SecurityPermission putProviderProperty.*", + "java.security.SecurityPermission putProviderProperty.foo", + "java.sql.SQLPermission callAbort", + "java.sql.SQLPermission setNetworkTimeout", + "java.util.PropertyPermission * read", + "java.util.PropertyPermission someProperty read", + "java.util.PropertyPermission * write", + "java.util.PropertyPermission foo.bar write", + "javax.management.MBeanPermission * addNotificationListener", + "javax.management.MBeanPermission * getAttribute", + "javax.management.MBeanPermission * getDomains", + "javax.management.MBeanPermission * getMBeanInfo", + "javax.management.MBeanPermission * getObjectInstance", + "javax.management.MBeanPermission * instantiate", + "javax.management.MBeanPermission * invoke", + "javax.management.MBeanPermission * isInstanceOf", + "javax.management.MBeanPermission * queryMBeans", + "javax.management.MBeanPermission * queryNames", + "javax.management.MBeanPermission * registerMBean", + "javax.management.MBeanPermission * removeNotificationListener", + "javax.management.MBeanPermission * setAttribute", + "javax.management.MBeanPermission * unregisterMBean", + "javax.management.MBeanServerPermission *", + "javax.management.MBeanTrustPermission register", + "javax.security.auth.AuthPermission doAs", + "javax.security.auth.AuthPermission doAsPrivileged", + "javax.security.auth.AuthPermission getSubject", + "javax.security.auth.AuthPermission getSubjectFromDomainCombiner", + "javax.security.auth.AuthPermission setReadOnly", + "javax.security.auth.AuthPermission modifyPrincipals", + "javax.security.auth.AuthPermission modifyPublicCredentials", + "javax.security.auth.AuthPermission modifyPrivateCredentials", + "javax.security.auth.AuthPermission refreshCredential", + "javax.security.auth.AuthPermission destroyCredential", + "javax.security.auth.AuthPermission createLoginContext.*", + "javax.security.auth.AuthPermission getLoginConfiguration", + "javax.security.auth.AuthPermission setLoginConfiguration", + "javax.security.auth.AuthPermission createLoginConfiguration.*", + "javax.security.auth.AuthPermission refreshLoginConfiguration", + "javax.security.auth.kerberos.DelegationPermission host/www.elastic.co@ELASTIC.CO krbtgt/ELASTIC.CO@ELASTIC.CO", + "javax.security.auth.kerberos.ServicePermission host/www.elastic.co@ELASTIC.CO accept" + ); + + public void testPluginPolicyAllowedPermissions() throws Exception { + assertAllowedPermissions(PLUGIN_TEST_PERMISSIONS, PolicyUtil::getPluginPolicyInfo); + assertIllegalPermissions(MODULE_TEST_PERMISSIONS, PolicyUtil::getPluginPolicyInfo); + } + + public void testPrivateCredentialPermissionAllowed() throws Exception { + // the test permission list relies on name values not containing spaces, so this + // exists to also check PrivateCredentialPermission which requires a space in the name + String clazz = "javax.security.auth.PrivateCredentialPermission"; + String name = "com.sun.PrivateCredential com.sun.Principal \\\"duke\\\""; + + assertAllowedPermission(clazz, name, "read", createTempDir(), PolicyUtil::getPluginPolicyInfo); + } + + static final List MODULE_TEST_PERMISSIONS = List.of( + "java.io.FilePermission /foo/bar write", + "java.lang.RuntimePermission createClassLoader", + "java.lang.RuntimePermission getFileStoreAttributes", + "java.lang.RuntimePermission accessUserInformation", + "org.elasticsearch.secure_sm.ThreadPermission modifyArbitraryThreadGroup" + ); + + public void testModulePolicyAllowedPermissions() throws Exception { + assertAllowedPermissions(MODULE_TEST_PERMISSIONS, PolicyUtil::getModulePolicyInfo); + } + + static final List ILLEGAL_TEST_PERMISSIONS = List.of( + "java.awt.AWTPermission *", + "java.io.FilePermission /foo/bar execute", + "java.io.FilePermission /foo/bar delete", + "java.io.FilePermission /foo/bar readLink", + "java.io.SerializablePermission enableSubclassImplementation", + "java.io.SerializablePermission enableSubstitution", + "java.lang.management.ManagementPermission control", + "java.lang.management.ManagementPermission monitor", + "java.lang.reflect.ReflectPermission newProxyInPackage.*", + "java.lang.RuntimePermission enableContextClassLoaderOverride", + "java.lang.RuntimePermission closeClassLoader", + "java.lang.RuntimePermission setSecurityManager", + "java.lang.RuntimePermission createSecurityManager", + "java.lang.RuntimePermission getenv.*", + "java.lang.RuntimePermission getenv.FOOBAR", + "java.lang.RuntimePermission shutdownHooks", + "java.lang.RuntimePermission setIO", + "java.lang.RuntimePermission modifyThread", + "java.lang.RuntimePermission stopThread", + "java.lang.RuntimePermission modifyThreadGroup", + "java.lang.RuntimePermission getProtectionDomain", + "java.lang.RuntimePermission readFileDescriptor", + "java.lang.RuntimePermission writeFileDescriptor", + "java.lang.RuntimePermission defineClassInPackage.*", + "java.lang.RuntimePermission defineClassInPackage.foobar", + "java.lang.RuntimePermission queuePrintJob", + "java.lang.RuntimePermission getStackTrace", + "java.lang.RuntimePermission setDefaultUncaughtExceptionHandler", + "java.lang.RuntimePermission preferences", + "java.lang.RuntimePermission usePolicy", + // blanket runtime permission not allowed + "java.lang.RuntimePermission *", + "java.net.NetPermission setDefaultAuthenticator", + "java.net.NetPermission specifyStreamHandler", + "java.net.NetPermission setProxySelector", + "java.net.NetPermission setCookieHandler", + "java.net.NetPermission setResponseCache", + "java.nio.file.LinkPermission hard", + "java.nio.file.LinkPermission symbolic", + "java.security.SecurityPermission getDomainCombiner", + "java.security.SecurityPermission getPolicy", + "java.security.SecurityPermission setPolicy", + "java.security.SecurityPermission getProperty.*", + "java.security.SecurityPermission getProperty.foobar", + "java.security.SecurityPermission setProperty.*", + "java.security.SecurityPermission setProperty.foobar", + "java.security.SecurityPermission removeProvider.*", + "java.security.SecurityPermission removeProvider.foobar", + "java.security.SecurityPermission clearProviderProperties.*", + "java.security.SecurityPermission clearProviderProperties.foobar", + "java.security.SecurityPermission removeProviderProperty.*", + "java.security.SecurityPermission removeProviderProperty.foobar", + "java.security.SecurityPermission insertProvider.*", + "java.security.SecurityPermission insertProvider.foobar", + "java.security.SecurityPermission setSystemScope", + "java.security.SecurityPermission setIdentityPublicKey", + "java.security.SecurityPermission setIdentityInfo", + "java.security.SecurityPermission addIdentityCertificate", + "java.security.SecurityPermission removeIdentityCertificate", + "java.security.SecurityPermission printIdentity", + "java.security.SecurityPermission getSignerPrivateKey", + "java.security.SecurityPermission getSignerKeyPair", + "java.sql.SQLPermission setLog", + "java.sql.SQLPermission setSyncFactory", + "java.sql.SQLPermission deregisterDriver", + "java.util.logging.LoggingPermission control", + "javax.management.MBeanPermission * getClassLoader", + "javax.management.MBeanPermission * getClassLoaderFor", + "javax.management.MBeanPermission * getClassLoaderRepository", + "javax.management.MBeanTrustPermission *", + "javax.management.remote.SubjectDelegationPermission *", + "javax.net.ssl.SSLPermission setHostnameVerifier", + "javax.net.ssl.SSLPermission getSSLSessionContext", + "javax.net.ssl.SSLPermission setDefaultSSLContext", + "javax.sound.sampled.AudioPermission play", + "javax.sound.sampled.AudioPermission record", + "javax.xml.bind.JAXBPermission setDatatypeConverter", + "javax.xml.ws.WebServicePermission publishEndpoint" + ); + + public void testIllegalPermissions() throws Exception { + assertIllegalPermissions(ILLEGAL_TEST_PERMISSIONS, PolicyUtil::getPluginPolicyInfo); + assertIllegalPermissions(ILLEGAL_TEST_PERMISSIONS, PolicyUtil::getModulePolicyInfo); + } + + public void testAllPermission() throws Exception { + // AllPermission has no name element, so doesn't work with the format above + Path tmpDir = createTempDir(); + assertIllegalPermission("java.security.AllPermission", null, null, tmpDir, PolicyUtil::getPluginPolicyInfo); + assertIllegalPermission("java.security.AllPermission", null, null, tmpDir, PolicyUtil::getModulePolicyInfo); + } +} diff --git a/qa/evil-tests/src/test/java/org/elasticsearch/common/logging/EvilLoggerTests.java b/qa/evil-tests/src/test/java/org/elasticsearch/common/logging/EvilLoggerTests.java index 2bc983e77283d..992bebe57e561 100644 --- a/qa/evil-tests/src/test/java/org/elasticsearch/common/logging/EvilLoggerTests.java +++ b/qa/evil-tests/src/test/java/org/elasticsearch/common/logging/EvilLoggerTests.java @@ -174,7 +174,7 @@ public void testConcurrentDeprecationLogger() throws IOException, BrokenBarrierE assertLogLine( deprecationEvents.get(i), DeprecationLogger.CRITICAL, - "org.elasticsearch.common.logging.DeprecationLogger.logDeprecation", + "org.elasticsearch.common.logging.DeprecationLogger.lambda\\$doPrivilegedLog\\$0", ".*This is a maybe logged deprecation message" + i + ".*" ); } @@ -207,7 +207,7 @@ public void testDeprecatedSettings() throws IOException { assertLogLine( deprecationEvents.get(0), DeprecationLogger.CRITICAL, - "org.elasticsearch.common.logging.DeprecationLogger.logDeprecation", + "org.elasticsearch.common.logging.DeprecationLogger.lambda\\$doPrivilegedLog\\$0", ".*\\[deprecated.foo\\] setting was deprecated in Elasticsearch and will be removed in a future release..*" ); } diff --git a/qa/logging-config/src/javaRestTest/resources/plugin-security.policy b/qa/logging-config/src/javaRestTest/resources/plugin-security.policy new file mode 100644 index 0000000000000..1dd8051b7ff37 --- /dev/null +++ b/qa/logging-config/src/javaRestTest/resources/plugin-security.policy @@ -0,0 +1,5 @@ +grant { + // Needed to read the log file + permission java.io.FilePermission "${tests.logfile}", "read"; + permission java.io.FilePermission "${tests.jsonLogfile}", "read"; +}; diff --git a/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/FullClusterRestartSystemIndexCompatibilityIT.java b/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/FullClusterRestartSystemIndexCompatibilityIT.java index 3bc2fde3e396b..985a073bd6034 100644 --- a/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/FullClusterRestartSystemIndexCompatibilityIT.java +++ b/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/FullClusterRestartSystemIndexCompatibilityIT.java @@ -22,7 +22,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.equalTo; @@ -96,7 +95,7 @@ public void testAsyncSearchIndexMigration() throws Exception { } catch (IOException e) { throw new AssertionError("System feature migration failed", e); } - }, 30, TimeUnit.SECONDS); + }); // check search results from n-2 search are still readable assertAsyncSearchHitCount(async_search_ids.get("n-2_id"), numDocs); diff --git a/qa/unconfigured-node-name/src/javaRestTest/resources/plugin-security.policy b/qa/unconfigured-node-name/src/javaRestTest/resources/plugin-security.policy new file mode 100644 index 0000000000000..d0d865c4ede16 --- /dev/null +++ b/qa/unconfigured-node-name/src/javaRestTest/resources/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // Needed to read the log file + permission java.io.FilePermission "${tests.logfile}", "read"; +}; diff --git a/server/build.gradle b/server/build.gradle index 1b86cd639e4a6..6b406bb1d1082 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -28,6 +28,7 @@ dependencies { api project(':libs:core') api project(':libs:logging') + api project(':libs:secure-sm') api project(':libs:x-content') api project(":libs:geo") api project(":libs:lz4") diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/fieldcaps/FieldCapsWithFilterIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/fieldcaps/FieldCapsWithFilterIT.java deleted file mode 100644 index d53020702ba8b..0000000000000 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/fieldcaps/FieldCapsWithFilterIT.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.action.fieldcaps; - -import org.apache.lucene.document.LongPoint; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.PointValues; -import org.elasticsearch.action.NoShardAvailableActionException; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.engine.EngineConfig; -import org.elasticsearch.index.engine.EngineFactory; -import org.elasticsearch.index.engine.InternalEngine; -import org.elasticsearch.index.engine.InternalEngineFactory; -import org.elasticsearch.index.query.RangeQueryBuilder; -import org.elasticsearch.index.shard.IndexLongFieldRange; -import org.elasticsearch.index.shard.ShardLongFieldRange; -import org.elasticsearch.plugins.EnginePlugin; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESIntegTestCase; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Arrays; -import java.util.Collection; -import java.util.Optional; - -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; - -public class FieldCapsWithFilterIT extends ESIntegTestCase { - @Override - protected boolean addMockInternalEngine() { - return false; - } - - private static class EngineWithExposingTimestamp extends InternalEngine { - EngineWithExposingTimestamp(EngineConfig engineConfig) { - super(engineConfig); - assert IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.get(config().getIndexSettings().getSettings()) : "require read-only index"; - } - - @Override - public ShardLongFieldRange getRawFieldRange(String field) { - try (Searcher searcher = acquireSearcher("test")) { - final DirectoryReader directoryReader = searcher.getDirectoryReader(); - - final byte[] minPackedValue = PointValues.getMinPackedValue(directoryReader, field); - final byte[] maxPackedValue = PointValues.getMaxPackedValue(directoryReader, field); - if (minPackedValue == null || maxPackedValue == null) { - assert minPackedValue == null && maxPackedValue == null - : Arrays.toString(minPackedValue) + "-" + Arrays.toString(maxPackedValue); - return ShardLongFieldRange.EMPTY; - } - - return ShardLongFieldRange.of(LongPoint.decodeDimension(minPackedValue, 0), LongPoint.decodeDimension(maxPackedValue, 0)); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - } - - public static class ExposingTimestampEnginePlugin extends Plugin implements EnginePlugin { - @Override - public Optional getEngineFactory(IndexSettings indexSettings) { - if (IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.get(indexSettings.getSettings())) { - return Optional.of(EngineWithExposingTimestamp::new); - } else { - return Optional.of(new InternalEngineFactory()); - } - } - } - - @Override - protected Collection> nodePlugins() { - return CollectionUtils.appendToCopy(super.nodePlugins(), ExposingTimestampEnginePlugin.class); - } - - void createIndexAndIndexDocs(String index, Settings.Builder indexSettings, long timestamp, boolean exposeTimestamp) throws Exception { - Client client = client(); - assertAcked( - client.admin() - .indices() - .prepareCreate(index) - .setSettings(indexSettings) - .setMapping("@timestamp", "type=date", "position", "type=long") - ); - int numDocs = between(100, 500); - for (int i = 0; i < numDocs; i++) { - client.prepareIndex(index).setSource("position", i, "@timestamp", timestamp + i).get(); - } - if (exposeTimestamp) { - client.admin().indices().prepareClose(index).get(); - client.admin() - .indices() - .prepareUpdateSettings(index) - .setSettings(Settings.builder().put(IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.getKey(), true).build()) - .get(); - client.admin().indices().prepareOpen(index).get(); - assertBusy(() -> { - IndexLongFieldRange timestampRange = clusterService().state().metadata().getProject().index(index).getTimestampRange(); - assertTrue(Strings.toString(timestampRange), timestampRange.containsAllShardRanges()); - }); - } else { - client.admin().indices().prepareRefresh(index).get(); - } - } - - public void testSkipUnmatchedShards() throws Exception { - long oldTimestamp = randomLongBetween(10_000_000, 20_000_000); - long newTimestamp = randomLongBetween(30_000_000, 50_000_000); - String redNode = internalCluster().startDataOnlyNode(); - String blueNode = internalCluster().startDataOnlyNode(); - createIndexAndIndexDocs( - "index_old", - indexSettings(between(1, 5), 0).put("index.routing.allocation.include._name", redNode), - oldTimestamp, - true - ); - internalCluster().stopNode(redNode); - createIndexAndIndexDocs( - "index_new", - indexSettings(between(1, 5), 0).put("index.routing.allocation.include._name", blueNode), - newTimestamp, - false - ); - // fails without index filter - { - FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); - request.indices("index_*"); - request.fields("*"); - request.setMergeResults(false); - if (randomBoolean()) { - request.indexFilter(new RangeQueryBuilder("@timestamp").from(oldTimestamp)); - } - var response = safeGet(client().execute(TransportFieldCapabilitiesAction.TYPE, request)); - assertThat(response.getIndexResponses(), hasSize(1)); - assertThat(response.getIndexResponses().get(0).getIndexName(), equalTo("index_new")); - assertThat(response.getFailures(), hasSize(1)); - assertThat(response.getFailures().get(0).getIndices(), equalTo(new String[] { "index_old" })); - assertThat(response.getFailures().get(0).getException(), instanceOf(NoShardAvailableActionException.class)); - } - // skip unavailable shards with index filter - { - FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); - request.indices("index_*"); - request.fields("*"); - request.indexFilter(new RangeQueryBuilder("@timestamp").from(newTimestamp)); - request.setMergeResults(false); - var response = safeGet(client().execute(TransportFieldCapabilitiesAction.TYPE, request)); - assertThat(response.getIndexResponses(), hasSize(1)); - assertThat(response.getIndexResponses().get(0).getIndexName(), equalTo("index_new")); - assertThat(response.getFailures(), empty()); - } - // skip both indices on the coordinator, one the data nodes - { - FieldCapabilitiesRequest request = new FieldCapabilitiesRequest(); - request.indices("index_*"); - request.fields("*"); - request.indexFilter(new RangeQueryBuilder("@timestamp").from(newTimestamp * 2L)); - request.setMergeResults(false); - var response = safeGet(client().execute(TransportFieldCapabilitiesAction.TYPE, request)); - assertThat(response.getIndexResponses(), empty()); - assertThat(response.getFailures(), empty()); - } - } -} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java index cdafab819e091..6450b964f1f5b 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java @@ -906,10 +906,6 @@ static void unblockOnRewrite() { @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - // skip rewriting on the coordinator - if (queryRewriteContext.convertToCoordinatorRewriteContext() != null) { - return this; - } try { blockingLatch.await(); } catch (InterruptedException e) { diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 8da4f403c29bd..38e3be0d3b13f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -23,6 +23,7 @@ requires org.elasticsearch.nativeaccess; requires org.elasticsearch.geo; requires org.elasticsearch.lz4; + requires org.elasticsearch.securesm; requires org.elasticsearch.xcontent; requires org.elasticsearch.logging; requires org.elasticsearch.plugin; diff --git a/server/src/main/java/org/elasticsearch/ElasticsearchException.java b/server/src/main/java/org/elasticsearch/ElasticsearchException.java index 7b229c1f979ae..f28938ae2ea76 100644 --- a/server/src/main/java/org/elasticsearch/ElasticsearchException.java +++ b/server/src/main/java/org/elasticsearch/ElasticsearchException.java @@ -294,7 +294,7 @@ public Throwable unwrapCause() { public String getDetailedMessage() { if (getCause() != null) { StringBuilder sb = new StringBuilder(); - sb.append(this).append("; "); + sb.append(toString()).append("; "); if (getCause() instanceof ElasticsearchException) { sb.append(((ElasticsearchException) getCause()).getDetailedMessage()); } else { @@ -384,7 +384,7 @@ protected XContentBuilder toXContent(XContentBuilder builder, Params params, int if (ex != this) { generateThrowableXContent(builder, params, this, nestedLevel); } else { - innerToXContent(builder, params, this, headers, metadata, getCause(), nestedLevel); + innerToXContent(builder, params, this, getExceptionName(), getMessage(), headers, metadata, getCause(), nestedLevel); } return builder; } @@ -393,6 +393,8 @@ protected static void innerToXContent( XContentBuilder builder, Params params, Throwable throwable, + String type, + String message, Map> headers, Map> metadata, Throwable cause, @@ -406,12 +408,16 @@ protected static void innerToXContent( return; } - builder.field(TYPE, throwable instanceof ElasticsearchException e ? e.getExceptionName() : getExceptionName(throwable)); - builder.field(REASON, throwable.getMessage()); + builder.field(TYPE, type); + builder.field(REASON, message); - // TODO: we could walk the exception chain to see if _any_ causes are timeouts? - if (throwable instanceof ElasticsearchException exception && exception.isTimeout()) { - builder.field(TIMED_OUT, true); + boolean timedOut = false; + if (throwable instanceof ElasticsearchException exception) { + // TODO: we could walk the exception chain to see if _any_ causes are timeouts? + timedOut = exception.isTimeout(); + } + if (timedOut) { + builder.field(TIMED_OUT, timedOut); } for (Map.Entry> entry : metadata.entrySet()) { @@ -422,10 +428,13 @@ protected static void innerToXContent( exception.metadataToXContent(builder, params); } - if (cause != null && params.paramAsBoolean(REST_EXCEPTION_SKIP_CAUSE, REST_EXCEPTION_SKIP_CAUSE_DEFAULT) == false) { - builder.startObject(CAUSED_BY); - generateThrowableXContent(builder, params, cause, nestedLevel + 1); - builder.endObject(); + if (params.paramAsBoolean(REST_EXCEPTION_SKIP_CAUSE, REST_EXCEPTION_SKIP_CAUSE_DEFAULT) == false) { + if (cause != null) { + builder.field(CAUSED_BY); + builder.startObject(); + generateThrowableXContent(builder, params, cause, nestedLevel + 1); + builder.endObject(); + } } if (headers.isEmpty() == false) { @@ -598,7 +607,7 @@ public static ElasticsearchException innerFromXContent(XContentParser parser, bo /** * Static toXContent helper method that renders {@link org.elasticsearch.ElasticsearchException} or {@link Throwable} instances * as XContent, delegating the rendering to {@link #toXContent(XContentBuilder, Params)} - * or {@link #innerToXContent(XContentBuilder, Params, Throwable, Map, Map, Throwable, int)}. + * or {@link #innerToXContent(XContentBuilder, Params, Throwable, String, String, Map, Map, Throwable, int)}. * * This method is usually used when the {@link Throwable} is rendered as a part of another XContent object, and its result can * be parsed back using the {@link #fromXContent(XContentParser)} method. @@ -618,7 +627,7 @@ protected static void generateThrowableXContent(XContentBuilder builder, Params if (t instanceof ElasticsearchException) { ((ElasticsearchException) t).toXContent(builder, params, nestedLevel); } else { - innerToXContent(builder, params, t, emptyMap(), emptyMap(), t.getCause(), nestedLevel); + innerToXContent(builder, params, t, getExceptionName(t), t.getMessage(), emptyMap(), emptyMap(), t.getCause(), nestedLevel); } } @@ -714,8 +723,8 @@ public static ElasticsearchException failureFromXContent(XContentParser parser) */ public ElasticsearchException[] guessRootCauses() { final Throwable cause = getCause(); - if (cause instanceof ElasticsearchException ese) { - return ese.guessRootCauses(); + if (cause != null && cause instanceof ElasticsearchException) { + return ((ElasticsearchException) cause).guessRootCauses(); } return new ElasticsearchException[] { this }; } @@ -764,28 +773,35 @@ protected String getExceptionName() { */ public static String getExceptionName(Throwable ex) { String simpleName = ex.getClass().getSimpleName(); + if (simpleName.startsWith("Elasticsearch")) { + simpleName = simpleName.substring("Elasticsearch".length()); + } // TODO: do we really need to make the exception name in underscore casing? - return toUnderscoreCase(simpleName, simpleName.startsWith("Elasticsearch") ? "Elasticsearch".length() : 0); + return toUnderscoreCase(simpleName); } static String buildMessage(String type, String reason, String stack) { - return "Elasticsearch exception [" - + TYPE - + "=" - + type - + ", " - + REASON - + "=" - + reason - + (stack == null ? "" : (", " + STACK_TRACE + "=" + stack)) - + "]"; + StringBuilder message = new StringBuilder("Elasticsearch exception ["); + message.append(TYPE).append('=').append(type); + message.append(", ").append(REASON).append('=').append(reason); + if (stack != null) { + message.append(", ").append(STACK_TRACE).append('=').append(stack); + } + message.append(']'); + return message.toString(); } @Override public String toString() { - return (metadata.containsKey(INDEX_METADATA_KEY) - ? (getIndex() + (metadata.containsKey(SHARD_METADATA_KEY) ? "[" + getShardId() + "] " : " ")) - : "") + super.toString().trim(); + StringBuilder builder = new StringBuilder(); + if (metadata.containsKey(INDEX_METADATA_KEY)) { + builder.append(getIndex()); + if (metadata.containsKey(SHARD_METADATA_KEY)) { + builder.append('[').append(getShardId()).append(']'); + } + builder.append(' '); + } + return builder.append(super.toString().trim()).toString(); } /** @@ -2090,17 +2106,19 @@ public String getResourceType() { } // lower cases and adds underscores to transitions in a name - private static String toUnderscoreCase(String value, final int offset) { + private static String toUnderscoreCase(String value) { StringBuilder sb = new StringBuilder(); boolean changed = false; - for (int i = offset; i < value.length(); i++) { + for (int i = 0; i < value.length(); i++) { char c = value.charAt(i); if (Character.isUpperCase(c)) { if (changed == false) { // copy it over here - sb.append(value, offset, i); + for (int j = 0; j < i; j++) { + sb.append(value.charAt(j)); + } changed = true; - if (i == offset) { + if (i == 0) { sb.append(Character.toLowerCase(c)); } else { sb.append('_'); @@ -2117,7 +2135,7 @@ private static String toUnderscoreCase(String value, final int offset) { } } if (changed == false) { - return offset == 0 ? value : value.substring(offset); + return value; } return sb.toString(); } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 82e70c2fd69f5..9c5e6ba9bed51 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -96,7 +96,6 @@ static TransportVersion def(int id) { public static final TransportVersion V_8_17_4 = def(8_797_0_04); public static final TransportVersion V_8_17_5 = def(8_797_0_05); public static final TransportVersion INITIAL_ELASTICSEARCH_8_17_6 = def(8_797_0_06); - public static final TransportVersion INITIAL_ELASTICSEARCH_8_17_7 = def(8_797_0_07); public static final TransportVersion INDEXING_PRESSURE_THROTTLING_STATS = def(8_798_0_00); public static final TransportVersion REINDEX_DATA_STREAMS = def(8_799_0_00); public static final TransportVersion ESQL_REMOVE_NODE_LEVEL_PLAN = def(8_800_0_00); @@ -143,7 +142,6 @@ static TransportVersion def(int id) { public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_18 = def(8_840_0_01); public static final TransportVersion V_8_18_0 = def(8_840_0_02); public static final TransportVersion INITIAL_ELASTICSEARCH_8_18_1 = def(8_840_0_03); - public static final TransportVersion INITIAL_ELASTICSEARCH_8_18_2 = def(8_840_0_04); public static final TransportVersion INITIAL_ELASTICSEARCH_8_19 = def(8_841_0_00); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_841_0_01); public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02); @@ -162,7 +160,6 @@ static TransportVersion def(int id) { public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); - public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG_8_19 = def(8_841_0_18); public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19); public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20); public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21); @@ -174,9 +171,9 @@ static TransportVersion def(int id) { public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27); public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28); public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29); + public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_X = def(8_841_0_30); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); - public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES = def(9_002_0_00); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED = def(9_003_0_00); @@ -252,6 +249,7 @@ static TransportVersion def(int id) { public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00); public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00); public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); + public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_076_0_00); /* * STOP! READ THIS FIRST! No, really, @@ -318,7 +316,7 @@ static TransportVersion def(int id) { * Reference to the minimum transport version that can be used with CCS. * This should be the transport version used by the previous minor release. */ - public static final TransportVersion MINIMUM_CCS_VERSION = INITIAL_ELASTICSEARCH_9_0_1; + public static final TransportVersion MINIMUM_CCS_VERSION = V_9_0_0; /** * Sorted list of all versions defined in this class diff --git a/server/src/main/java/org/elasticsearch/Version.java b/server/src/main/java/org/elasticsearch/Version.java index 70c5792f25e49..58b6cbe57850e 100644 --- a/server/src/main/java/org/elasticsearch/Version.java +++ b/server/src/main/java/org/elasticsearch/Version.java @@ -202,14 +202,11 @@ public class Version implements VersionId, ToXContentFragment { public static final Version V_8_17_4 = new Version(8_17_04_99); public static final Version V_8_17_5 = new Version(8_17_05_99); public static final Version V_8_17_6 = new Version(8_17_06_99); - public static final Version V_8_17_7 = new Version(8_17_07_99); public static final Version V_8_18_0 = new Version(8_18_00_99); public static final Version V_8_18_1 = new Version(8_18_01_99); - public static final Version V_8_18_2 = new Version(8_18_02_99); public static final Version V_8_19_0 = new Version(8_19_00_99); public static final Version V_9_0_0 = new Version(9_00_00_99); public static final Version V_9_0_1 = new Version(9_00_01_99); - public static final Version V_9_0_2 = new Version(9_00_02_99); public static final Version V_9_1_0 = new Version(9_01_00_99); public static final Version CURRENT = V_9_1_0; diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java index efbde6264e91c..88eb2ef4fb13d 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java @@ -22,7 +22,6 @@ import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -38,8 +37,6 @@ public final class FieldCapabilitiesRequest extends ActionRequest implements Ind public static final String NAME = "field_caps_request"; public static final IndicesOptions DEFAULT_INDICES_OPTIONS = IndicesOptions.strictExpandOpenAndForbidClosed(); - private String clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; - private String[] indices = Strings.EMPTY_ARRAY; private IndicesOptions indicesOptions = DEFAULT_INDICES_OPTIONS; private String[] fields = Strings.EMPTY_ARRAY; @@ -70,11 +67,6 @@ public FieldCapabilitiesRequest(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { includeEmptyFields = in.readBoolean(); } - if (in.getTransportVersion().onOrAfter(TransportVersions.FIELD_CAPS_ADD_CLUSTER_ALIAS)) { - clusterAlias = in.readOptionalString(); - } else { - clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; - } } public FieldCapabilitiesRequest() {} @@ -98,14 +90,6 @@ public void setMergeResults(boolean mergeResults) { this.mergeResults = mergeResults; } - void clusterAlias(String clusterAlias) { - this.clusterAlias = clusterAlias; - } - - String clusterAlias() { - return clusterAlias; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -124,9 +108,6 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { out.writeBoolean(includeEmptyFields); } - if (out.getTransportVersion().onOrAfter(TransportVersions.FIELD_CAPS_ADD_CLUSTER_ALIAS)) { - out.writeOptionalString(clusterAlias); - } } @Override diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/RequestDispatcher.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/RequestDispatcher.java index c56fd985c9e2b..93095e872858a 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/RequestDispatcher.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/RequestDispatcher.java @@ -26,14 +26,8 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.RunOnce; -import org.elasticsearch.index.query.CoordinatorRewriteContextProvider; import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.search.SearchService; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.internal.AliasFilter; -import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; @@ -78,7 +72,6 @@ final class RequestDispatcher { ClusterService clusterService, TransportService transportService, ProjectResolver projectResolver, - CoordinatorRewriteContextProvider coordinatorRewriteContextProvider, Task parentTask, FieldCapabilitiesRequest fieldCapsRequest, OriginalIndices originalIndices, @@ -112,14 +105,8 @@ final class RequestDispatcher { onIndexFailure.accept(index, e); continue; } - final IndexSelector indexResult = new IndexSelector( - fieldCapsRequest.clusterAlias(), - shardIts, - fieldCapsRequest.indexFilter(), - nowInMillis, - coordinatorRewriteContextProvider - ); - if (indexResult.nodeToShards.isEmpty() && indexResult.unmatchedShardIds.isEmpty()) { + final IndexSelector indexResult = new IndexSelector(shardIts); + if (indexResult.nodeToShards.isEmpty()) { onIndexFailure.accept(index, new NoShardAvailableActionException(null, "index [" + index + "] has no active shard copy")); } else { this.indexSelectors.put(index, indexResult); @@ -268,34 +255,10 @@ private static class IndexSelector { private final Set unmatchedShardIds = new HashSet<>(); private final Map failures = new HashMap<>(); - IndexSelector( - String clusterAlias, - List shardIts, - QueryBuilder indexFilter, - long nowInMillis, - CoordinatorRewriteContextProvider coordinatorRewriteContextProvider - ) { + IndexSelector(List shardIts) { for (ShardIterator shardIt : shardIts) { - boolean canMatch = true; - final ShardId shardId = shardIt.shardId(); - if (indexFilter != null && indexFilter instanceof MatchAllQueryBuilder == false) { - var coordinatorRewriteContext = coordinatorRewriteContextProvider.getCoordinatorRewriteContext(shardId.getIndex()); - if (coordinatorRewriteContext != null) { - var shardRequest = new ShardSearchRequest(shardId, nowInMillis, AliasFilter.EMPTY, clusterAlias); - shardRequest.source(new SearchSourceBuilder().query(indexFilter)); - try { - canMatch = SearchService.queryStillMatchesAfterRewrite(shardRequest, coordinatorRewriteContext); - } catch (Exception e) { - // treat as if shard is still a potential match - } - } - } - if (canMatch) { - for (ShardRouting shard : shardIt) { - nodeToShards.computeIfAbsent(shard.currentNodeId(), node -> new ArrayList<>()).add(shard); - } - } else { - unmatchedShardIds.add(shardId); + for (ShardRouting shard : shardIt) { + nodeToShards.computeIfAbsent(shard.currentNodeId(), node -> new ArrayList<>()).add(shard); } } } diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java index d64da10e0b2f7..1868cd649f0ee 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java @@ -250,7 +250,6 @@ private void doExecuteForked( clusterService, transportService, projectResolver, - indicesService.getCoordinatorRewriteContextProvider(() -> nowInMillis), task, request, localIndices, @@ -274,7 +273,7 @@ private void doExecuteForked( singleThreadedExecutor, RemoteClusterService.DisconnectedStrategy.RECONNECT_UNLESS_SKIP_UNAVAILABLE ); - FieldCapabilitiesRequest remoteRequest = prepareRemoteRequest(clusterAlias, request, originalIndices, nowInMillis); + FieldCapabilitiesRequest remoteRequest = prepareRemoteRequest(request, originalIndices, nowInMillis); ActionListener remoteListener = ActionListener.wrap(response -> { for (FieldCapabilitiesIndexResponse resp : response.getIndexResponses()) { String indexName = RemoteClusterAware.buildRemoteIndexName(clusterAlias, resp.getIndexName()); @@ -384,13 +383,11 @@ private static void mergeIndexResponses( } private static FieldCapabilitiesRequest prepareRemoteRequest( - String clusterAlias, FieldCapabilitiesRequest request, OriginalIndices originalIndices, long nowInMillis ) { FieldCapabilitiesRequest remoteRequest = new FieldCapabilitiesRequest(); - remoteRequest.clusterAlias(clusterAlias); remoteRequest.setMergeResults(false); // we need to merge on this node remoteRequest.indicesOptions(originalIndices.indicesOptions()); remoteRequest.indices(originalIndices.indices()); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseExecutionException.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseExecutionException.java index 458c9c9ec2505..fc79e85f9326d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseExecutionException.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseExecutionException.java @@ -106,6 +106,22 @@ public Throwable getCause() { return cause; } + private static String buildMessage(String phaseName, String msg, ShardSearchFailure[] shardFailures) { + StringBuilder sb = new StringBuilder(); + sb.append("Failed to execute phase [").append(phaseName).append("], ").append(msg); + if (CollectionUtils.isEmpty(shardFailures) == false) { + sb.append("; shardFailures "); + for (ShardSearchFailure shardFailure : shardFailures) { + if (shardFailure.shard() != null) { + sb.append("{").append(shardFailure.shard()).append(": ").append(shardFailure.reason()).append("}"); + } else { + sb.append("{").append(shardFailure.reason()).append("}"); + } + } + } + return sb.toString(); + } + @Override protected void metadataToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("phase", phaseName); @@ -128,7 +144,17 @@ protected XContentBuilder toXContent(XContentBuilder builder, Params params, int // We don't have a cause when all shards failed, but we do have shards failures so we can "guess" a cause // (see {@link #getCause()}). Here, we use super.getCause() because we don't want the guessed exception to // be rendered twice (one in the "cause" field, one in "failed_shards") - innerToXContent(builder, params, this, getHeaders(), getMetadata(), super.getCause(), nestedLevel); + innerToXContent( + builder, + params, + this, + getExceptionName(), + getMessage(), + getHeaders(), + getMetadata(), + super.getCause(), + nestedLevel + ); } return builder; } @@ -146,23 +172,7 @@ public ElasticsearchException[] guessRootCauses() { @Override public String toString() { - return "Failed to execute phase [" - + phaseName - + "], " - + getMessage() - + (CollectionUtils.isEmpty(shardFailures) ? "" : buildShardFailureString()); - } - - private String buildShardFailureString() { - StringBuilder sb = new StringBuilder("; shardFailures "); - for (ShardSearchFailure shardFailure : shardFailures) { - sb.append("{"); - if (shardFailure.shard() != null) { - sb.append(shardFailure.shard()).append(": "); - } - sb.append(shardFailure.reason()).append("}"); - } - return sb.toString(); + return buildMessage(phaseName, getMessage(), shardFailures); } public String getPhaseName() { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 9aee331bb106b..39e1c30f658d8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; @@ -819,7 +820,7 @@ void onShardDone() { out.close(); } } - ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(out.moveToBytesReference())); + ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(new ReleasableBytesReference(out.bytes(), out))); } private void maybeFreeContext(SearchPhaseResult result, BitSet relevantShardIndices) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 02cf294ac6f1c..21eeaedb7ea54 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.CollectionUtil; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -1429,7 +1430,7 @@ static List mergeShardsIterators( } else { shards = CollectionUtils.concatLists(remoteShardIterators, localShardIterators); } - shards.sort(SearchShardIterator::compareTo); + CollectionUtil.timSort(shards); return shards; } diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Bootstrap.java b/server/src/main/java/org/elasticsearch/bootstrap/Bootstrap.java index 56d185645e149..4c7fb96c5b1d5 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Bootstrap.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Bootstrap.java @@ -33,6 +33,7 @@ class Bootstrap { // arguments from the CLI process private final ServerArgs args; + private final boolean useEntitlements; // controller for spawning component subprocesses private final Spawner spawner = new Spawner(); @@ -46,10 +47,11 @@ class Bootstrap { // loads information about plugins required for entitlements in phase 2, used by plugins service in phase 3 private final SetOnce pluginsLoader = new SetOnce<>(); - Bootstrap(PrintStream out, PrintStream err, ServerArgs args) { + Bootstrap(PrintStream out, PrintStream err, ServerArgs args, boolean useEntitlements) { this.out = out; this.err = err; this.args = args; + this.useEntitlements = useEntitlements; } ServerArgs args() { @@ -60,6 +62,10 @@ Spawner spawner() { return spawner; } + public boolean useEntitlements() { + return useEntitlements; + } + void setSecureSettings(SecureSettings secureSettings) { this.secureSettings.set(secureSettings); } diff --git a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java index 246c9a781dc68..b5b616fff0182 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java @@ -21,6 +21,7 @@ import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.discovery.DiscoveryModule; import org.elasticsearch.index.IndexModule; +import org.elasticsearch.jdk.RuntimeVersionFeature; import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.monitor.process.ProcessProbe; import org.elasticsearch.nativeaccess.NativeAccess; @@ -32,6 +33,7 @@ import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; +import java.security.AllPermission; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -709,6 +711,36 @@ public ReferenceDocs referenceDocs() { } + static class AllPermissionCheck implements BootstrapCheck { + + @Override + public final BootstrapCheckResult check(BootstrapContext context) { + if (isAllPermissionGranted()) { + return BootstrapCheck.BootstrapCheckResult.failure("granting the all permission effectively disables security"); + } + return BootstrapCheckResult.success(); + } + + boolean isAllPermissionGranted() { + if (RuntimeVersionFeature.isSecurityManagerAvailable() == false) { + return false; + } + final SecurityManager sm = System.getSecurityManager(); + assert sm != null; + try { + sm.checkPermission(new AllPermission()); + } catch (final SecurityException e) { + return false; + } + return true; + } + + @Override + public ReferenceDocs referenceDocs() { + return ReferenceDocs.BOOTSTRAP_CHECK_ALL_PERMISSION; + } + } + static class DiscoveryConfiguredCheck implements BootstrapCheck { @Override public BootstrapCheckResult check(BootstrapContext context) { diff --git a/server/src/main/java/org/elasticsearch/bootstrap/ESPolicy.java b/server/src/main/java/org/elasticsearch/bootstrap/ESPolicy.java new file mode 100644 index 0000000000000..845303abe6baf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/bootstrap/ESPolicy.java @@ -0,0 +1,319 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.core.Predicates; +import org.elasticsearch.core.SuppressForbidden; + +import java.io.FilePermission; +import java.io.IOException; +import java.net.SocketPermission; +import java.net.URL; +import java.security.AllPermission; +import java.security.CodeSource; +import java.security.Permission; +import java.security.PermissionCollection; +import java.security.Permissions; +import java.security.Policy; +import java.security.ProtectionDomain; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** custom policy for union of static and dynamic permissions */ +final class ESPolicy extends Policy { + + /** template policy file, the one used in tests */ + static final String POLICY_RESOURCE = "security.policy"; + /** limited policy for scripts */ + static final String UNTRUSTED_RESOURCE = "untrusted.policy"; + + private static final String ALL_FILE_MASK = "read,readlink,write,delete,execute"; + private static final AllPermission ALL_PERMISSION = new AllPermission(); + + final Policy template; + final Policy untrusted; + final Policy system; + final PermissionCollection dynamic; + final PermissionCollection dataPathPermission; + final Map plugins; + final PermissionCollection allSecuredFiles; + final Map> securedFiles; + + @SuppressForbidden(reason = "Need to access and check file permissions directly") + ESPolicy( + Policy template, + PermissionCollection dynamic, + Map plugins, + boolean filterBadDefaults, + List dataPathPermissions, + Map> securedFiles + ) { + this.template = template; + this.dataPathPermission = createPermission(dataPathPermissions); + this.untrusted = PolicyUtil.readPolicy(getClass().getResource(UNTRUSTED_RESOURCE), Collections.emptyMap()); + if (filterBadDefaults) { + this.system = new SystemPolicy(Policy.getPolicy()); + } else { + this.system = Policy.getPolicy(); + } + this.dynamic = dynamic; + this.plugins = plugins; + + this.securedFiles = securedFiles.entrySet() + .stream() + .collect(Collectors.toUnmodifiableMap(e -> new FilePermission(e.getKey(), ALL_FILE_MASK), e -> Set.copyOf(e.getValue()))); + this.allSecuredFiles = createPermission(this.securedFiles.keySet()); + } + + private static PermissionCollection createPermission(Collection permissions) { + PermissionCollection coll; + var it = permissions.iterator(); + if (it.hasNext() == false) { + coll = new Permissions(); + } else { + Permission p = it.next(); + coll = p.newPermissionCollection(); + coll.add(p); + it.forEachRemaining(coll::add); + } + + coll.setReadOnly(); + return coll; + } + + private static PermissionCollection createPermission(List permissions) { + PermissionCollection coll = null; + for (FilePermission permission : permissions) { + if (coll == null) { + coll = permission.newPermissionCollection(); + } + coll.add(permission); + } + if (coll == null) { + coll = new Permissions(); + } + coll.setReadOnly(); + return coll; + } + + @Override + @SuppressForbidden(reason = "fast equals check is desired") + public boolean implies(ProtectionDomain domain, Permission permission) { + CodeSource codeSource = domain.getCodeSource(); + // codesource can be null when reducing privileges via doPrivileged() + if (codeSource == null) { + return false; + } + + URL location = codeSource.getLocation(); + if (allSecuredFiles.implies(permission)) { + /* + * Check if location can access this secured file + * The permission this is generated from, SecuredFileAccessPermission, doesn't have a mask, + * it just grants all access (and so disallows all access from others) + * It's helpful to use the infrastructure around FilePermission here to do the directory structure check with implies + * so we use ALL_FILE_MASK mask to check if we can do something with this file, whatever the actual operation we're requesting + */ + return canAccessSecuredFile(domain, new FilePermission(permission.getName(), ALL_FILE_MASK)); + } + + if (location != null) { + // run scripts with limited permissions + if (BootstrapInfo.UNTRUSTED_CODEBASE.equals(location.getFile())) { + return untrusted.implies(domain, permission); + } + // check for an additional plugin permission: plugin policy is + // only consulted for its codesources. + Policy plugin = plugins.get(location); + if (plugin != null && plugin.implies(domain, permission)) { + return true; + } + } + + // The FilePermission to check access to the path.data is the hottest permission check in + // Elasticsearch, so we explicitly check it here. + if (dataPathPermission.implies(permission)) { + return true; + } + + // Special handling for broken Hadoop code: "let me execute or my classes will not load" + // yeah right, REMOVE THIS when hadoop is fixed + if (permission instanceof FilePermission && "<>".equals(permission.getName())) { + hadoopHack(); + } + + // otherwise defer to template + dynamic file permissions + return template.implies(domain, permission) || dynamic.implies(permission) || system.implies(domain, permission); + } + + @SuppressForbidden(reason = "We get given an URL by the security infrastructure") + private boolean canAccessSecuredFile(ProtectionDomain domain, FilePermission permission) { + if (domain == null || domain.getCodeSource() == null || domain.getCodeSource().getLocation() == null) { + return false; + } + + // If the domain in question has AllPermission - only true of sources built into the JDK, as we prevent AllPermission from being + // configured in Elasticsearch - then it has access to this file. + + if (system.implies(domain, ALL_PERMISSION)) { + return true; + } + URL location = domain.getCodeSource().getLocation(); + + // check the source + Set accessibleSources = securedFiles.get(permission); + if (accessibleSources != null) { + // simple case - single-file referenced directly + + return accessibleSources.contains(location); + } else { + // there's a directory reference in there somewhere + // do a manual search :( + // there may be several permissions that potentially match, + // grant access if any of them cover this file + return securedFiles.entrySet() + .stream() + .filter(e -> e.getKey().implies(permission)) + .anyMatch(e -> e.getValue().contains(location)); + } + } + + private static void hadoopHack() { + for (StackTraceElement element : Thread.currentThread().getStackTrace()) { + if ("org.apache.hadoop.util.Shell".equals(element.getClassName()) && "runCommand".equals(element.getMethodName())) { + // we found the horrible method: the hack begins! + // force the hadoop code to back down, by throwing an exception that it catches. + rethrow(new IOException("no hadoop, you cannot do this.")); + } + } + } + + /** + * Classy puzzler to rethrow any checked exception as an unchecked one. + */ + private static class Rethrower { + @SuppressWarnings("unchecked") + private void rethrow(Throwable t) throws T { + throw (T) t; + } + } + + /** + * Rethrows t (identical object). + */ + private static void rethrow(Throwable t) { + new Rethrower().rethrow(t); + } + + @Override + public PermissionCollection getPermissions(CodeSource codesource) { + // code should not rely on this method, or at least use it correctly: + // https://bugs.openjdk.java.net/browse/JDK-8014008 + // return them a new empty permissions object so jvisualvm etc work + for (StackTraceElement element : Thread.currentThread().getStackTrace()) { + if ("sun.rmi.server.LoaderHandler".equals(element.getClassName()) && "loadClass".equals(element.getMethodName())) { + return new Permissions(); + } + } + // return UNSUPPORTED_EMPTY_COLLECTION since it is safe. + return super.getPermissions(codesource); + } + + // TODO: remove this hack when insecure defaults are removed from java + + /** + * Wraps a bad default permission, applying a pre-implies to any permissions before checking if the wrapped bad default permission + * implies a permission. + */ + private static class BadDefaultPermission extends Permission { + + private final Permission badDefaultPermission; + private final Predicate preImplies; + + /** + * Construct an instance with a pre-implies check to apply to desired permissions. + * + * @param badDefaultPermission the bad default permission to wrap + * @param preImplies a test that is applied to a desired permission before checking if the bad default permission that + * this instance wraps implies the desired permission + */ + BadDefaultPermission(final Permission badDefaultPermission, final Predicate preImplies) { + super(badDefaultPermission.getName()); + this.badDefaultPermission = badDefaultPermission; + this.preImplies = preImplies; + } + + @Override + public final boolean implies(Permission permission) { + return preImplies.test(permission) && badDefaultPermission.implies(permission); + } + + @Override + public final boolean equals(Object obj) { + return badDefaultPermission.equals(obj); + } + + @Override + public int hashCode() { + return badDefaultPermission.hashCode(); + } + + @Override + public String getActions() { + return badDefaultPermission.getActions(); + } + + } + + // default policy file states: + // "It is strongly recommended that you either remove this permission + // from this policy file or further restrict it to code sources + // that you specify, because Thread.stop() is potentially unsafe." + // not even sure this method still works... + private static final Permission BAD_DEFAULT_NUMBER_ONE = new BadDefaultPermission( + new RuntimePermission("stopThread"), + Predicates.always() + ); + + // default policy file states: + // "allows anyone to listen on dynamic ports" + // specified exactly because that is what we want, and fastest since it won't imply any + // expensive checks for the implicit "resolve" + private static final Permission BAD_DEFAULT_NUMBER_TWO = new BadDefaultPermission( + new SocketPermission("localhost:0", "listen"), + // we apply this pre-implies test because some SocketPermission#implies calls do expensive reverse-DNS resolves + p -> p instanceof SocketPermission && p.getActions().contains("listen") + ); + + /** + * Wraps the Java system policy, filtering out bad default permissions that + * are granted to all domains. Note, before java 8 these were even worse. + */ + static class SystemPolicy extends Policy { + final Policy delegate; + + SystemPolicy(Policy delegate) { + this.delegate = delegate; + } + + @Override + public boolean implies(ProtectionDomain domain, Permission permission) { + if (BAD_DEFAULT_NUMBER_ONE.implies(permission) || BAD_DEFAULT_NUMBER_TWO.implies(permission)) { + return false; + } + return delegate.implies(domain, permission); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java index d759abd366e7c..83c1db21acc3c 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java @@ -63,6 +63,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.security.Security; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -127,6 +128,7 @@ private static Bootstrap initPhase1() { final PrintStream err = getStderr(); final ServerArgs args; + final boolean useEntitlements = true; try { initSecurityProperties(); LogConfigurator.registerErrorListener(); @@ -154,7 +156,7 @@ private static Bootstrap initPhase1() { return null; // unreachable, to satisfy compiler } - return new Bootstrap(out, err, args); + return new Bootstrap(out, err, args, useEntitlements); } /** @@ -400,7 +402,11 @@ protected void validateNodeBeforeAcceptingRequests( final BoundTransportAddress boundTransportAddress, List checks ) throws NodeValidationException { - BootstrapChecks.check(context, boundTransportAddress, checks); + var additionalChecks = new ArrayList<>(checks); + if (bootstrap.useEntitlements() == false) { + additionalChecks.add(new BootstrapChecks.AllPermissionCheck()); + } + BootstrapChecks.check(context, boundTransportAddress, additionalChecks); } }; INSTANCE = new Elasticsearch(bootstrap.spawner(), node); @@ -562,6 +568,9 @@ private static void initSecurityProperties() { } } } + + // policy file codebase declarations in security.policy rely on property expansion, see PolicyUtil.readPolicy + Security.setProperty("policy.expandProperties", "true"); } private static Environment createEnvironment(Path configDir, Settings initialSettings, SecureSettings secureSettings) { diff --git a/server/src/main/java/org/elasticsearch/bootstrap/ElasticsearchUncaughtExceptionHandler.java b/server/src/main/java/org/elasticsearch/bootstrap/ElasticsearchUncaughtExceptionHandler.java index 20ba8a9dd5e8c..b2c1bcb1d544a 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/ElasticsearchUncaughtExceptionHandler.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/ElasticsearchUncaughtExceptionHandler.java @@ -14,6 +14,8 @@ import org.elasticsearch.core.SuppressForbidden; import java.io.IOError; +import java.security.AccessController; +import java.security.PrivilegedAction; class ElasticsearchUncaughtExceptionHandler implements Thread.UncaughtExceptionHandler { private static final Logger logger = LogManager.getLogger(ElasticsearchUncaughtExceptionHandler.class); @@ -51,17 +53,41 @@ static boolean isFatalUncaught(Throwable e) { void onFatalUncaught(final String threadName, final Throwable t) { final String message = "fatal error in thread [" + threadName + "], exiting"; - logger.error(message, t); + logErrorMessage(t, message); } void onNonFatalUncaught(final String threadName, final Throwable t) { final String message = "uncaught exception in thread [" + threadName + "]"; - logger.error(message, t); + logErrorMessage(t, message); + } + + private static void logErrorMessage(Throwable t, String message) { + AccessController.doPrivileged((PrivilegedAction) () -> { + logger.error(message, t); + return null; + }); } - @SuppressForbidden(reason = "intentionally halting") void halt(int status) { - // we halt to prevent shutdown hooks from running - Runtime.getRuntime().halt(status); + AccessController.doPrivileged(new PrivilegedHaltAction(status)); } + + static class PrivilegedHaltAction implements PrivilegedAction { + + private final int status; + + private PrivilegedHaltAction(final int status) { + this.status = status; + } + + @SuppressForbidden(reason = "halt") + @Override + public Void run() { + // we halt to prevent shutdown hooks from running + Runtime.getRuntime().halt(status); + return null; + } + + } + } diff --git a/server/src/main/java/org/elasticsearch/bootstrap/PluginPolicyInfo.java b/server/src/main/java/org/elasticsearch/bootstrap/PluginPolicyInfo.java new file mode 100644 index 0000000000000..c5fb06a1bcba3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/bootstrap/PluginPolicyInfo.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import java.net.URL; +import java.nio.file.Path; +import java.security.Policy; +import java.util.Set; + +public record PluginPolicyInfo(Path file, Set jars, Policy policy) {} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/PolicyUtil.java b/server/src/main/java/org/elasticsearch/bootstrap/PolicyUtil.java new file mode 100644 index 0000000000000..78cf0ee93a0e5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/bootstrap/PolicyUtil.java @@ -0,0 +1,439 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.SecuredConfigFileAccessPermission; +import org.elasticsearch.SecuredConfigFileSettingAccessPermission; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.PathUtils; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.plugins.PluginDescriptor; +import org.elasticsearch.script.ClassPermission; +import org.elasticsearch.secure_sm.ThreadPermission; + +import java.io.FilePermission; +import java.io.IOException; +import java.lang.reflect.ReflectPermission; +import java.net.NetPermission; +import java.net.SocketPermission; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLPermission; +import java.nio.file.DirectoryStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.CodeSource; +import java.security.NoSuchAlgorithmException; +import java.security.Permission; +import java.security.PermissionCollection; +import java.security.Permissions; +import java.security.Policy; +import java.security.ProtectionDomain; +import java.security.SecurityPermission; +import java.security.URIParameter; +import java.security.UnresolvedPermission; +import java.security.cert.Certificate; +import java.sql.SQLPermission; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.PropertyPermission; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import javax.management.MBeanPermission; +import javax.management.MBeanServerPermission; +import javax.management.MBeanTrustPermission; +import javax.management.ObjectName; +import javax.security.auth.AuthPermission; +import javax.security.auth.PrivateCredentialPermission; +import javax.security.auth.kerberos.DelegationPermission; +import javax.security.auth.kerberos.ServicePermission; + +import static java.util.Map.entry; + +public class PolicyUtil { + + // this object is checked by reference, so the value in the list does not matter + static final List ALLOW_ALL_NAMES = List.of("ALLOW ALL NAMES SENTINEL"); + + static class PermissionMatcher implements Predicate { + + PermissionCollection namedPermissions; + Map> classPermissions; + + PermissionMatcher(PermissionCollection namedPermissions, Map> classPermissions) { + this.namedPermissions = namedPermissions; + this.classPermissions = classPermissions; + } + + @Override + public boolean test(Permission permission) { + if (namedPermissions.implies(permission)) { + return true; + } + String clazz = permission.getClass().getCanonicalName(); + String name = permission.getName(); + if (permission.getClass().equals(UnresolvedPermission.class)) { + UnresolvedPermission up = (UnresolvedPermission) permission; + clazz = up.getUnresolvedType(); + name = up.getUnresolvedName(); + } + List allowedNames = classPermissions.get(clazz); + return allowedNames != null && (allowedNames == ALLOW_ALL_NAMES || allowedNames.contains(name)); + } + } + + private static final PermissionMatcher ALLOWED_PLUGIN_PERMISSIONS; + private static final PermissionMatcher ALLOWED_MODULE_PERMISSIONS; + static { + List namedPermissions = List.of( + // TODO: remove read permission, see https://github.com/elastic/elasticsearch/issues/69464 + createFilePermission("<>", "read"), + + new ReflectPermission("suppressAccessChecks"), + new RuntimePermission("getClassLoader"), + new RuntimePermission("setContextClassLoader"), + new RuntimePermission("setFactory"), + new RuntimePermission("loadLibrary.*"), + new RuntimePermission("accessClassInPackage.*"), + new RuntimePermission("accessDeclaredMembers"), + new NetPermission("requestPasswordAuthentication"), + new NetPermission("getProxySelector"), + new NetPermission("getCookieHandler"), + new NetPermission("getResponseCache"), + new SocketPermission("*", "accept,connect,listen,resolve"), + new SecurityPermission("createAccessControlContext"), + new SecurityPermission("insertProvider"), + new SecurityPermission("putProviderProperty.*"), + // apache abuses the SecurityPermission class for it's own purposes + new SecurityPermission("org.apache.*"), + // write is needed because of HdfsPlugin + new PropertyPermission("*", "read,write"), + new AuthPermission("doAs"), + new AuthPermission("doAsPrivileged"), + new AuthPermission("getSubject"), + new AuthPermission("getSubjectFromDomainCombiner"), + new AuthPermission("setReadOnly"), + new AuthPermission("modifyPrincipals"), + new AuthPermission("modifyPublicCredentials"), + new AuthPermission("modifyPrivateCredentials"), + new AuthPermission("refreshCredential"), + new AuthPermission("destroyCredential"), + new AuthPermission("createLoginContext.*"), + new AuthPermission("getLoginConfiguration"), + new AuthPermission("setLoginConfiguration"), + new AuthPermission("createLoginConfiguration.*"), + new AuthPermission("refreshLoginConfiguration"), + new MBeanPermission( + "*", + "*", + ObjectName.WILDCARD, + "addNotificationListener,getAttribute,getDomains,getMBeanInfo,getObjectInstance,instantiate,invoke," + + "isInstanceOf,queryMBeans,queryNames,registerMBean,removeNotificationListener,setAttribute,unregisterMBean" + ), + new MBeanServerPermission("*"), + new MBeanTrustPermission("register") + ); + // While it would be ideal to represent all allowed permissions with concrete instances so that we can + // use the builtin implies method to match them against the parsed policy, this does not work in all + // cases for two reasons: + // (1) Some permissions classes do not have a name argument that can represent all possible variants. + // For example, FilePermission has "<< ALL FILES >>" so all paths can be matched, but DelegationPermission + // does not have anything to represent all principals. + // (2) Some permissions classes are in java modules that are not accessible from the classloader used by + // the policy parser. This results in those permissions being in UnresolvedPermission instances. Those + // are normally resolved at runtime when that permission is checked by SecurityManager. But there is + // no general purpose utility to resolve those permissions, so we must be able to match those + // unresolved permissions in the policy by class and name values. + // Given the above, the below map is from permission class to the list of allowed name values. A sentinel value + // is used to mean names are accepted. We do not use this model for all permissions because many permission + // classes have their own meaning for some form of wildcard matching of the name, which we want to delegate + // to those permissions if possible. + Map> classPermissions = Stream.of( + entry(URLPermission.class, ALLOW_ALL_NAMES), + entry(DelegationPermission.class, ALLOW_ALL_NAMES), + entry(ServicePermission.class, ALLOW_ALL_NAMES), + entry(PrivateCredentialPermission.class, ALLOW_ALL_NAMES), + entry(SQLPermission.class, List.of("callAbort", "setNetworkTimeout")), + entry(ClassPermission.class, ALLOW_ALL_NAMES), + entry(SecuredConfigFileAccessPermission.class, ALLOW_ALL_NAMES), + entry(SecuredConfigFileSettingAccessPermission.class, ALLOW_ALL_NAMES) + ).collect(Collectors.toMap(e -> e.getKey().getCanonicalName(), Map.Entry::getValue)); + PermissionCollection pluginPermissionCollection = new Permissions(); + namedPermissions.forEach(pluginPermissionCollection::add); + pluginPermissionCollection.setReadOnly(); + ALLOWED_PLUGIN_PERMISSIONS = new PermissionMatcher(pluginPermissionCollection, classPermissions); + + // Modules are allowed a few extra permissions. While we should strive to keep this list small, modules + // are essentially part of core, so these are permissions we need for various reasons in core functionality, + // but that we do not think plugins in general should need. + List modulePermissions = List.of( + createFilePermission("<>", "read,write"), + new RuntimePermission("createClassLoader"), + new RuntimePermission("getFileStoreAttributes"), + new RuntimePermission("accessUserInformation"), + new AuthPermission("modifyPrivateCredentials"), + new RuntimePermission("accessSystemModules") + ); + PermissionCollection modulePermissionCollection = new Permissions(); + namedPermissions.forEach(modulePermissionCollection::add); + modulePermissions.forEach(modulePermissionCollection::add); + modulePermissionCollection.setReadOnly(); + Map> moduleClassPermissions = new HashMap<>(classPermissions); + moduleClassPermissions.put( + // Not available to the SecurityManager ClassLoader. See classPermissions comment. + ThreadPermission.class.getCanonicalName(), + List.of("modifyArbitraryThreadGroup") + ); + moduleClassPermissions = Collections.unmodifiableMap(moduleClassPermissions); + ALLOWED_MODULE_PERMISSIONS = new PermissionMatcher(modulePermissionCollection, moduleClassPermissions); + } + + @SuppressForbidden(reason = "create permission for test") + private static FilePermission createFilePermission(String path, String actions) { + return new FilePermission(path, actions); + } + + /** + * Return a map from codebase name to codebase url of jar codebases used by ES core. + */ + @SuppressForbidden(reason = "find URL path") + public static Map getCodebaseJarMap(Set urls) { + Map codebases = new LinkedHashMap<>(); // maintain order + for (URL url : urls) { + try { + String fileName = PathUtils.get(url.toURI()).getFileName().toString(); + if (fileName.endsWith(".jar") == false) { + // tests :( + continue; + } + codebases.put(fileName, url); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + return codebases; + } + + /** + * Reads and returns the specified {@code policyFile}. + *

+ * Jar files listed in {@code codebases} location will be provided to the policy file via + * a system property of the short name: e.g. ${codebase.my-dep-1.2.jar} + * would map to full URL. + */ + @SuppressForbidden(reason = "accesses fully qualified URLs to configure security") + public static Policy readPolicy(URL policyFile, Map codebases) { + try { + Properties originalProps = System.getProperties(); + // allow missing while still setting values + Set unknownCodebases = new HashSet<>(); + Map codebaseProperties = new HashMap<>(); + Properties tempProps = new Properties(originalProps) { + @Override + public String getProperty(String key) { + if (key.startsWith("codebase.")) { + String value = codebaseProperties.get(key); + if (value == null) { + unknownCodebases.add(key); + } + return value; + } else { + return super.getProperty(key); + } + } + }; + + try { + System.setProperties(tempProps); + // set codebase properties + for (Map.Entry codebase : codebases.entrySet()) { + String name = codebase.getKey(); + URL url = codebase.getValue(); + + // We attempt to use a versionless identifier for each codebase. This assumes a specific version + // format in the jar filename. While we cannot ensure all jars in all plugins use this format, nonconformity + // only means policy grants would need to include the entire jar filename as they always have before. + String property = "codebase." + name; + String aliasProperty = "codebase." + name.replaceFirst("-\\d+\\.\\d+.*\\.jar", ""); + if (aliasProperty.equals(property) == false) { + + Object previous = codebaseProperties.put(aliasProperty, url.toString()); + if (previous != null) { + throw new IllegalStateException( + "codebase property already set: " + aliasProperty + " -> " + previous + ", cannot set to " + url.toString() + ); + } + } + Object previous = codebaseProperties.put(property, url.toString()); + if (previous != null) { + throw new IllegalStateException( + "codebase property already set: " + property + " -> " + previous + ", cannot set to " + url.toString() + ); + } + } + Policy policy = Policy.getInstance("JavaPolicy", new URIParameter(policyFile.toURI())); + if (unknownCodebases.isEmpty() == false) { + throw new IllegalArgumentException( + "Unknown codebases " + + unknownCodebases + + " in policy file [" + + policyFile + + "]" + + "\nAvailable codebases: \n " + + String.join("\n ", codebaseProperties.keySet().stream().sorted().toList()) + ); + } + return policy; + } finally { + System.setProperties(originalProps); + } + } catch (NoSuchAlgorithmException | URISyntaxException e) { + throw new IllegalArgumentException("unable to parse policy file `" + policyFile + "`", e); + } + } + + // package private for tests + static PluginPolicyInfo readPolicyInfo(Path pluginRoot) throws IOException { + Path policyFile = pluginRoot.resolve(PluginDescriptor.ES_PLUGIN_POLICY); + if (Files.exists(policyFile) == false) { + return null; + } + + // first get a list of URLs for the plugins' jars: + // we resolve symlinks so map is keyed on the normalize codebase name + Set jars = new LinkedHashSet<>(); // order is already lost, but some filesystems have it + try (DirectoryStream jarStream = Files.newDirectoryStream(pluginRoot, "*.jar")) { + for (Path jar : jarStream) { + URL url = jar.toRealPath().toUri().toURL(); + if (jars.add(url) == false) { + throw new IllegalStateException("duplicate module/plugin: " + url); + } + } + } + // also add spi jars + // TODO: move this to a shared function, or fix plugin layout to have jar files in lib directory + Path spiDir = pluginRoot.resolve("spi"); + if (Files.exists(spiDir)) { + try (DirectoryStream jarStream = Files.newDirectoryStream(spiDir, "*.jar")) { + for (Path jar : jarStream) { + URL url = jar.toRealPath().toUri().toURL(); + if (jars.add(url) == false) { + throw new IllegalStateException("duplicate module/plugin: " + url); + } + } + } + } + + // parse the plugin's policy file into a set of permissions + Policy policy = readPolicy(policyFile.toUri().toURL(), getCodebaseJarMap(jars)); + + return new PluginPolicyInfo(policyFile, jars, policy); + } + + private static void validatePolicyPermissionsForJar( + String type, + Path file, + URL jar, + Policy policy, + PermissionMatcher allowedPermissions, + Path tmpDir + ) throws IOException { + Set jarPermissions = getPolicyPermissions(jar, policy, tmpDir); + for (Permission permission : jarPermissions) { + if (allowedPermissions.test(permission) == false) { + String scope = jar == null ? " in global grant" : " for jar " + jar; + throw new IllegalArgumentException(type + " policy [" + file + "] contains illegal permission " + permission + scope); + } + } + } + + private static void validatePolicyPermissions(String type, PluginPolicyInfo info, PermissionMatcher allowedPermissions, Path tmpDir) + throws IOException { + if (info == null) { + return; + } + validatePolicyPermissionsForJar(type, info.file(), null, info.policy(), allowedPermissions, tmpDir); + for (URL jar : info.jars()) { + validatePolicyPermissionsForJar(type, info.file(), jar, info.policy(), allowedPermissions, tmpDir); + } + } + + /** + * Return info about the security policy for a plugin. + */ + public static PluginPolicyInfo getPluginPolicyInfo(Path pluginRoot, Path tmpDir) throws IOException { + PluginPolicyInfo info = readPolicyInfo(pluginRoot); + validatePolicyPermissions("plugin", info, ALLOWED_PLUGIN_PERMISSIONS, tmpDir); + return info; + } + + /** + * Return info about the security policy for a module. + */ + public static PluginPolicyInfo getModulePolicyInfo(Path moduleRoot, Path tmpDir) throws IOException { + PluginPolicyInfo info = readPolicyInfo(moduleRoot); + validatePolicyPermissions("module", info, ALLOWED_MODULE_PERMISSIONS, tmpDir); + return info; + } + + /** + * Return permissions for a policy that apply to a jar. + * + * @param url The url of a jar to find permissions for, or {@code null} for global permissions. + */ + public static Set getPolicyPermissions(URL url, Policy policy, Path tmpDir) throws IOException { + // create a zero byte file for "comparison" + // this is necessary because the default policy impl automatically grants two permissions: + // 1. permission to exitVM (which we ignore) + // 2. read permission to the code itself (e.g. jar file of the code) + + Path emptyPolicyFile = Files.createTempFile(tmpDir, "empty", "tmp"); + final Policy emptyPolicy; + try { + emptyPolicy = Policy.getInstance("JavaPolicy", new URIParameter(emptyPolicyFile.toUri())); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + IOUtils.rm(emptyPolicyFile); + + final ProtectionDomain protectionDomain; + if (url == null) { + // global, use PolicyUtil since it is part of core ES + protectionDomain = PolicyUtil.class.getProtectionDomain(); + } else { + // we may not have the url loaded, so create a fake protection domain + protectionDomain = new ProtectionDomain(new CodeSource(url, (Certificate[]) null), null); + } + + PermissionCollection permissions = policy.getPermissions(protectionDomain); + // this method is supported with the specific implementation we use, but just check for safety. + if (permissions == Policy.UNSUPPORTED_EMPTY_COLLECTION) { + throw new UnsupportedOperationException("JavaPolicy implementation does not support retrieving permissions"); + } + + Set actualPermissions = new HashSet<>(); + for (Permission permission : Collections.list(permissions.elements())) { + if (emptyPolicy.implies(protectionDomain, permission) == false) { + actualPermissions.add(permission); + } + } + + return actualPermissions; + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java b/server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java index 79e1654aae758..32f57d1cbd1db 100644 --- a/server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java +++ b/server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java @@ -391,7 +391,9 @@ private ReleasableBytesReference maybeSerializeClusterState( } assert clusterState.nodes().isLocalNodeElectedMaster(); - try (var bytesStream = transportService.newNetworkBytesStream()) { + final var bytesStream = transportService.newNetworkBytesStream(); + var success = false; + try { try ( var stream = new OutputStreamStreamOutput( CompressorFactory.COMPRESSOR.threadLocalOutputStream(Streams.flushOnCloseStream(bytesStream)) @@ -402,16 +404,22 @@ private ReleasableBytesReference maybeSerializeClusterState( } catch (IOException e) { throw new ElasticsearchException("failed to serialize cluster state for publishing to node {}", e, discoveryNode); } + final var newBytes = new ReleasableBytesReference(bytesStream.bytes(), bytesStream); logger.trace( "serialized join validation cluster state version [{}] for transport version [{}] with size [{}]", clusterState.version(), version, - bytesStream.position() + newBytes.length() ); - var newBytes = bytesStream.moveToBytesReference(); final var previousBytes = statesByVersion.put(version, newBytes); assert previousBytes == null; + success = true; return newBytes; + } finally { + if (success == false) { + bytesStream.close(); + assert false; + } } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/PublicationTransportHandler.java b/server/src/main/java/org/elasticsearch/cluster/coordination/PublicationTransportHandler.java index 1d4ab1346ef36..af3fdc317c8a7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/coordination/PublicationTransportHandler.java +++ b/server/src/main/java/org/elasticsearch/cluster/coordination/PublicationTransportHandler.java @@ -262,7 +262,9 @@ public PublicationContext newPublicationContext(ClusterStatePublicationEvent clu } private ReleasableBytesReference serializeFullClusterState(ClusterState clusterState, DiscoveryNode node, TransportVersion version) { - try (RecyclerBytesStreamOutput bytesStream = transportService.newNetworkBytesStream()) { + final RecyclerBytesStreamOutput bytesStream = transportService.newNetworkBytesStream(); + boolean success = false; + try { final long uncompressedBytes; try ( StreamOutput stream = new PositionTrackingOutputStreamStreamOutput( @@ -276,15 +278,20 @@ private ReleasableBytesReference serializeFullClusterState(ClusterState clusterS } catch (IOException e) { throw new ElasticsearchException("failed to serialize cluster state for publishing to node {}", e, node); } - final int size = bytesStream.size(); - serializationStatsTracker.serializedFullState(uncompressedBytes, size); + final ReleasableBytesReference result = new ReleasableBytesReference(bytesStream.bytes(), bytesStream); + serializationStatsTracker.serializedFullState(uncompressedBytes, result.length()); logger.trace( "serialized full cluster state version [{}] using transport version [{}] with size [{}]", clusterState.version(), version, - size + result.length() ); - return bytesStream.moveToBytesReference(); + success = true; + return result; + } finally { + if (success == false) { + bytesStream.close(); + } } } @@ -295,7 +302,9 @@ private ReleasableBytesReference serializeDiffClusterState( TransportVersion version ) { final long clusterStateVersion = newState.version(); - try (RecyclerBytesStreamOutput bytesStream = transportService.newNetworkBytesStream()) { + final RecyclerBytesStreamOutput bytesStream = transportService.newNetworkBytesStream(); + boolean success = false; + try { final long uncompressedBytes; try ( StreamOutput stream = new PositionTrackingOutputStreamStreamOutput( @@ -313,15 +322,20 @@ private ReleasableBytesReference serializeDiffClusterState( } catch (IOException e) { throw new ElasticsearchException("failed to serialize cluster state diff for publishing to node {}", e, node); } - final int size = bytesStream.size(); - serializationStatsTracker.serializedDiff(uncompressedBytes, size); + final ReleasableBytesReference result = new ReleasableBytesReference(bytesStream.bytes(), bytesStream); + serializationStatsTracker.serializedDiff(uncompressedBytes, result.length()); logger.trace( "serialized cluster state diff for version [{}] using transport version [{}] with size [{}]", clusterStateVersion, version, - size + result.length() ); - return bytesStream.moveToBytesReference(); + success = true; + return result; + } finally { + if (success == false) { + bytesStream.close(); + } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index c22e799dc9506..495403e963e45 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -27,7 +27,6 @@ import java.util.Objects; import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG; -import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG_8_19; /** * Contains inference field data for fields. @@ -75,8 +74,7 @@ public InferenceFieldMetadata(StreamInput input) throws IOException { this.searchInferenceId = this.inferenceId; } this.sourceFields = input.readStringArray(); - if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG) - || input.getTransportVersion().isPatchFrom(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) { + if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) { this.chunkingSettings = input.readGenericMap(); } else { this.chunkingSettings = null; @@ -91,8 +89,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(searchInferenceId); } out.writeStringArray(sourceFields); - if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG) - || out.getTransportVersion().isPatchFrom(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) { + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) { out.writeGenericMap(chunkingSettings); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java b/server/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java index 6709d0c5f89f5..80146029e0d9d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java @@ -28,7 +28,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -259,14 +258,22 @@ public ShardIterator activeInitializingShardsRankedIt( return new ShardIterator(shardId, ordered); } + private static Set getAllNodeIds(final List shards) { + final Set nodeIds = new HashSet<>(); + for (ShardRouting shard : shards) { + nodeIds.add(shard.currentNodeId()); + } + return nodeIds; + } + private static Map> getNodeStats( - List shardRoutings, + final Set nodeIds, final ResponseCollectorService collector ) { - final Map> nodeStats = new HashMap<>(); - for (ShardRouting shardRouting : shardRoutings) { - nodeStats.computeIfAbsent(shardRouting.currentNodeId(), collector::getNodeStatistics); + final Map> nodeStats = Maps.newMapWithExpectedSize(nodeIds.size()); + for (String nodeId : nodeIds) { + nodeStats.put(nodeId, collector.getNodeStatistics(nodeId)); } return nodeStats; } @@ -335,28 +342,32 @@ private static List rankShardsAndUpdateStats( } // Retrieve which nodes we can potentially send the query to - final Map> nodeStats = getNodeStats(shards, collector); + final Set nodeIds = getAllNodeIds(shards); + final Map> nodeStats = getNodeStats(nodeIds, collector); // Retrieve all the nodes the shards exist on + final Map nodeRanks = rankNodes(nodeStats, nodeSearchCounts); // sort all shards based on the shard rank ArrayList sortedShards = new ArrayList<>(shards); - sortedShards.sort(new NodeRankComparator(rankNodes(nodeStats, nodeSearchCounts))); + Collections.sort(sortedShards, new NodeRankComparator(nodeRanks)); // adjust the non-winner nodes' stats so they will get a chance to receive queries - ShardRouting minShard = sortedShards.get(0); - // If the winning shard is not started we are ranking initializing - // shards, don't bother to do adjustments - if (minShard.started()) { - String minNodeId = minShard.currentNodeId(); - Optional maybeMinStats = nodeStats.get(minNodeId); - if (maybeMinStats.isPresent()) { - adjustStats(collector, nodeStats, minNodeId, maybeMinStats.get()); - // Increase the number of searches for the "winning" node by one. - // Note that this doesn't actually affect the "real" counts, instead - // it only affects the captured node search counts, which is - // captured once for each query in TransportSearchAction - nodeSearchCounts.compute(minNodeId, (id, conns) -> conns == null ? 1 : conns + 1); + if (sortedShards.size() > 1) { + ShardRouting minShard = sortedShards.get(0); + // If the winning shard is not started we are ranking initializing + // shards, don't bother to do adjustments + if (minShard.started()) { + String minNodeId = minShard.currentNodeId(); + Optional maybeMinStats = nodeStats.get(minNodeId); + if (maybeMinStats.isPresent()) { + adjustStats(collector, nodeStats, minNodeId, maybeMinStats.get()); + // Increase the number of searches for the "winning" node by one. + // Note that this doesn't actually affect the "real" counts, instead + // it only affects the captured node search counts, which is + // captured once for each query in TransportSearchAction + nodeSearchCounts.compute(minNodeId, (id, conns) -> conns == null ? 1 : conns + 1); + } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java b/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java index 978e6c19566b1..e140a81bdfbb7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.routing; +import org.apache.lucene.util.CollectionUtil; import org.elasticsearch.cluster.ProjectState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.ProjectMetadata; @@ -18,6 +19,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; @@ -108,9 +110,9 @@ public List searchShards( @Nullable ResponseCollectorService collectorService, @Nullable Map nodeCounts ) { - final Set shards = computeTargetedShards(projectState, concreteIndices, routing); + Set shards = computeTargetedShards(projectState, concreteIndices, routing); DiscoveryNodes nodes = projectState.cluster().nodes(); - List res = new ArrayList<>(shards.size()); + Set set = Sets.newHashSetWithExpectedSize(shards.size()); for (IndexShardRoutingTable shard : shards) { ShardIterator iterator = preferenceActiveShardIterator( shard, @@ -121,10 +123,11 @@ public List searchShards( nodeCounts ); if (iterator != null) { - res.add(ShardIterator.allSearchableShards(iterator)); + set.add(ShardIterator.allSearchableShards(iterator)); } } - res.sort(ShardIterator::compareTo); + List res = new ArrayList<>(set); + CollectionUtil.timSort(res); return res; } diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java index 938aa1ddbc9fc..9a368483d46c0 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java @@ -18,6 +18,8 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Iterator; import java.util.List; @@ -55,11 +57,14 @@ public int bufferSizeInBytes() { public BlobContainer blobContainer(BlobPath path) { Path f = buildPath(path); if (readOnly == false) { - try { - Files.createDirectories(f); - } catch (IOException ex) { - throw new ElasticsearchException("failed to create blob container", ex); - } + AccessController.doPrivileged((PrivilegedAction) () -> { + try { + Files.createDirectories(f); + } catch (IOException ex) { + throw new ElasticsearchException("failed to create blob container", ex); + } + return null; + }); } return new FsBlobContainer(this, path, f); } diff --git a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java index ff8e68d462829..4343892428c9a 100644 --- a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java +++ b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java @@ -29,7 +29,7 @@ public final class ReleasableBytesReference implements RefCounted, Releasable, B private static final ReleasableBytesReference EMPTY = new ReleasableBytesReference(BytesArray.EMPTY, RefCounted.ALWAYS_REFERENCED); - private BytesReference delegate; + private final BytesReference delegate; private final RefCounted refCounted; public static ReleasableBytesReference empty() { @@ -63,24 +63,16 @@ public boolean tryIncRef() { @Override public boolean decRef() { - boolean res = refCounted.decRef(); - if (res) { - delegate = null; - } - return res; + return refCounted.decRef(); } @Override public boolean hasReferences() { - boolean hasRef = refCounted.hasReferences(); - // delegate is nulled out when the ref-count reaches zero but only via a plain store, and also we could be racing with a concurrent - // decRef so need to check #refCounted again in case we run into a non-null delegate but saw a reference before - assert delegate != null || refCounted.hasReferences() == false; - return hasRef; + return refCounted.hasReferences(); } public ReleasableBytesReference retain() { - refCounted.mustIncRef(); + refCounted.incRef(); return this; } @@ -90,7 +82,6 @@ public ReleasableBytesReference retain() { * retaining unnecessary buffers. */ public ReleasableBytesReference retainedSlice(int from, int length) { - assert hasReferences(); if (from == 0 && length() == length) { return retain(); } @@ -145,7 +136,6 @@ public int indexOf(byte marker, int from) { @Override public int length() { - assert hasReferences(); return delegate.length(); } @@ -164,7 +154,6 @@ public ReleasableBytesReference slice(int from, int length) { @Override public long ramBytesUsed() { - assert hasReferences(); return delegate.ramBytesUsed(); } @@ -244,7 +233,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public boolean isFragment() { - assert hasReferences(); return delegate.isFragment(); } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java index 1dabad2d62e4c..ddf97a5151690 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; -import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -38,7 +37,7 @@ public class RecyclerBytesStreamOutput extends BytesStream implements Releasable static final VarHandle VH_BE_LONG = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN); static final VarHandle VH_LE_LONG = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); - private ArrayList> pages = new ArrayList<>(); + private final ArrayList> pages = new ArrayList<>(); private final Recycler recycler; private final int pageSize; private int pageIndex = -1; @@ -238,26 +237,13 @@ public void skip(int length) { @Override public void close() { - var pages = this.pages; - if (pages != null) { - this.pages = null; + try { Releasables.close(pages); + } finally { + pages.clear(); } } - /** - * Move the contents written to this stream to a {@link ReleasableBytesReference}. Closing this instance becomes a noop after - * this method returns successfully and its buffers need to be released by releasing the returned bytes reference. - * - * @return a {@link ReleasableBytesReference} that must be released once no longer needed - */ - public ReleasableBytesReference moveToBytesReference() { - var bytes = bytes(); - var pages = this.pages; - this.pages = null; - return new ReleasableBytesReference(bytes, () -> Releasables.close(pages)); - } - /** * Returns the current size of the buffer. * diff --git a/server/src/main/java/org/elasticsearch/common/logging/DeprecationLogger.java b/server/src/main/java/org/elasticsearch/common/logging/DeprecationLogger.java index 8f778f8c05209..ef5b318a8b426 100644 --- a/server/src/main/java/org/elasticsearch/common/logging/DeprecationLogger.java +++ b/server/src/main/java/org/elasticsearch/common/logging/DeprecationLogger.java @@ -15,6 +15,8 @@ import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Collections; import java.util.List; @@ -117,11 +119,18 @@ private DeprecationLogger logDeprecation(Level level, DeprecationCategory catego String opaqueId = HeaderWarning.getXOpaqueId(); String productOrigin = HeaderWarning.getProductOrigin(); ESLogMessage deprecationMessage = DeprecatedMessage.of(category, key, opaqueId, productOrigin, msg, params); - logger.log(level, deprecationMessage); + doPrivilegedLog(level, deprecationMessage); } return this; } + private void doPrivilegedLog(Level level, ESLogMessage deprecationMessage) { + AccessController.doPrivileged((PrivilegedAction) () -> { + logger.log(level, deprecationMessage); + return null; + }); + } + /** * Used for handling previous version RestApiCompatible logic. * Logs a message at the {@link DeprecationLogger#CRITICAL} level diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index a4b239c10ba6a..a84bc7d00578c 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -36,7 +36,6 @@ import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.index.store.FsDirectoryFactory; import org.elasticsearch.index.store.Store; @@ -158,7 +157,6 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.INDEX_TRANSLOG_RETENTION_AGE_SETTING, IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING, IndexSettings.INDEX_SEARCH_IDLE_AFTER, - DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY, IndexSettings.IGNORE_ABOVE_SETTING, FieldMapper.IGNORE_MALFORMED_SETTING, diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java index 4fd5225a29167..28849a825bf25 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java @@ -17,6 +17,8 @@ import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.node.Node; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.List; import java.util.Optional; import java.util.concurrent.AbstractExecutorService; @@ -391,9 +393,11 @@ static class EsThreadFactory implements ThreadFactory { @Override public Thread newThread(Runnable r) { - Thread t = new EsThread(group, r, namePrefix + "[T#" + threadNumber.getAndIncrement() + "]", 0, isSystem); - t.setDaemon(true); - return t; + return AccessController.doPrivileged((PrivilegedAction) () -> { + Thread t = new EsThread(group, r, namePrefix + "[T#" + threadNumber.getAndIncrement() + "]", 0, isSystem); + t.setDaemon(true); + return t; + }); } } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java index 2b1a5ff6e9c0c..71f57bcc16754 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutor.java @@ -10,17 +10,9 @@ package org.elasticsearch.common.util.concurrent; import org.elasticsearch.common.ExponentiallyWeightedMovingAverage; -import org.elasticsearch.common.metrics.ExponentialBucketHistogram; import org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.telemetry.metric.DoubleWithAttributes; -import org.elasticsearch.telemetry.metric.Instrument; -import org.elasticsearch.telemetry.metric.LongWithAttributes; -import org.elasticsearch.telemetry.metric.MeterRegistry; -import org.elasticsearch.threadpool.ThreadPool; - -import java.util.Arrays; -import java.util.List; + import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; @@ -30,17 +22,11 @@ import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; -import static org.elasticsearch.threadpool.ThreadPool.THREAD_POOL_METRIC_NAME_QUEUE_TIME; -import static org.elasticsearch.threadpool.ThreadPool.THREAD_POOL_METRIC_NAME_UTILIZATION; - /** * An extension to thread pool executor, which tracks statistics for the task execution time. */ public final class TaskExecutionTimeTrackingEsThreadPoolExecutor extends EsThreadPoolExecutor { - public static final int QUEUE_LATENCY_HISTOGRAM_BUCKETS = 18; - private static final int[] LATENCY_PERCENTILES_TO_REPORT = { 50, 90, 99 }; - private final Function runnableWrapper; private final ExponentiallyWeightedMovingAverage executionEWMA; private final LongAdder totalExecutionTime = new LongAdder(); @@ -49,7 +35,6 @@ public final class TaskExecutionTimeTrackingEsThreadPoolExecutor extends EsThrea private final Map ongoingTasks = new ConcurrentHashMap<>(); private volatile long lastPollTime = System.nanoTime(); private volatile long lastTotalExecutionTime = 0; - private final ExponentialBucketHistogram queueLatencyMillisHistogram = new ExponentialBucketHistogram(QUEUE_LATENCY_HISTOGRAM_BUCKETS); TaskExecutionTimeTrackingEsThreadPoolExecutor( String name, @@ -70,36 +55,6 @@ public final class TaskExecutionTimeTrackingEsThreadPoolExecutor extends EsThrea this.trackOngoingTasks = trackingConfig.trackOngoingTasks(); } - public List setupMetrics(MeterRegistry meterRegistry, String threadPoolName) { - return List.of( - meterRegistry.registerLongsGauge( - ThreadPool.THREAD_POOL_METRIC_PREFIX + threadPoolName + THREAD_POOL_METRIC_NAME_QUEUE_TIME, - "Time tasks spent in the queue for the " + threadPoolName + " thread pool", - "milliseconds", - () -> { - long[] snapshot = queueLatencyMillisHistogram.getSnapshot(); - int[] bucketUpperBounds = queueLatencyMillisHistogram.calculateBucketUpperBounds(); - List metricValues = Arrays.stream(LATENCY_PERCENTILES_TO_REPORT) - .mapToObj( - percentile -> new LongWithAttributes( - queueLatencyMillisHistogram.getPercentile(percentile / 100f, snapshot, bucketUpperBounds), - Map.of("percentile", String.valueOf(percentile)) - ) - ) - .toList(); - queueLatencyMillisHistogram.clear(); - return metricValues; - } - ), - meterRegistry.registerDoubleGauge( - ThreadPool.THREAD_POOL_METRIC_PREFIX + threadPoolName + THREAD_POOL_METRIC_NAME_UTILIZATION, - "fraction of maximum thread time utilized for " + threadPoolName, - "fraction", - () -> new DoubleWithAttributes(pollUtilization(), Map.of()) - ) - ); - } - @Override protected Runnable wrapRunnable(Runnable command) { return super.wrapRunnable(this.runnableWrapper.apply(command)); @@ -161,12 +116,6 @@ protected void beforeExecute(Thread t, Runnable r) { if (trackOngoingTasks) { ongoingTasks.put(r, System.nanoTime()); } - assert super.unwrap(r) instanceof TimedRunnable : "expected only TimedRunnables in queue"; - final TimedRunnable timedRunnable = (TimedRunnable) super.unwrap(r); - timedRunnable.beforeExecute(); - final long taskQueueLatency = timedRunnable.getQueueTimeNanos(); - assert taskQueueLatency >= 0; - queueLatencyMillisHistogram.addObservation(TimeUnit.NANOSECONDS.toMillis(taskQueueLatency)); } @Override diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/TimedRunnable.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/TimedRunnable.java index de89ad0d8ea3f..63fbee7999324 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/TimedRunnable.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/TimedRunnable.java @@ -18,7 +18,6 @@ class TimedRunnable extends AbstractRunnable implements WrappedRunnable { private final Runnable original; private final long creationTimeNanos; - private long beforeExecuteTime = -1; private long startTimeNanos; private long finishTimeNanos = -1; private boolean failedOrRejected = false; @@ -59,19 +58,6 @@ public boolean isForceExecution() { return original instanceof AbstractRunnable && ((AbstractRunnable) original).isForceExecution(); } - /** - * Returns the time in nanoseconds between the creation time and the execution time - * - * @return The time in nanoseconds or -1 if the task was never de-queued - */ - long getQueueTimeNanos() { - if (beforeExecuteTime == -1) { - assert false : "beforeExecute must be called before getQueueTimeNanos"; - return -1; - } - return beforeExecuteTime - creationTimeNanos; - } - /** * Return the time this task spent being run. * If the task is still running or has not yet been run, returns -1. @@ -84,13 +70,6 @@ long getTotalExecutionNanos() { return Math.max(finishTimeNanos - startTimeNanos, 1); } - /** - * Called when the task has reached the front of the queue and is about to be executed - */ - public void beforeExecute() { - beforeExecuteTime = System.nanoTime(); - } - /** * If the task was failed or rejected, return true. * Otherwise, false. diff --git a/server/src/main/java/org/elasticsearch/env/NodeEnvironment.java b/server/src/main/java/org/elasticsearch/env/NodeEnvironment.java index ddc36cc81dda1..febde6b6a69ac 100644 --- a/server/src/main/java/org/elasticsearch/env/NodeEnvironment.java +++ b/server/src/main/java/org/elasticsearch/env/NodeEnvironment.java @@ -1501,15 +1501,10 @@ private static void tryWriteTempFile(Path path) throws IOException { /** * Get a useful version string to direct a user's downgrade operation - *

- * Assuming that the index was compatible with {@code previousNodeVersion}, - * the user should downgrade to that {@code previousNodeVersion}, - * unless it's prior to the minimum compatible version, - * in which case the user should downgrade to that instead. - * (If the index version is so old that the minimum compatible version is incompatible with the index, - * then the cluster was already borked before the node upgrade began, - * and we can't probably help them without more info than we have here.) * + *

If a user is trying to install current major N but has incompatible indices, the user should + * downgrade to last minor of the previous major (N-1).last. We return (N-1).last, unless the user is trying to upgrade from + * a (N-1).last.x release, in which case we return the last installed version. * @return Version to downgrade to */ // visible for testing diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index 7a3b2ad938d1b..e7ff0f6d1e137 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -29,7 +29,6 @@ import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.SourceFieldMapper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.ingest.IngestService; @@ -897,7 +896,6 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) { private volatile int maxTokenCount; private volatile int maxNgramDiff; private volatile int maxShingleDiff; - private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic; private volatile TimeValue searchIdleAfter; private volatile int maxAnalyzedOffset; private volatile boolean weightMatchesEnabled; @@ -1093,7 +1091,6 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti logsdbAddHostNameField = scopedSettings.get(LOGSDB_ADD_HOST_NAME_FIELD); skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING); skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING); - hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC); indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING); recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings); recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false @@ -1206,7 +1203,6 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti this::setSkipIgnoredSourceWrite ); scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead); - scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic); } private void setSearchIdleAfter(TimeValue searchIdleAfter) { @@ -1825,16 +1821,4 @@ public TimestampBounds getTimestampBounds() { public IndexRouting getIndexRouting() { return indexRouting; } - - /** - * The heuristic to utilize when executing filtered search on vectors indexed - * in HNSW format. - */ - public DenseVectorFieldMapper.FilterHeuristic getHnswFilterHeuristic() { - return this.hnswFilterHeuristic; - } - - private void setHnswFilterHeuristic(DenseVectorFieldMapper.FilterHeuristic heuristic) { - this.hnswFilterHeuristic = heuristic; - } } diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 74eb6721dda81..53d2ad0d19707 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -138,9 +138,6 @@ private static Version parseUnchecked(String version) { public static final IndexVersion USE_SYNTHETIC_SOURCE_FOR_RECOVERY_BY_DEFAULT_BACKPORT = def(8_526_0_00, parseUnchecked("9.12.1")); public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_BACKPORT_8_X = def(8_527_0_00, Version.LUCENE_9_12_1); public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X = def(8_528_0_00, Version.LUCENE_9_12_1); - public static final IndexVersion RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS_BACKPORT_8_X = def(8_529_0_00, Version.LUCENE_9_12_1); - public static final IndexVersion DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ_BACKPORT_8_X = def(8_530_0_00, Version.LUCENE_9_12_1); - public static final IndexVersion SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X = def(8_531_0_00, Version.LUCENE_9_12_1); public static final IndexVersion UPGRADE_TO_LUCENE_10_0_0 = def(9_000_0_00, Version.LUCENE_10_0_0); public static final IndexVersion LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT = def(9_001_0_00, Version.LUCENE_10_0_0); public static final IndexVersion TIME_BASED_K_ORDERED_DOC_ID = def(9_002_0_00, Version.LUCENE_10_0_0); @@ -167,7 +164,6 @@ private static Version parseUnchecked(String version) { public static final IndexVersion UPGRADE_TO_LUCENE_10_2_1 = def(9_023_00_0, Version.LUCENE_10_2_1); public static final IndexVersion DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ = def(9_024_0_00, Version.LUCENE_10_2_1); public static final IndexVersion SEMANTIC_TEXT_DEFAULTS_TO_BBQ = def(9_025_0_00, Version.LUCENE_10_2_1); - public static final IndexVersion DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC = def(9_026_0_00, Version.LUCENE_10_2_1); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/AssertingKnnVectorsReaderReflect.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/AssertingKnnVectorsReaderReflect.java index b22fa88fb49b4..bf47564c11b3a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/AssertingKnnVectorsReaderReflect.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/AssertingKnnVectorsReaderReflect.java @@ -14,6 +14,8 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; +import java.security.AccessController; +import java.security.PrivilegedAction; /** * Reflective access to unwrap non-accessible delegate in AssertingKnnVectorsReader. @@ -50,13 +52,25 @@ private static MethodHandle getDelegateFieldHandle() { if (cls == null) { return MethodHandles.throwException(KnnVectorsReader.class, AssertionError.class); } - var lookup = MethodHandles.privateLookupIn(cls, MethodHandles.lookup()); + var lookup = privilegedPrivateLookupIn(cls, MethodHandles.lookup()); return lookup.findGetter(cls, "delegate", KnnVectorsReader.class); } catch (ReflectiveOperationException e) { throw new AssertionError(e); } } + @SuppressWarnings("removal") + static MethodHandles.Lookup privilegedPrivateLookupIn(Class cls, MethodHandles.Lookup lookup) throws IllegalAccessException { + PrivilegedAction pa = () -> { + try { + return MethodHandles.privateLookupIn(cls, lookup); + } catch (IllegalAccessException e) { + throw new AssertionError("should not happen, check opens", e); + } + }; + return AccessController.doPrivileged(pa); + } + static void handleThrowable(Throwable t) { if (t instanceof Error error) { throw error; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/OffHeapReflectionUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/OffHeapReflectionUtils.java index 49950bb4df4e9..599a205508385 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/OffHeapReflectionUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/reflect/OffHeapReflectionUtils.java @@ -26,6 +26,8 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Map; import static java.lang.invoke.MethodType.methodType; @@ -89,62 +91,62 @@ private OffHeapReflectionUtils() {} try { // Lucene99ScalarQuantizedVectorsReader var cls = Class.forName("org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader$FieldEntry"); - var lookup = MethodHandles.privateLookupIn(L99_SQ_VR_CLS, MethodHandles.lookup()); + var lookup = privilegedPrivateLookupIn(L99_SQ_VR_CLS, MethodHandles.lookup()); var mt = methodType(cls, String.class); GET_FIELD_ENTRY_HNDL_SQ = lookup.findVirtual(L99_SQ_VR_CLS, "getFieldEntry", mt); GET_VECTOR_DATA_LENGTH_HANDLE_SQ = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); RAW_VECTORS_READER_HNDL_SQ = lookup.findVarHandle(L99_SQ_VR_CLS, "rawVectorsReader", FlatVectorsReader.class); // Lucene99FlatVectorsReader cls = Class.forName("org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L99_FLT_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L99_FLT_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class, VectorEncoding.class); GET_FIELD_ENTRY_HANDLE_L99FLT = lookup.findVirtual(L99_FLT_VR_CLS, "getFieldEntry", mt); VECTOR_DATA_LENGTH_HANDLE_L99FLT = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); // DirectIOLucene99FlatVectorsReader cls = Class.forName("org.elasticsearch.index.codec.vectors.es818.DirectIOLucene99FlatVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(DIOL99_FLT_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(DIOL99_FLT_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class, VectorEncoding.class); GET_FIELD_ENTRY_HANDLE_DIOL99FLT = lookup.findVirtual(DIOL99_FLT_VR_CLS, "getFieldEntry", mt); VECTOR_DATA_LENGTH_HANDLE_DIOL99FLT = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); // Lucene99HnswVectorsReader cls = Class.forName("org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L99_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L99_HNSW_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class, VectorEncoding.class); GET_FIELD_ENTRY_HANDLE_L99HNSW = lookup.findVirtual(L99_HNSW_VR_CLS, "getFieldEntry", mt); GET_VECTOR_INDEX_LENGTH_HANDLE_L99HNSW = lookup.findVirtual(cls, "vectorIndexLength", methodType(long.class)); - lookup = MethodHandles.privateLookupIn(L99_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L99_HNSW_VR_CLS, MethodHandles.lookup()); FLAT_VECTORS_READER_HNDL_L99HNSW = lookup.findVarHandle(L99_HNSW_VR_CLS, "flatVectorsReader", FlatVectorsReader.class); // Lucene90HnswVectorsReader cls = Class.forName("org.apache.lucene.backward_codecs.lucene90.Lucene90HnswVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L90_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L90_HNSW_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class); GET_FIELD_ENTRY_HANDLE_L90HNSW = lookup.findVirtual(L90_HNSW_VR_CLS, "getFieldEntry", mt); GET_VECTOR_INDEX_LENGTH_HANDLE_L90HNSW = lookup.findVirtual(cls, "indexDataLength", methodType(long.class)); GET_VECTOR_DATA_LENGTH_HANDLE_L90HNSW = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); // Lucene91HnswVectorsReader cls = Class.forName("org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L91_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L91_HNSW_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class); GET_FIELD_ENTRY_HANDLE_L91HNSW = lookup.findVirtual(L91_HNSW_VR_CLS, "getFieldEntry", mt); GET_VECTOR_INDEX_LENGTH_HANDLE_L91HNSW = lookup.findVirtual(cls, "vectorIndexLength", methodType(long.class)); GET_VECTOR_DATA_LENGTH_HANDLE_L91HNSW = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); // Lucene92HnswVectorsReader cls = Class.forName("org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L92_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L92_HNSW_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class); GET_FIELD_ENTRY_HANDLE_L92HNSW = lookup.findVirtual(L92_HNSW_VR_CLS, "getFieldEntry", mt); GET_VECTOR_INDEX_LENGTH_HANDLE_L92HNSW = lookup.findVirtual(cls, "vectorIndexLength", methodType(long.class)); GET_VECTOR_DATA_LENGTH_HANDLE_L92HNSW = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); // Lucene94HnswVectorsReader cls = Class.forName("org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L94_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L94_HNSW_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class, VectorEncoding.class); GET_FIELD_ENTRY_HANDLE_L94HNSW = lookup.findVirtual(L94_HNSW_VR_CLS, "getFieldEntry", mt); GET_VECTOR_INDEX_LENGTH_HANDLE_L94HNSW = lookup.findVirtual(cls, "vectorIndexLength", methodType(long.class)); GET_VECTOR_DATA_LENGTH_HANDLE_L94HNSW = lookup.findVirtual(cls, "vectorDataLength", methodType(long.class)); // Lucene95HnswVectorsReader cls = Class.forName("org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsReader$FieldEntry"); - lookup = MethodHandles.privateLookupIn(L95_HNSW_VR_CLS, MethodHandles.lookup()); + lookup = privilegedPrivateLookupIn(L95_HNSW_VR_CLS, MethodHandles.lookup()); mt = methodType(cls, String.class, VectorEncoding.class); GET_FIELD_ENTRY_HANDLE_L95HNSW = lookup.findVirtual(L95_HNSW_VR_CLS, "getFieldEntry", mt); GET_VECTOR_INDEX_LENGTH_HANDLE_L95HNSW = lookup.findVirtual(cls, "vectorIndexLength", methodType(long.class)); @@ -276,6 +278,18 @@ static Map getOffHeapByteSizeL95HNSW(Lucene95HnswVectorsReader rea throw new AssertionError("should not reach here"); } + @SuppressWarnings("removal") + private static MethodHandles.Lookup privilegedPrivateLookupIn(Class cls, MethodHandles.Lookup lookup) { + PrivilegedAction pa = () -> { + try { + return MethodHandles.privateLookupIn(cls, lookup); + } catch (IllegalAccessException e) { + throw new AssertionError("should not happen, check opens", e); + } + }; + return AccessController.doPrivileged(pa); + } + private static void handleThrowable(Throwable t) { if (t instanceof Error error) { throw error; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java index 08cef586e1438..d77817915f387 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java @@ -1087,13 +1087,6 @@ public Builder builder(BlockFactory factory, int expectedCount) { * using whatever */ private BlockSourceReader.LeafIteratorLookup blockReaderDisiLookup(BlockLoaderContext blContext) { - if (isSyntheticSource && syntheticSourceDelegate != null) { - // Since we are using synthetic source and a delegate, we can't use this field - // to determine if the delegate has values in the document (f.e. handling of `null` is different - // between text and keyword). - return BlockSourceReader.lookupMatchingAll(); - } - if (isIndexed()) { if (getTextSearchInfo().hasNorms()) { return BlockSourceReader.lookupFromNorms(name()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 95caaa6ccf316..532cdfc08666c 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -33,12 +33,10 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.ParsingException; -import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexVersion; @@ -95,9 +93,9 @@ import java.util.function.Supplier; import java.util.stream.Stream; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_INDEX_VERSION_CREATED; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.index.IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; /** * A {@link FieldMapper} for indexing a dense vector of floats. @@ -111,74 +109,20 @@ public static boolean isNotUnitVector(float magnitude) { return Math.abs(magnitude - 1.0f) > EPS; } - /** - * The heuristic to utilize when executing a filtered search against vectors indexed in an HNSW graph. - */ - public enum FilterHeuristic { - /** - * This heuristic searches the entire graph, doing vector comparisons in all immediate neighbors - * but only collects vectors that match the filtering criteria. - */ - FANOUT { - static final KnnSearchStrategy FANOUT_STRATEGY = new KnnSearchStrategy.Hnsw(0); - - @Override - public KnnSearchStrategy getKnnSearchStrategy() { - return FANOUT_STRATEGY; - } - }, - /** - * This heuristic will only compare vectors that match the filtering criteria. - */ - ACORN { - static final KnnSearchStrategy ACORN_STRATEGY = new KnnSearchStrategy.Hnsw(60); - - @Override - public KnnSearchStrategy getKnnSearchStrategy() { - return ACORN_STRATEGY; - } - }; - - public abstract KnnSearchStrategy getKnnSearchStrategy(); - } - - public static final Setting HNSW_FILTER_HEURISTIC = Setting.enumSetting(FilterHeuristic.class, s -> { - IndexVersion version = SETTING_INDEX_VERSION_CREATED.get(s); - if (version.onOrAfter(IndexVersions.DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC)) { - return FilterHeuristic.ACORN.toString(); - } - return FilterHeuristic.FANOUT.toString(); - }, - "index.dense_vector.hnsw_filter_heuristic", - fh -> {}, - Setting.Property.IndexScope, - Setting.Property.ServerlessPublic, - Setting.Property.Dynamic - ); - private static boolean hasRescoreIndexVersion(IndexVersion version) { return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) || version.between(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0); } - private static boolean allowsZeroRescore(IndexVersion version) { - return version.onOrAfter(IndexVersions.RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS) - || version.between( - IndexVersions.RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS_BACKPORT_8_X, - IndexVersions.UPGRADE_TO_LUCENE_10_0_0 - ); - } - - private static boolean defaultOversampleForBBQ(IndexVersion version) { - return version.onOrAfter(IndexVersions.DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ) - || version.between(IndexVersions.DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0); - } - public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0; public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION; public static final IndexVersion NORMALIZE_COSINE = IndexVersions.NORMALIZED_VECTOR_COSINE; - public static final IndexVersion DEFAULT_TO_INT8 = IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; + public static final IndexVersion DEFAULT_TO_INT8 = DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersions.V_8_9_0; + public static final IndexVersion ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS = IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS; + public static final IndexVersion RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS = + IndexVersions.RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS; + public static final IndexVersion DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ = IndexVersions.DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ; public static final NodeFeature RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING = new NodeFeature("mapper.dense_vector.rescore_vector"); public static final NodeFeature RESCORE_ZERO_VECTOR_QUANTIZED_VECTOR_MAPPING = new NodeFeature( @@ -252,7 +196,7 @@ public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; final boolean indexedByDefault = indexVersionCreated.onOrAfter(INDEXED_BY_DEFAULT_INDEX_VERSION); - final boolean defaultInt8Hnsw = indexVersionCreated.onOrAfter(IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW); + final boolean defaultInt8Hnsw = indexVersionCreated.onOrAfter(DEFAULT_DENSE_VECTOR_TO_INT8_HNSW); this.indexed = Parameter.indexParam(m -> toType(m).fieldType().indexed, indexedByDefault); if (indexedByDefault) { // Only serialize on newer index versions to prevent breaking existing indices when upgrading @@ -1548,7 +1492,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); - if (rescoreVector == null && defaultOversampleForBBQ(indexVersion)) { + if (rescoreVector == null && indexVersion.onOrAfter(DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ)) { rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE); } } @@ -1572,7 +1516,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion); - if (rescoreVector == null && defaultOversampleForBBQ(indexVersion)) { + if (rescoreVector == null && indexVersion.onOrAfter(DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ)) { rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE); } } @@ -2106,7 +2050,7 @@ static RescoreVector fromIndexOptions(Map indexOptionsMap, IndexVersi throw new IllegalArgumentException("Invalid rescore_vector value. Missing required field " + OVERSAMPLE); } float oversampleValue = (float) XContentMapValues.nodeDoubleValue(oversampleNode); - if (oversampleValue == 0 && allowsZeroRescore(indexVersion) == false) { + if (oversampleValue == 0 && indexVersion.before(RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS)) { throw new IllegalArgumentException("oversample must be greater than 1"); } if (oversampleValue < 1 && oversampleValue != 0) { @@ -2258,25 +2202,15 @@ public Query createKnnQuery( Float oversample, Query filter, Float similarityThreshold, - BitSetProducer parentFilter, - DenseVectorFieldMapper.FilterHeuristic heuristic + BitSetProducer parentFilter ) { if (isIndexed() == false) { throw new IllegalArgumentException( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } - KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy(); return switch (getElementType()) { - case BYTE -> createKnnByteQuery( - queryVector.asByteVector(), - k, - numCands, - filter, - similarityThreshold, - parentFilter, - knnSearchStrategy - ); + case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); case FLOAT -> createKnnFloatQuery( queryVector.asFloatVector(), k, @@ -2284,18 +2218,9 @@ public Query createKnnQuery( oversample, filter, similarityThreshold, - parentFilter, - knnSearchStrategy - ); - case BIT -> createKnnBitQuery( - queryVector.asByteVector(), - k, - numCands, - filter, - similarityThreshold, - parentFilter, - knnSearchStrategy + parentFilter ); + case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); }; } @@ -2313,13 +2238,12 @@ private Query createKnnBitQuery( int numCands, Query filter, Float similarityThreshold, - BitSetProducer parentFilter, - KnnSearchStrategy searchStrategy + BitSetProducer parentFilter ) { elementType.checkDimensions(dims, queryVector.length); Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2336,8 +2260,7 @@ private Query createKnnByteQuery( int numCands, Query filter, Float similarityThreshold, - BitSetProducer parentFilter, - KnnSearchStrategy searchStrategy + BitSetProducer parentFilter ) { elementType.checkDimensions(dims, queryVector.length); @@ -2346,8 +2269,8 @@ private Query createKnnByteQuery( elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2365,8 +2288,7 @@ private Query createKnnFloatQuery( Float queryOversample, Query filter, Float similarityThreshold, - BitSetProducer parentFilter, - KnnSearchStrategy knnSearchStrategy + BitSetProducer parentFilter ) { elementType.checkDimensions(dims, queryVector.length); elementType.checkVectorBounds(queryVector); @@ -2400,16 +2322,8 @@ && isNotUnitVector(squaredMagnitude)) { numCands = Math.max(adjustedK, numCands); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery( - name(), - queryVector, - filter, - adjustedK, - numCands, - parentFilter, - knnSearchStrategy - ) - : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy); + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, numCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter); if (rescore) { knnQuery = new RescoreKnnVectorQuery( name(), diff --git a/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java b/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java index b884b2c850cc5..c64bd2aae6975 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java +++ b/server/src/main/java/org/elasticsearch/indices/IndexingMemoryController.java @@ -261,15 +261,15 @@ protected void deactivateThrottling(IndexShard shard) { @Override public void postIndex(ShardId shardId, Engine.Index index, Engine.IndexResult result) { - postOperation(index, result); + postOperation(shardId, index, result); } @Override public void postDelete(ShardId shardId, Engine.Delete delete, Engine.DeleteResult result) { - postOperation(delete, result); + postOperation(shardId, delete, result); } - private void postOperation(Engine.Operation operation, Engine.Result result) { + private void postOperation(ShardId shardId, Engine.Operation operation, Engine.Result result) { recordOperationBytes(operation, result); // Piggy back on indexing threads to write segments. We're not submitting a task to the index threadpool because we want memory to // be reclaimed rapidly. This has the downside of increasing the latency of _bulk requests though. Lucene does the same thing in diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 9ff23a73ff7ad..ef58b1e05e9a8 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -1064,25 +1064,19 @@ public void afterIndexShardClosed(ShardId shardId, IndexShard indexShard, Settin /** * Deletes an index that is not assigned to this node. This method cleans up all disk folders relating to the index * but does not deal with in-memory structures. For those call {@link #removeIndex} - * - * @param reason the reason why this index should be deleted - * @param oldIndexMetadata the index metadata of the index that should be deleted - * @param currentProject the current project metadata which is used to verify that the index does not exist in the project - * anymore - can be null in case the whole project got deleted while there were still indices in it */ @Override - public void deleteUnassignedIndex(String reason, IndexMetadata oldIndexMetadata, @Nullable ProjectMetadata currentProject) { + public void deleteUnassignedIndex(String reason, IndexMetadata oldIndexMetadata, ClusterState clusterState) { if (nodeEnv.hasNodeFile()) { Index index = oldIndexMetadata.getIndex(); try { - if (currentProject != null && currentProject.hasIndex(index)) { - final IndexMetadata currentMetadata = currentProject.index(index); + if (clusterState.metadata().getProject().hasIndex(index)) { + final IndexMetadata currentMetadata = clusterState.metadata().getProject().index(index); throw new IllegalStateException( "Can't delete unassigned index store for [" + index.getName() - + "] - it's still part of project [" - + currentProject.id() - + "] with UUIDs [" + + "] - it's still part " + + "of the cluster state [" + currentMetadata.getIndexUUID() + "] [" + oldIndexMetadata.getIndexUUID() diff --git a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java index c0e897bc34319..3c84d7be8c6b4 100644 --- a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java +++ b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java @@ -400,12 +400,7 @@ private void deleteIndices(final ClusterChangedEvent event) { indexServiceClosedListener = SubscribableListener.nullSuccess(); final IndexMetadata metadata = project.get().index(index); indexSettings = new IndexSettings(metadata, settings); - final var projectId = project.get().id(); - indicesService.deleteUnassignedIndex( - "deleted index in project [" + projectId + "] was not assigned to local node", - metadata, - state.metadata().projects().get(projectId) - ); + indicesService.deleteUnassignedIndex("deleted index was not assigned to local node", metadata, state); } else { // The previous cluster state's metadata also does not contain the index, // which is what happens on node startup when an index was deleted while the @@ -1262,13 +1257,8 @@ U createIndex(IndexMetadata indexMetadata, List builtInIndex /** * Deletes an index that is not assigned to this node. This method cleans up all disk folders relating to the index * but does not deal with in-memory structures. For those call {@link #removeIndex} - * - * @param reason the reason why this index should be deleted - * @param oldIndexMetadata the index metadata of the index that should be deleted - * @param currentProject the current project metadata which is used to verify that the index does not exist in the project - * anymore - can be null in case the whole project got deleted while there were still indices in it */ - void deleteUnassignedIndex(String reason, IndexMetadata oldIndexMetadata, @Nullable ProjectMetadata currentProject); + void deleteUnassignedIndex(String reason, IndexMetadata metadata, ClusterState clusterState); /** * Removes the given index from this service and releases all associated resources. Persistent parts of the index diff --git a/server/src/main/java/org/elasticsearch/plugins/ExtendedPluginsClassLoader.java b/server/src/main/java/org/elasticsearch/plugins/ExtendedPluginsClassLoader.java index 7a78b9fbe7500..d9bf0d653bb62 100644 --- a/server/src/main/java/org/elasticsearch/plugins/ExtendedPluginsClassLoader.java +++ b/server/src/main/java/org/elasticsearch/plugins/ExtendedPluginsClassLoader.java @@ -9,6 +9,8 @@ package org.elasticsearch.plugins; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Collections; import java.util.List; @@ -41,6 +43,8 @@ protected Class findClass(String name) throws ClassNotFoundException { * Return a new classloader across the parent and extended loaders. */ public static ExtendedPluginsClassLoader create(ClassLoader parent, List extendedLoaders) { - return new ExtendedPluginsClassLoader(parent, extendedLoaders); + return AccessController.doPrivileged( + (PrivilegedAction) () -> new ExtendedPluginsClassLoader(parent, extendedLoaders) + ); } } diff --git a/server/src/main/java/org/elasticsearch/plugins/PluginDescriptor.java b/server/src/main/java/org/elasticsearch/plugins/PluginDescriptor.java index 4e6af08a63ffa..ba40e9ad2bdd8 100644 --- a/server/src/main/java/org/elasticsearch/plugins/PluginDescriptor.java +++ b/server/src/main/java/org/elasticsearch/plugins/PluginDescriptor.java @@ -46,6 +46,8 @@ public class PluginDescriptor implements Writeable, ToXContentObject { public static final String STABLE_DESCRIPTOR_FILENAME = "stable-plugin-descriptor.properties"; public static final String NAMED_COMPONENTS_FILENAME = "named_components.json"; + public static final String ES_PLUGIN_POLICY = "plugin-security.policy"; + private static final TransportVersion MODULE_NAME_SUPPORT = TransportVersions.V_8_3_0; private static final TransportVersion BOOTSTRAP_SUPPORT_REMOVED = TransportVersions.V_8_4_0; diff --git a/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java b/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java index c30cc28d2f6fa..ac920d73fc666 100644 --- a/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java +++ b/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java @@ -27,6 +27,8 @@ import java.net.URL; import java.net.URLClassLoader; import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -421,7 +423,7 @@ static LayerAndLoader createModuleLayer( finder, Set.of(moduleName) ); - var controller = ModuleLayer.defineModulesWithOneLoader(configuration, parentLayersOrBoot(parentLayers), parentLoader); + var controller = privilegedDefineModulesWithOneLoader(configuration, parentLayersOrBoot(parentLayers), parentLoader); var pluginModule = controller.layer().findModule(moduleName).get(); ensureEntryPointAccessible(controller, pluginModule, className); // export/open upstream modules to this plugin module @@ -430,7 +432,7 @@ static LayerAndLoader createModuleLayer( addPluginExportsServices(qualifiedExports, controller); enableNativeAccess(moduleName, modulesWithNativeAccess, controller); logger.debug(() -> "Loading bundle: created module layer and loader for module " + moduleName); - return new LayerAndLoader(controller.layer(), controller.layer().findLoader(moduleName)); + return new LayerAndLoader(controller.layer(), privilegedFindLoader(controller.layer(), moduleName)); } /** Determines the module name of the SPI module, given its URL. */ @@ -488,6 +490,18 @@ private static void ensureEntryPointAccessible(Controller controller, Module plu } } + @SuppressWarnings("removal") + static Controller privilegedDefineModulesWithOneLoader(Configuration cf, List parentLayers, ClassLoader parentLoader) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> ModuleLayer.defineModulesWithOneLoader(cf, parentLayers, parentLoader) + ); + } + + @SuppressWarnings("removal") + static ClassLoader privilegedFindLoader(ModuleLayer layer, String name) { + return AccessController.doPrivileged((PrivilegedAction) () -> layer.findLoader(name)); + } + private static List parentLayersOrBoot(List parentLayers) { if (parentLayers == null || parentLayers.isEmpty()) { return List.of(ModuleLayer.boot()); diff --git a/server/src/main/java/org/elasticsearch/plugins/PluginsService.java b/server/src/main/java/org/elasticsearch/plugins/PluginsService.java index 78a8650a5e920..6ef3cd17ba2e9 100644 --- a/server/src/main/java/org/elasticsearch/plugins/PluginsService.java +++ b/server/src/main/java/org/elasticsearch/plugins/PluginsService.java @@ -32,6 +32,8 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -393,7 +395,7 @@ private void loadBundle(PluginLayer pluginLayer, Map loade // Set context class loader to plugin's class loader so that plugins // that have dependencies with their own SPI endpoints have a chance to load // and initialize them appropriately. - Thread.currentThread().setContextClassLoader(pluginLayer.pluginClassLoader()); + privilegedSetContextClassLoader(pluginLayer.pluginClassLoader()); Plugin plugin; if (pluginBundle.pluginDescriptor().isStable()) { @@ -426,7 +428,7 @@ We need to pass a name though so that we can show that a plugin was loaded (via } loadedPlugins.put(name, new LoadedPlugin(pluginBundle.plugin, plugin, pluginLayer.pluginClassLoader())); } finally { - Thread.currentThread().setContextClassLoader(cl); + privilegedSetContextClassLoader(cl); } } @@ -535,4 +537,12 @@ private static String signatureMessage(final Class clazz) { public final Stream filterPlugins(Class type) { return plugins().stream().filter(x -> type.isAssignableFrom(x.instance().getClass())).map(p -> ((T) p.instance())); } + + @SuppressWarnings("removal") + private static void privilegedSetContextClassLoader(ClassLoader loader) { + AccessController.doPrivileged((PrivilegedAction) () -> { + Thread.currentThread().setContextClassLoader(loader); + return null; + }); + } } diff --git a/server/src/main/java/org/elasticsearch/plugins/UberModuleClassLoader.java b/server/src/main/java/org/elasticsearch/plugins/UberModuleClassLoader.java index c47fac279f7e1..5e63f2e0b9aa9 100644 --- a/server/src/main/java/org/elasticsearch/plugins/UberModuleClassLoader.java +++ b/server/src/main/java/org/elasticsearch/plugins/UberModuleClassLoader.java @@ -23,8 +23,10 @@ import java.net.URL; import java.net.URLClassLoader; import java.nio.file.Path; +import java.security.AccessController; import java.security.CodeSigner; import java.security.CodeSource; +import java.security.PrivilegedAction; import java.security.SecureClassLoader; import java.util.Enumeration; import java.util.List; @@ -117,7 +119,7 @@ static UberModuleClassLoader getInstance( Set packageNames = finder.find(moduleName).map(ModuleReference::descriptor).map(ModuleDescriptor::packages).orElseThrow(); - return new UberModuleClassLoader( + PrivilegedAction pa = () -> new UberModuleClassLoader( parent, moduleName, jarUrls.toArray(new URL[0]), @@ -126,6 +128,7 @@ static UberModuleClassLoader getInstance( packageNames, modulesWithNativeAccess ); + return AccessController.doPrivileged(pa); } private static boolean isPackageInLayers(String packageName, ModuleLayer moduleLayer) { @@ -309,12 +312,17 @@ static Path urlToPathUnchecked(URL url) { } @Override + @SuppressWarnings("removal") public void close() throws Exception { - try { - internalLoader.close(); - } catch (IOException e) { - throw new IllegalStateException("Could not close internal URLClassLoader"); - } + PrivilegedAction pa = () -> { + try { + internalLoader.close(); + } catch (IOException e) { + throw new IllegalStateException("Could not close internal URLClassLoader"); + } + return null; + }; + AccessController.doPrivileged(pa); } // visible for testing diff --git a/server/src/main/java/org/elasticsearch/readiness/ReadinessService.java b/server/src/main/java/org/elasticsearch/readiness/ReadinessService.java index 165bcebb80a5d..1a169699d4131 100644 --- a/server/src/main/java/org/elasticsearch/readiness/ReadinessService.java +++ b/server/src/main/java/org/elasticsearch/readiness/ReadinessService.java @@ -32,6 +32,8 @@ import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Collection; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; @@ -120,20 +122,25 @@ ServerSocketChannel setupSocket() { int portNumber = PORT.get(settings); assert portNumber >= 0; - InetSocketAddress socketAddress; - try { - socketAddress = socketAddress(InetAddress.getByName("0"), portNumber); - } catch (IOException e) { - throw new IllegalArgumentException("Failed to resolve readiness host address", e); - } - - try { - serverChannel = socketChannelFactory.get(); + var socketAddress = AccessController.doPrivileged((PrivilegedAction) () -> { try { - serverChannel.bind(socketAddress); + return socketAddress(InetAddress.getByName("0"), portNumber); } catch (IOException e) { - throw new BindTransportException("Failed to bind to " + NetworkAddress.format(socketAddress), e); + throw new IllegalArgumentException("Failed to resolve readiness host address", e); } + }); + + try { + serverChannel = socketChannelFactory.get(); + + AccessController.doPrivileged((PrivilegedAction) () -> { + try { + serverChannel.bind(socketAddress); + } catch (IOException e) { + throw new BindTransportException("Failed to bind to " + NetworkAddress.format(socketAddress), e); + } + return null; + }); // First time bounding the socket, we notify any listeners if (boundSocket.get() == null) { @@ -173,11 +180,14 @@ synchronized void startListener() { assert serverChannel != null; try { while (serverChannel.isOpen()) { - try (SocketChannel channel = serverChannel.accept()) {} catch (IOException e) { - logger.debug("encountered exception while responding to readiness check request", e); - } catch (Exception other) { - logger.warn("encountered unknown exception while responding to readiness check request", other); - } + AccessController.doPrivileged((PrivilegedAction) () -> { + try (SocketChannel channel = serverChannel.accept()) {} catch (IOException e) { + logger.debug("encountered exception while responding to readiness check request", e); + } catch (Exception other) { + logger.warn("encountered unknown exception while responding to readiness check request", other); + } + return null; + }); } } finally { listenerThreadLatch.countDown(); diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java index 89e73673cacc0..694af7e1606cb 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java @@ -166,7 +166,10 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec if (serialization.hasNext() == false) { builder.close(); } - final var result = chunkStream.moveToBytesReference(); + final var result = new ReleasableBytesReference( + chunkStream.bytes(), + () -> Releasables.closeExpectNoException(chunkStream) + ); target = null; return result; } catch (Exception e) { diff --git a/server/src/main/java/org/elasticsearch/search/lookup/LeafDocLookup.java b/server/src/main/java/org/elasticsearch/search/lookup/LeafDocLookup.java index 00fc07043baf5..4eaf5c4bb077f 100644 --- a/server/src/main/java/org/elasticsearch/search/lookup/LeafDocLookup.java +++ b/server/src/main/java/org/elasticsearch/search/lookup/LeafDocLookup.java @@ -19,6 +19,8 @@ import org.elasticsearch.script.field.Field; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Collection; import java.util.Map; import java.util.Set; @@ -40,18 +42,23 @@ Field factories require a privileged action to advance to docids (files could be */ class FieldFactoryWrapper { final DocValuesScriptFieldFactory factory; + private final PrivilegedAction advancer; FieldFactoryWrapper(DocValuesScriptFieldFactory factory) { this.factory = factory; + this.advancer = () -> { + try { + factory.setNextDocId(docId); + } catch (IOException ioe) { + throw ExceptionsHelper.convertToElastic(ioe); + } + return null; + }; } // advances the factory to the current docid for the enclosing LeafDocLookup void advanceToDoc() { - try { - factory.setNextDocId(docId); - } catch (IOException ioe) { - throw ExceptionsHelper.convertToElastic(ioe); - } + AccessController.doPrivileged(this.advancer); } } @@ -94,26 +101,30 @@ private FieldFactoryWrapper getFactoryForField(String fieldName) { throw new IllegalArgumentException("No field found for [" + fieldName + "] in mapping"); } - IndexFieldData indexFieldData = fieldDataLookup.apply(fieldType, SCRIPT); + // Load the field data on behalf of the script. Otherwise, it would require + // additional permissions to deal with pagedbytes/ramusagestimator/etc. + return AccessController.doPrivileged((PrivilegedAction) () -> { + IndexFieldData indexFieldData = fieldDataLookup.apply(fieldType, SCRIPT); - FieldFactoryWrapper docFactory = null; + FieldFactoryWrapper docFactory = null; - if (docFactoryCache.isEmpty() == false) { - docFactory = docFactoryCache.get(fieldName); - } + if (docFactoryCache.isEmpty() == false) { + docFactory = docFactoryCache.get(fieldName); + } - // if this field has already been accessed via the doc-access API and the field-access API - // uses doc values then we share to avoid double-loading - FieldFactoryWrapper fieldFactory; - if (docFactory != null && indexFieldData instanceof SourceValueFetcherIndexFieldData == false) { - fieldFactory = docFactory; - } else { - fieldFactory = new FieldFactoryWrapper(indexFieldData.load(reader).getScriptFieldFactory(fieldName)); - } + // if this field has already been accessed via the doc-access API and the field-access API + // uses doc values then we share to avoid double-loading + FieldFactoryWrapper fieldFactory; + if (docFactory != null && indexFieldData instanceof SourceValueFetcherIndexFieldData == false) { + fieldFactory = docFactory; + } else { + fieldFactory = new FieldFactoryWrapper(indexFieldData.load(reader).getScriptFieldFactory(fieldName)); + } - fieldFactoryCache.put(fieldName, fieldFactory); + fieldFactoryCache.put(fieldName, fieldFactory); - return fieldFactory; + return fieldFactory; + }); } public Field getScriptField(String fieldName) { @@ -135,31 +146,35 @@ private FieldFactoryWrapper getFactoryForDoc(String fieldName) { throw new IllegalArgumentException("No field found for [" + fieldName + "] in mapping"); } - FieldFactoryWrapper docFactory = null; - FieldFactoryWrapper fieldFactory = null; + // Load the field data on behalf of the script. Otherwise, it would require + // additional permissions to deal with pagedbytes/ramusagestimator/etc. + return AccessController.doPrivileged((PrivilegedAction) () -> { + FieldFactoryWrapper docFactory = null; + FieldFactoryWrapper fieldFactory = null; - if (fieldFactoryCache.isEmpty() == false) { - fieldFactory = fieldFactoryCache.get(fieldName); - } + if (fieldFactoryCache.isEmpty() == false) { + fieldFactory = fieldFactoryCache.get(fieldName); + } - if (fieldFactory != null) { - IndexFieldData fieldIndexFieldData = fieldDataLookup.apply(fieldType, SCRIPT); + if (fieldFactory != null) { + IndexFieldData fieldIndexFieldData = fieldDataLookup.apply(fieldType, SCRIPT); - // if this field has already been accessed via the field-access API and the field-access API - // uses doc values then we share to avoid double-loading - if (fieldIndexFieldData instanceof SourceValueFetcherIndexFieldData == false) { - docFactory = fieldFactory; + // if this field has already been accessed via the field-access API and the field-access API + // uses doc values then we share to avoid double-loading + if (fieldIndexFieldData instanceof SourceValueFetcherIndexFieldData == false) { + docFactory = fieldFactory; + } } - } - if (docFactory == null) { - IndexFieldData indexFieldData = fieldDataLookup.apply(fieldType, SEARCH); - docFactory = new FieldFactoryWrapper(indexFieldData.load(reader).getScriptFieldFactory(fieldName)); - } + if (docFactory == null) { + IndexFieldData indexFieldData = fieldDataLookup.apply(fieldType, SEARCH); + docFactory = new FieldFactoryWrapper(indexFieldData.load(reader).getScriptFieldFactory(fieldName)); + } - docFactoryCache.put(fieldName, docFactory); + docFactoryCache.put(fieldName, docFactory); - return docFactory; + return docFactory; + }); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index 46b2f0a09cf7f..b7f129f674036 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -13,7 +13,6 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { @@ -26,10 +25,9 @@ public ESDiversifyingChildrenByteKnnVectorQuery( Query childFilter, Integer k, int numCands, - BitSetProducer parentsFilter, - KnnSearchStrategy strategy + BitSetProducer parentsFilter ) { - super(field, query, childFilter, numCands, parentsFilter, strategy); + super(field, query, childFilter, numCands, parentsFilter); this.kParam = k; } @@ -44,8 +42,4 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } - - public KnnSearchStrategy getStrategy() { - return searchStrategy; - } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index 5635281ab0e8a..cb323bbe3932a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -13,7 +13,6 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider { @@ -26,10 +25,9 @@ public ESDiversifyingChildrenFloatKnnVectorQuery( Query childFilter, Integer k, int numCands, - BitSetProducer parentsFilter, - KnnSearchStrategy strategy + BitSetProducer parentsFilter ) { - super(field, query, childFilter, numCands, parentsFilter, strategy); + super(field, query, childFilter, numCands, parentsFilter); this.kParam = k; } @@ -44,8 +42,4 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } - - public KnnSearchStrategy getStrategy() { - return searchStrategy; - } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 295efd8f9b05e..5c199f42093b1 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -12,15 +12,14 @@ import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; - public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) { - super(field, target, numCands, filter, strategy); + public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter) { + super(field, target, numCands, filter); this.kParam = k; } @@ -40,8 +39,4 @@ public void profile(QueryProfiler queryProfiler) { public Integer kParam() { return kParam; } - - public KnnSearchStrategy getStrategy() { - return searchStrategy; - } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index 8ef4aad147049..b7b9d092ceeac 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -12,15 +12,14 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; - public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) { - super(field, target, numCands, filter, strategy); + public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) { + super(field, target, numCands, filter); this.kParam = k; } @@ -40,8 +39,4 @@ public void profile(QueryProfiler queryProfiler) { public Integer kParam() { return kParam; } - - public KnnSearchStrategy getStrategy() { - return searchStrategy; - } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 87f9a50c64c17..565fd7325a5ac 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -552,17 +552,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } } - DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic(); - return vectorFieldType.createKnnQuery( - queryVector, - k, - adjustedNumCands, - oversample, - filterQuery, - vectorSimilarity, - parentBitSet, - heuristic - ); + + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, oversample, filterQuery, vectorSimilarity, parentBitSet); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index ac5938b047804..51cc87d58fcc2 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -24,7 +24,6 @@ import java.util.Objects; import static org.elasticsearch.TransportVersions.RESCORE_VECTOR_ALLOW_ZERO; -import static org.elasticsearch.TransportVersions.RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19; public class RescoreVectorBuilder implements Writeable, ToXContentObject { @@ -58,9 +57,7 @@ public RescoreVectorBuilder(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { // We don't want to serialize a `0` oversample to a node that doesn't know what to do with it. - if (oversample == NO_OVERSAMPLE - && out.getTransportVersion().before(RESCORE_VECTOR_ALLOW_ZERO) - && out.getTransportVersion().isPatchFrom(RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19) == false) { + if (oversample == NO_OVERSAMPLE && out.getTransportVersion().before(RESCORE_VECTOR_ALLOW_ZERO)) { throw new ElasticsearchStatusException( "[rescore_vector] does not support a 0 for [" + OVERSAMPLE_FIELD.getPreferredName() diff --git a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java index 1287dadd36928..1533e616b8f28 100644 --- a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java @@ -31,6 +31,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.node.Node; import org.elasticsearch.node.ReportingService; +import org.elasticsearch.telemetry.metric.DoubleWithAttributes; import org.elasticsearch.telemetry.metric.Instrument; import org.elasticsearch.telemetry.metric.LongAsyncCounter; import org.elasticsearch.telemetry.metric.LongGauge; @@ -153,7 +154,6 @@ public static class Names { public static final String THREAD_POOL_METRIC_NAME_UTILIZATION = ".threads.utilization.current"; public static final String THREAD_POOL_METRIC_NAME_LARGEST = ".threads.largest.current"; public static final String THREAD_POOL_METRIC_NAME_REJECTED = ".threads.rejected.total"; - public static final String THREAD_POOL_METRIC_NAME_QUEUE_TIME = ".queue.latency.histogram"; public enum ThreadPoolType { FIXED("fixed"), @@ -379,7 +379,14 @@ private static ArrayList setupMetrics(MeterRegistry meterRegistry, S } if (threadPoolExecutor instanceof TaskExecutionTimeTrackingEsThreadPoolExecutor timeTrackingExecutor) { - instruments.addAll(timeTrackingExecutor.setupMetrics(meterRegistry, name)); + instruments.add( + meterRegistry.registerDoubleGauge( + prefix + THREAD_POOL_METRIC_NAME_UTILIZATION, + "fraction of maximum thread time utilized for " + name, + "fraction", + () -> new DoubleWithAttributes(timeTrackingExecutor.pollUtilization(), at) + ) + ); } } return instruments; diff --git a/server/src/main/resources/org/elasticsearch/TransportVersions.csv b/server/src/main/resources/org/elasticsearch/TransportVersions.csv index 53388b45dc15f..3d9dd9bf37c9f 100644 --- a/server/src/main/resources/org/elasticsearch/TransportVersions.csv +++ b/server/src/main/resources/org/elasticsearch/TransportVersions.csv @@ -146,8 +146,5 @@ 8.17.3,8797003 8.17.4,8797004 8.17.5,8797005 -8.17.6,8797006 8.18.0,8840002 -8.18.1,8840003 9.0.0,9000009 -9.0.1,9000010 diff --git a/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv b/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv index c00aa14888fc2..467ef95c740ab 100644 --- a/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv +++ b/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv @@ -146,8 +146,5 @@ 8.17.3,8521000 8.17.4,8521000 8.17.5,8521000 -8.17.6,8521000 8.18.0,8525000 -8.18.1,8525000 9.0.0,9009000 -9.0.1,9009000 diff --git a/server/src/test/java/org/elasticsearch/action/fieldcaps/RequestDispatcherTests.java b/server/src/test/java/org/elasticsearch/action/fieldcaps/RequestDispatcherTests.java index ddbdd849e97bf..f3b9f32ef2930 100644 --- a/server/src/test/java/org/elasticsearch/action/fieldcaps/RequestDispatcherTests.java +++ b/server/src/test/java/org/elasticsearch/action/fieldcaps/RequestDispatcherTests.java @@ -49,7 +49,6 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.query.CoordinatorRewriteContextProvider; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.shard.ShardId; @@ -136,7 +135,6 @@ public void testHappyCluster() throws Exception { mockClusterService(clusterState), transportService, TestProjectResolvers.singleProject(projectId), - coordinatorRewriteContextProvider(), newRandomParentTask(), randomFieldCapRequest(withFilter), OriginalIndices.NONE, @@ -208,7 +206,6 @@ public void testRetryThenOk() throws Exception { mockClusterService(clusterState), transportService, TestProjectResolvers.singleProject(projectId), - coordinatorRewriteContextProvider(), newRandomParentTask(), randomFieldCapRequest(withFilter), OriginalIndices.NONE, @@ -331,7 +328,6 @@ public void testRetryButFails() throws Exception { mockClusterService(clusterState), transportService, TestProjectResolvers.singleProject(projectId), - coordinatorRewriteContextProvider(), newRandomParentTask(), randomFieldCapRequest(withFilter), OriginalIndices.NONE, @@ -456,7 +452,6 @@ public void testSuccessWithAnyMatch() throws Exception { mockClusterService(clusterState), transportService, TestProjectResolvers.singleProject(projectId), - coordinatorRewriteContextProvider(), newRandomParentTask(), randomFieldCapRequest(withFilter), OriginalIndices.NONE, @@ -555,7 +550,6 @@ public void testStopAfterAllShardsUnmatched() throws Exception { mockClusterService(clusterState), transportService, TestProjectResolvers.singleProject(projectId), - coordinatorRewriteContextProvider(), newRandomParentTask(), randomFieldCapRequest(withFilter), OriginalIndices.NONE, @@ -648,7 +642,6 @@ public void testFailWithSameException() throws Exception { mockClusterService(clusterState), transportService, TestProjectResolvers.singleProject(projectId), - coordinatorRewriteContextProvider(), newRandomParentTask(), randomFieldCapRequest(withFilter), OriginalIndices.NONE, @@ -1039,8 +1032,4 @@ static ClusterService mockClusterService(ClusterState clusterState) { when(clusterService.operationRouting()).thenReturn(operationRouting); return clusterService; } - - static CoordinatorRewriteContextProvider coordinatorRewriteContextProvider() { - return mock(CoordinatorRewriteContextProvider.class); - } } diff --git a/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java b/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java index 6ca38b200b55c..07c694e502cff 100644 --- a/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java +++ b/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java @@ -657,6 +657,28 @@ String javaVersion() { } + public void testAllPermissionCheck() throws NodeValidationException { + final AtomicBoolean isAllPermissionGranted = new AtomicBoolean(true); + final BootstrapChecks.AllPermissionCheck allPermissionCheck = new BootstrapChecks.AllPermissionCheck() { + @Override + boolean isAllPermissionGranted() { + return isAllPermissionGranted.get(); + } + }; + + final List checks = Collections.singletonList(allPermissionCheck); + final NodeValidationException e = expectThrows( + NodeValidationException.class, + () -> BootstrapChecks.check(emptyContext, true, checks) + ); + assertThat(e, hasToString(containsString("granting the all permission effectively disables security"))); + assertThat(e.getMessage(), containsString("; for more information see [https://www.elastic.co/docs/")); + + // if all permissions are not granted, nothing should happen + isAllPermissionGranted.set(false); + BootstrapChecks.check(emptyContext, true, checks); + } + public void testAlwaysEnforcedChecks() { final BootstrapCheck check = new BootstrapCheck() { @Override diff --git a/server/src/test/java/org/elasticsearch/bootstrap/ESPolicyTests.java b/server/src/test/java/org/elasticsearch/bootstrap/ESPolicyTests.java new file mode 100644 index 0000000000000..1660eeee837b3 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/bootstrap/ESPolicyTests.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.jdk.RuntimeVersionFeature; +import org.elasticsearch.test.ESTestCase; + +import java.security.AccessControlContext; +import java.security.AccessController; +import java.security.PermissionCollection; +import java.security.Permissions; +import java.security.PrivilegedAction; +import java.security.ProtectionDomain; + +/** + * Tests for ESPolicy + */ +public class ESPolicyTests extends ESTestCase { + + /** + * test restricting privileges to no permissions actually works + */ + public void testRestrictPrivileges() { + assumeTrue( + "test requires security manager", + RuntimeVersionFeature.isSecurityManagerAvailable() && System.getSecurityManager() != null + ); + try { + System.getProperty("user.home"); + } catch (SecurityException e) { + fail("this test needs to be fixed: user.home not available by policy"); + } + + PermissionCollection noPermissions = new Permissions(); + AccessControlContext noPermissionsAcc = new AccessControlContext( + new ProtectionDomain[] { new ProtectionDomain(null, noPermissions) } + ); + try { + AccessController.doPrivileged(new PrivilegedAction() { + public Void run() { + System.getProperty("user.home"); + fail("access should have been denied"); + return null; + } + }, noPermissionsAcc); + } catch (SecurityException expected) { + // expected exception + } + } +} diff --git a/server/src/test/java/org/elasticsearch/common/logging/DeprecationLoggerTests.java b/server/src/test/java/org/elasticsearch/common/logging/DeprecationLoggerTests.java index d891b0bb41198..52439f4b59447 100644 --- a/server/src/test/java/org/elasticsearch/common/logging/DeprecationLoggerTests.java +++ b/server/src/test/java/org/elasticsearch/common/logging/DeprecationLoggerTests.java @@ -17,6 +17,11 @@ import org.elasticsearch.test.ESTestCase; import org.mockito.Mockito; +import java.security.AccessControlContext; +import java.security.AccessController; +import java.security.Permissions; +import java.security.PrivilegedAction; +import java.security.ProtectionDomain; import java.util.concurrent.atomic.AtomicBoolean; import static org.hamcrest.Matchers.equalTo; @@ -70,7 +75,13 @@ public void testLogPermissions() { DeprecationLogger deprecationLogger = DeprecationLogger.getLogger("name"); - deprecationLogger.warn(DeprecationCategory.API, "key", "foo", "bar"); + AccessControlContext noPermissionsAcc = new AccessControlContext( + new ProtectionDomain[] { new ProtectionDomain(null, new Permissions()) } + ); + AccessController.doPrivileged((PrivilegedAction) () -> { + deprecationLogger.warn(DeprecationCategory.API, "key", "foo", "bar"); + return null; + }, noPermissionsAcc); assertThat("supplier called", supplierCalled.get(), is(true)); } finally { LogManager.setFactory(originalFactory); diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java index 7f720721aebf2..b112ac0ceb6be 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/TaskExecutionTimeTrackingEsThreadPoolExecutorTests.java @@ -9,19 +9,11 @@ package org.elasticsearch.common.util.concurrent; -import org.elasticsearch.common.metrics.ExponentialBucketHistogram; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig; -import org.elasticsearch.telemetry.InstrumentType; -import org.elasticsearch.telemetry.Measurement; -import org.elasticsearch.telemetry.RecordingMeterRegistry; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.ThreadPool; -import java.util.List; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -29,7 +21,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; /** * Tests for the automatic queue resizing of the {@code QueueResizingEsThreadPoolExecutorTests} @@ -156,85 +147,6 @@ public void testGetOngoingTasks() throws Exception { executor.awaitTermination(10, TimeUnit.SECONDS); } - public void testQueueLatencyMetrics() { - RecordingMeterRegistry meterRegistry = new RecordingMeterRegistry(); - final var threadPoolName = randomIdentifier(); - var executor = new TaskExecutionTimeTrackingEsThreadPoolExecutor( - threadPoolName, - 1, - 1, - 1000, - TimeUnit.MILLISECONDS, - ConcurrentCollections.newBlockingQueue(), - TimedRunnable::new, - EsExecutors.daemonThreadFactory("queuetest"), - new EsAbortPolicy(), - new ThreadContext(Settings.EMPTY), - new TaskTrackingConfig(true, DEFAULT_EWMA_ALPHA) - ); - executor.setupMetrics(meterRegistry, threadPoolName); - - try { - final var barrier = new CyclicBarrier(2); - final ExponentialBucketHistogram expectedHistogram = new ExponentialBucketHistogram( - TaskExecutionTimeTrackingEsThreadPoolExecutor.QUEUE_LATENCY_HISTOGRAM_BUCKETS - ); - - /* - * The thread pool has a single thread, so we submit a task that will occupy that thread - * and cause subsequent tasks to be queued - */ - Future runningTask = executor.submit(() -> { - safeAwait(barrier); - safeAwait(barrier); - }); - safeAwait(barrier); // wait till the first task starts - expectedHistogram.addObservation(0L); // the first task should not be delayed - - /* - * On each iteration we submit a task - which will be queued because of the - * currently running task, pause for some random interval, then unblock the - * new task by releasing the currently running task. This gives us a lower - * bound for the real delays (the real delays will be greater than or equal - * to the synthetic delays we add, i.e. each percentile should be >= our - * expected values) - */ - for (int i = 0; i < 10; i++) { - Future waitingTask = executor.submit(() -> { - safeAwait(barrier); - safeAwait(barrier); - }); - final long delayTimeMs = randomLongBetween(1, 50); - safeSleep(delayTimeMs); - safeAwait(barrier); // let the running task complete - safeAwait(barrier); // wait for the next task to start - safeGet(runningTask); // ensure previous task is complete - expectedHistogram.addObservation(delayTimeMs); - runningTask = waitingTask; - } - safeAwait(barrier); // let the last task finish - safeGet(runningTask); - meterRegistry.getRecorder().collect(); - - List measurements = meterRegistry.getRecorder() - .getMeasurements( - InstrumentType.LONG_GAUGE, - ThreadPool.THREAD_POOL_METRIC_PREFIX + threadPoolName + ThreadPool.THREAD_POOL_METRIC_NAME_QUEUE_TIME - ); - assertThat(measurements, hasSize(3)); - // we have to use greater than or equal to because the actual delay might be higher than what we imposed - assertThat(getPercentile(measurements, "99"), greaterThanOrEqualTo(expectedHistogram.getPercentile(0.99f))); - assertThat(getPercentile(measurements, "90"), greaterThanOrEqualTo(expectedHistogram.getPercentile(0.9f))); - assertThat(getPercentile(measurements, "50"), greaterThanOrEqualTo(expectedHistogram.getPercentile(0.5f))); - } finally { - ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); - } - } - - private long getPercentile(List measurements, String percentile) { - return measurements.stream().filter(m -> m.attributes().get("percentile").equals(percentile)).findFirst().orElseThrow().getLong(); - } - /** * The returned function outputs a WrappedRunnabled that simulates the case * where {@link TimedRunnable#getTotalExecutionNanos()} always returns {@code timeTakenNanos}. diff --git a/server/src/test/java/org/elasticsearch/env/NodeEnvironmentTests.java b/server/src/test/java/org/elasticsearch/env/NodeEnvironmentTests.java index 9a7fadad33a53..497795ec4d2c5 100644 --- a/server/src/test/java/org/elasticsearch/env/NodeEnvironmentTests.java +++ b/server/src/test/java/org/elasticsearch/env/NodeEnvironmentTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.test.MockLog; import org.elasticsearch.test.NodeRoles; import org.elasticsearch.test.junit.annotations.TestLogging; +import org.hamcrest.Matchers; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -581,9 +582,9 @@ public void testIndexCompatibilityChecks() throws IOException { containsString("it holds metadata for indices with version [" + oldIndexVersion.toReleaseVersion() + "]"), containsString( "Revert this node to version [" - + (previousNodeVersion.onOrAfter(Version.CURRENT.minimumCompatibilityVersion()) - ? previousNodeVersion - : Version.CURRENT.minimumCompatibilityVersion()) + + (previousNodeVersion.major == Version.CURRENT.major + ? Version.CURRENT.minimumCompatibilityVersion() + : previousNodeVersion) + "]" ) ) @@ -638,37 +639,29 @@ public void testSymlinkDataDirectory() throws Exception { } public void testGetBestDowngradeVersion() { - int prev = Version.CURRENT.minimumCompatibilityVersion().major; - int last = Version.CURRENT.minimumCompatibilityVersion().minor; - int old = prev - 1; - - assumeTrue("The current compatibility rules are active only from 8.x onward", prev >= 7); - assertEquals(Version.CURRENT.major - 1, prev); - - assertEquals( - "From an old major, recommend prev.last", - NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString(old + ".0.0")), - BuildVersion.fromString(prev + "." + last + ".0") + assertThat( + NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString("8.18.0")), + Matchers.equalTo(BuildVersion.fromString("8.18.0")) ); - - if (last >= 1) { - assertEquals( - "From an old minor of the previous major, recommend prev.last", - NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString(prev + "." + (last - 1) + ".0")), - BuildVersion.fromString(prev + "." + last + ".0") - ); - } - - assertEquals( - "From an old patch of prev.last, return that version itself", - NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString(prev + "." + last + ".1")), - BuildVersion.fromString(prev + "." + last + ".1") + assertThat( + NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString("8.18.5")), + Matchers.equalTo(BuildVersion.fromString("8.18.5")) ); - - assertEquals( - "From the first version of this major, return that version itself", - NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString(Version.CURRENT.major + ".0.0")), - BuildVersion.fromString(Version.CURRENT.major + ".0.0") + assertThat( + NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString("8.18.12")), + Matchers.equalTo(BuildVersion.fromString("8.18.12")) + ); + assertThat( + NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString("8.19.0")), + Matchers.equalTo(BuildVersion.fromString("8.19.0")) + ); + assertThat( + NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString("8.17.0")), + Matchers.equalTo(BuildVersion.fromString("8.18.0")) + ); + assertThat( + NodeEnvironment.getBestDowngradeVersion(BuildVersion.fromString("7.17.0")), + Matchers.equalTo(BuildVersion.fromString("8.18.0")) ); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/blockloader/TextFieldBlockLoaderTests.java b/server/src/test/java/org/elasticsearch/index/mapper/blockloader/TextFieldBlockLoaderTests.java index 77c42740451ee..5c9acaf18a45d 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/blockloader/TextFieldBlockLoaderTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/blockloader/TextFieldBlockLoaderTests.java @@ -62,8 +62,15 @@ public static Object expectedValue(Map fieldMapping, Object valu if (params.syntheticSource() && testContext.forceFallbackSyntheticSource() == false && usingSyntheticSourceDelegate) { var nullValue = (String) keywordMultiFieldMapping.get("null_value"); + // Due to how TextFieldMapper#blockReaderDisiLookup works this is complicated. + // If we are using lookupMatchingAll() then we'll see all docs, generate synthetic source using syntheticSourceDelegate, + // parse it and see null_value inside. + // But if we are using lookupFromNorms() we will skip the document (since the text field itself does not exist). + // Same goes for lookupFromFieldNames(). + boolean textFieldIndexed = (boolean) fieldMapping.getOrDefault("index", true); + if (value == null) { - if (nullValue != null && nullValue.length() <= (int) ignoreAbove) { + if (textFieldIndexed == false && nullValue != null && nullValue.length() <= (int) ignoreAbove) { return new BytesRef(nullValue); } @@ -75,6 +82,12 @@ public static Object expectedValue(Map fieldMapping, Object valu } var values = (List) value; + + // See note above about TextFieldMapper#blockReaderDisiLookup. + if (textFieldIndexed && values.stream().allMatch(Objects::isNull)) { + return null; + } + var indexed = values.stream() .map(s -> s == null ? nullValue : s) .filter(Objects::nonNull) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index c1c21ccda580a..d57acf786d715 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -904,17 +904,14 @@ public void testRescoreVectorForNonQuantized() { } public void testRescoreVectorOldIndexVersion() { - IndexVersion incompatibleVersion = randomFrom( + IndexVersion incompatibleVersion = IndexVersionUtils.randomVersionBetween( + random(), IndexVersionUtils.randomVersionBetween( random(), IndexVersionUtils.getLowestReadCompatibleVersion(), IndexVersionUtils.getPreviousVersion(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X) ), - IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.UPGRADE_TO_LUCENE_10_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) - ) + IndexVersionUtils.getPreviousVersion(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) ); for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { expectThrows( @@ -935,17 +932,10 @@ public void testRescoreVectorOldIndexVersion() { } public void testRescoreZeroVectorOldIndexVersion() { - IndexVersion incompatibleVersion = randomFrom( - IndexVersionUtils.randomVersionBetween( - random(), - IndexVersionUtils.getLowestReadCompatibleVersion(), - IndexVersionUtils.getPreviousVersion(IndexVersions.RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS_BACKPORT_8_X) - ), - IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.UPGRADE_TO_LUCENE_10_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS) - ) + IndexVersion incompatibleVersion = IndexVersionUtils.randomVersionBetween( + random(), + IndexVersionUtils.getLowestReadCompatibleVersion(), + IndexVersionUtils.getPreviousVersion(DenseVectorFieldMapper.RESCORE_PARAMS_ALLOW_ZERO_TO_QUANTIZED_VECTORS) ); for (String indexType : List.of("int8_hnsw", "int8_flat", "int4_hnsw", "int4_flat", "bbq_hnsw", "bbq_flat")) { expectThrows( @@ -1881,16 +1871,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 128, 0, 0 }), - 3, - 3, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1906,8 +1887,7 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat( @@ -1917,16 +1897,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), - 3, - 3, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1935,16 +1906,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), - 3, - 3, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1960,8 +1922,7 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1975,8 +1936,7 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat( @@ -1993,8 +1953,7 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat( @@ -2028,8 +1987,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -2043,8 +2001,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat( @@ -2061,8 +2018,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5877ce9003ff5..b6df46d17b598 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -15,8 +15,6 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; -import org.apache.lucene.search.knn.KnnSearchStrategy; -import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.mapper.FieldTypeTestCase; @@ -34,10 +32,8 @@ import java.util.Collections; import java.util.List; import java.util.Set; -import java.util.function.Function; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BIT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; @@ -220,16 +216,7 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery( - VectorData.fromFloats(queryVector), - 10, - 10, - null, - null, - null, - producer, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } @@ -253,29 +240,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery( - vectorData, - 10, - 10, - null, - null, - null, - producer, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery( - vectorData, - 10, - 10, - null, - null, - null, - producer, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -343,8 +312,7 @@ public void testFloatCreateKnnQuery() { null, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -365,16 +333,7 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery( - VectorData.fromFloats(queryVector), - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -390,16 +349,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery( - VectorData.fromFloats(new float[BBQ_MIN_DIMS]), - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -420,16 +370,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery( - VectorData.fromFloats(queryVector), - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } @@ -452,16 +393,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery( - vectorData, - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -479,16 +411,7 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery( - VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -504,31 +427,13 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery( - VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery( - new VectorData(null, new byte[] { 0, 0, 0 }), - 10, - 10, - null, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -553,8 +458,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { randomFloatBetween(1.0F, 10.0F, false), null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ); if (elementType == BYTE) { @@ -600,16 +504,7 @@ public void testRescoreOversampleQueryOverrides() { randomIndexOptionsHnswQuantized(new DenseVectorFieldMapper.RescoreVector(randomFloatBetween(1.1f, 9.9f, false))), Collections.emptyMap() ); - Query query = fieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 1, 4, 10 }), - 10, - 100, - 0f, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + Query query = fieldType.createKnnQuery(VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, 0f, null, null, null); assertTrue(query instanceof ESKnnFloatVectorQuery); // verify we can override a `0` to a positive number @@ -623,16 +518,7 @@ public void testRescoreOversampleQueryOverrides() { randomIndexOptionsHnswQuantized(new DenseVectorFieldMapper.RescoreVector(0)), Collections.emptyMap() ); - query = fieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 1, 4, 10 }), - 10, - 100, - 2f, - null, - null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) - ); + query = fieldType.createKnnQuery(VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, 2f, null, null, null); assertTrue(query instanceof RescoreKnnVectorQuery); assertThat(((RescoreKnnVectorQuery) query).k(), equalTo(10)); ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) ((RescoreKnnVectorQuery) query).innerQuery(); @@ -640,55 +526,6 @@ public void testRescoreOversampleQueryOverrides() { } - public void testFilterSearchThreshold() { - List>> cases = List.of( - Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), - Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()), - Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy()) - ); - for (var tuple : cases) { - DenseVectorFieldType fieldType = new DenseVectorFieldType( - "f", - IndexVersion.current(), - tuple.v1(), - tuple.v1() == BIT ? 3 * 8 : 3, - true, - VectorSimilarity.COSINE, - randomIndexOptionsHnswQuantized(), - Collections.emptyMap() - ); - - // Test with a filter search threshold - Query query = fieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 1, 4, 10 }), - 10, - 100, - 0f, - null, - null, - null, - DenseVectorFieldMapper.FilterHeuristic.FANOUT - ); - KnnSearchStrategy strategy = tuple.v2().apply(query); - assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); - assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0)); - - query = fieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 1, 4, 10 }), - 10, - 100, - 0f, - null, - null, - null, - DenseVectorFieldMapper.FilterHeuristic.ACORN - ); - strategy = tuple.v2().apply(query); - assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); - assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60)); - } - } - private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, int k, @@ -705,8 +542,7 @@ private static void checkRescoreQueryParameters( oversample, null, null, - null, - randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + null ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); diff --git a/server/src/test/java/org/elasticsearch/plugins/ExtensionLoaderTests.java b/server/src/test/java/org/elasticsearch/plugins/ExtensionLoaderTests.java index d877f4d5f6bb8..4f9bbc24fd9e3 100644 --- a/server/src/test/java/org/elasticsearch/plugins/ExtensionLoaderTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/ExtensionLoaderTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.plugins; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.PrivilegedOperations.ClosableURLClassLoader; import org.elasticsearch.test.compiler.InMemoryJavaCompiler; import org.elasticsearch.test.jar.JarUtils; @@ -34,7 +35,7 @@ public interface TestService { int getValue(); } - private URLClassLoader buildProviderJar(Map sources) throws Exception { + private ClosableURLClassLoader buildProviderJar(Map sources) throws Exception { var classToBytes = InMemoryJavaCompiler.compile(sources); Map jarEntries = new HashMap<>(); @@ -54,7 +55,7 @@ private URLClassLoader buildProviderJar(Map sources) throw JarUtils.createJarWithEntries(jar, jarEntries); URL[] urls = new URL[] { jar.toUri().toURL() }; - return URLClassLoader.newInstance(urls, this.getClass().getClassLoader()); + return new ClosableURLClassLoader(URLClassLoader.newInstance(urls, this.getClass().getClassLoader())); } private String defineProvider(String name, int value) { @@ -78,7 +79,7 @@ public void testNoProvider() { public void testOneProvider() throws Exception { Map sources = Map.of("p.FooService", defineProvider("FooService", 1)); try (var loader = buildProviderJar(sources)) { - TestService service = ExtensionLoader.loadSingleton(ServiceLoader.load(TestService.class, loader)) + TestService service = ExtensionLoader.loadSingleton(ServiceLoader.load(TestService.class, loader.classloader())) .orElseThrow(AssertionError::new); assertThat(service, not(nullValue())); assertThat(service.getValue(), equalTo(1)); @@ -95,7 +96,7 @@ public void testManyProviders() throws Exception { try (var loader = buildProviderJar(sources)) { var e = expectThrows( IllegalStateException.class, - () -> ExtensionLoader.loadSingleton(ServiceLoader.load(TestService.class, loader)) + () -> ExtensionLoader.loadSingleton(ServiceLoader.load(TestService.class, loader.classloader())) ); assertThat(e.getMessage(), containsString("More than one extension found")); assertThat(e.getMessage(), containsString("TestService")); diff --git a/server/src/test/java/org/elasticsearch/plugins/PluginIntrospectorTests.java b/server/src/test/java/org/elasticsearch/plugins/PluginIntrospectorTests.java index f1d77083228a9..df7a72cfd59ed 100644 --- a/server/src/test/java/org/elasticsearch/plugins/PluginIntrospectorTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/PluginIntrospectorTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; import org.elasticsearch.ingest.Processor; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.PrivilegedOperations; import org.elasticsearch.test.compiler.InMemoryJavaCompiler; import org.elasticsearch.test.jar.JarUtils; @@ -232,8 +233,11 @@ public final class FooPlugin extends q.AbstractFooPlugin { } JarUtils.createJarWithEntries(jar, jarEntries); URL[] urls = new URL[] { jar.toUri().toURL() }; - try (URLClassLoader loader = URLClassLoader.newInstance(urls, PluginIntrospectorTests.class.getClassLoader())) { + URLClassLoader loader = URLClassLoader.newInstance(urls, PluginIntrospectorTests.class.getClassLoader()); + try { assertThat(pluginIntrospector.interfaces(loader.loadClass("r.FooPlugin")), contains("ActionPlugin")); + } finally { + PrivilegedOperations.closeURLClassLoader(loader); } } diff --git a/server/src/test/java/org/elasticsearch/plugins/PluginsLoaderTests.java b/server/src/test/java/org/elasticsearch/plugins/PluginsLoaderTests.java index efc3c069b4ab3..f17132a028391 100644 --- a/server/src/test/java/org/elasticsearch/plugins/PluginsLoaderTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/PluginsLoaderTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.nativeaccess.NativeAccessUtil; import org.elasticsearch.plugin.analysis.CharFilterFactory; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.PrivilegedOperations; import org.elasticsearch.test.compiler.InMemoryJavaCompiler; import org.elasticsearch.test.jar.JarUtils; @@ -350,13 +351,13 @@ static void closePluginLoaders(PluginsLoader pluginsLoader) { pluginsLoader.pluginLayers().forEach(lp -> { if (lp.pluginClassLoader() instanceof URLClassLoader urlClassLoader) { try { - urlClassLoader.close(); + PrivilegedOperations.closeURLClassLoader(urlClassLoader); } catch (IOException unexpected) { throw new UncheckedIOException(unexpected); } } else if (lp.pluginClassLoader() instanceof UberModuleClassLoader loader) { try { - loader.getInternalLoader().close(); + PrivilegedOperations.closeURLClassLoader(loader.getInternalLoader()); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/server/src/test/java/org/elasticsearch/plugins/PluginsServiceTests.java b/server/src/test/java/org/elasticsearch/plugins/PluginsServiceTests.java index d18b7f52b8d08..fa2c800cfbe28 100644 --- a/server/src/test/java/org/elasticsearch/plugins/PluginsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/PluginsServiceTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.plugins.spi.BarTestService; import org.elasticsearch.plugins.spi.TestService; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.PrivilegedOperations; import org.elasticsearch.test.compiler.InMemoryJavaCompiler; import org.elasticsearch.test.jar.JarUtils; @@ -670,11 +671,9 @@ public String name() { } public void testLoadServiceProviders() throws Exception { - - try ( - URLClassLoader fakeClassLoader = buildTestProviderPlugin("integer"); - URLClassLoader fakeClassLoader1 = buildTestProviderPlugin("string") - ) { + URLClassLoader fakeClassLoader = buildTestProviderPlugin("integer"); + URLClassLoader fakeClassLoader1 = buildTestProviderPlugin("string"); + try { @SuppressWarnings("unchecked") Class fakePluginClass = (Class) fakeClassLoader.loadClass("r.FooPlugin"); @SuppressWarnings("unchecked") @@ -700,6 +699,9 @@ public void testLoadServiceProviders() throws Exception { providers = service.loadServiceProviders(TestService.class); assertEquals(0, providers.size()); + } finally { + PrivilegedOperations.closeURLClassLoader(fakeClassLoader); + PrivilegedOperations.closeURLClassLoader(fakeClassLoader1); } } @@ -875,13 +877,13 @@ static void closePluginLoaders(PluginsService pluginService) { for (var lp : pluginService.plugins()) { if (lp.classLoader() instanceof URLClassLoader urlClassLoader) { try { - urlClassLoader.close(); + PrivilegedOperations.closeURLClassLoader(urlClassLoader); } catch (IOException unexpected) { throw new UncheckedIOException(unexpected); } } else if (lp.classLoader() instanceof UberModuleClassLoader loader) { try { - loader.getInternalLoader().close(); + PrivilegedOperations.closeURLClassLoader(loader.getInternalLoader()); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/server/src/test/java/org/elasticsearch/search/lookup/LeafDocLookupTests.java b/server/src/test/java/org/elasticsearch/search/lookup/LeafDocLookupTests.java index 895206b551250..6ddffbef37f7d 100644 --- a/server/src/test/java/org/elasticsearch/search/lookup/LeafDocLookupTests.java +++ b/server/src/test/java/org/elasticsearch/search/lookup/LeafDocLookupTests.java @@ -24,6 +24,10 @@ import org.junit.Before; import java.io.IOException; +import java.security.AccessControlContext; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.security.ProtectionDomain; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -423,7 +427,12 @@ public void testParallelCache() { public void testLookupPrivilegesAdvanceDoc() { nextDocCallback = i -> SpecialPermission.check(); - ScriptDocValues fetchedDocValues = docLookup.get("field"); - assertEquals(docValues, fetchedDocValues); + // mimic the untrusted codebase, which gets no permissions + var restrictedContext = new AccessControlContext(new ProtectionDomain[] { new ProtectionDomain(null, null) }); + AccessController.doPrivileged((PrivilegedAction) () -> { + ScriptDocValues fetchedDocValues = docLookup.get("field"); + assertEquals(docValues, fetchedDocValues); + return null; + }, restrictedContext); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 27549b3c4030b..9499edc71b4a6 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -13,7 +13,6 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.support.PlainActionFuture; @@ -22,7 +21,6 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.query.InnerHitsRewriteContext; @@ -218,29 +216,9 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que numCands = Math.max(numCands, k); } - final KnnSearchStrategy expectedStrategy = context.getIndexSettings() - .getIndexVersionCreated() - .onOrAfter(IndexVersions.DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC) - ? DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() - : DenseVectorFieldMapper.FilterHeuristic.FANOUT.getKnnSearchStrategy(); - Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE, BIT -> new ESKnnByteVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asByteVector(), - k, - numCands, - filterQuery, - expectedStrategy - ); - case FLOAT -> new ESKnnFloatVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asFloatVector(), - k, - numCands, - filterQuery, - expectedStrategy - ); + case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery); + case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportHandshakerRawMessageTests.java b/server/src/test/java/org/elasticsearch/transport/TransportHandshakerRawMessageTests.java index 58ca2cbab9530..89e929d7029f6 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportHandshakerRawMessageTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportHandshakerRawMessageTests.java @@ -27,6 +27,8 @@ import java.net.ServerSocket; import java.net.Socket; import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.greaterThan; @@ -191,6 +193,8 @@ public void testOutboundHandshake() throws Exception { private Socket openTransportConnection() throws Exception { final var transportAddress = randomFrom(getInstanceFromNode(TransportService.class).boundAddress().boundAddresses()).address(); - return new Socket(transportAddress.getAddress(), transportAddress.getPort()); + return AccessController.doPrivileged( + (PrivilegedExceptionAction) (() -> new Socket(transportAddress.getAddress(), transportAddress.getPort())) + ); } } diff --git a/test/external-modules/jvm-crash/src/main/plugin-metadata/plugin-security.policy b/test/external-modules/jvm-crash/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..860ae72b058db --- /dev/null +++ b/test/external-modules/jvm-crash/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,6 @@ +grant { + // various permissions to fiddle with Unsafe + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + permission java.lang.RuntimePermission "accessClassInPackage.sun.misc"; +}; diff --git a/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java b/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java index 318f2ce863173..09735ad0c3d5c 100644 --- a/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java +++ b/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java @@ -16,10 +16,15 @@ import org.elasticsearch.core.Booleans; import org.elasticsearch.core.PathUtils; import org.elasticsearch.jdk.JarHell; +import org.elasticsearch.test.PrivilegedOperations; +import org.elasticsearch.test.mockito.SecureMockMaker; import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Map; import java.util.Objects; /** @@ -69,10 +74,54 @@ public class BootstrapForTesting { throw new RuntimeException("found jar hell in test classpath", e); } + // init mockito + SecureMockMaker.init(); + + // init the privileged operation + try { + MethodHandles.publicLookup().ensureInitialized(PrivilegedOperations.class); + } catch (IllegalAccessException unexpected) { + throw new AssertionError(unexpected); + } + // Log ifconfig output before SecurityManager is installed IfConfig.logIfNecessary(); } + static Map getCodebases() { + Map codebases = PolicyUtil.getCodebaseJarMap(JarHell.parseClassPath()); + // when testing server, the main elasticsearch code is not yet in a jar, so we need to manually add it + addClassCodebase(codebases, "elasticsearch", "org.elasticsearch.plugins.PluginsService"); + addClassCodebase(codebases, "elasticsearch-plugin-classloader", "org.elasticsearch.plugins.loader.ExtendedPluginsClassLoader"); + addClassCodebase(codebases, "elasticsearch-nio", "org.elasticsearch.nio.ChannelFactory"); + addClassCodebase(codebases, "elasticsearch-secure-sm", "org.elasticsearch.secure_sm.SecureSM"); + addClassCodebase(codebases, "elasticsearch-rest-client", "org.elasticsearch.client.RestClient"); + addClassCodebase(codebases, "elasticsearch-core", "org.elasticsearch.core.Booleans"); + addClassCodebase(codebases, "elasticsearch-cli", "org.elasticsearch.cli.Command"); + addClassCodebase(codebases, "elasticsearch-simdvec", "org.elasticsearch.simdvec.VectorScorerFactory"); + addClassCodebase(codebases, "framework", "org.elasticsearch.test.ESTestCase"); + return codebases; + } + + /** Add the codebase url of the given classname to the codebases map, if the class exists. */ + private static void addClassCodebase(Map codebases, String name, String classname) { + try { + if (codebases.containsKey(name)) { + return; // the codebase already exists, from the classpath + } + Class clazz = BootstrapForTesting.class.getClassLoader().loadClass(classname); + URL location = clazz.getProtectionDomain().getCodeSource().getLocation(); + if (location.toString().endsWith(".jar") == false) { + if (codebases.put(name, location) != null) { + throw new IllegalStateException("Already added " + name + " codebase for testing"); + } + } + } catch (ClassNotFoundException e) { + // no class, fall through to not add. this can happen for any tests that do not include + // the given class. eg only core tests include plugin-classloader + } + } + // does nothing, just easy way to make sure the class is loaded. public static void ensureInitialized() {} } diff --git a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java index 03d6ac6342b42..c786fcd3f288f 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RoutingNode; @@ -214,7 +213,7 @@ public IndexMetadata verifyIndexIsDeleted(Index index, ClusterState state) { } @Override - public void deleteUnassignedIndex(String reason, IndexMetadata oldIndexMetadata, ProjectMetadata currentProject) { + public void deleteUnassignedIndex(String reason, IndexMetadata metadata, ClusterState clusterState) { } diff --git a/test/framework/src/main/java/org/elasticsearch/test/PrivilegedOperations.java b/test/framework/src/main/java/org/elasticsearch/test/PrivilegedOperations.java new file mode 100644 index 0000000000000..275adf47d3637 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/PrivilegedOperations.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test; + +import org.elasticsearch.core.SuppressForbidden; + +import java.io.FilePermission; +import java.io.IOException; +import java.net.URLClassLoader; +import java.security.AccessControlContext; +import java.security.AccessController; +import java.security.CodeSigner; +import java.security.CodeSource; +import java.security.DomainCombiner; +import java.security.Permission; +import java.security.PermissionCollection; +import java.security.PrivilegedAction; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.security.ProtectionDomain; +import java.util.Enumeration; + +import javax.tools.JavaCompiler; + +/** + * A small set of privileged operations that can be executed by unprivileged test code. + * The set of operations is deliberately small, and the permissions narrow. + */ +public final class PrivilegedOperations { + + private PrivilegedOperations() {} + + public static void closeURLClassLoader(URLClassLoader loader) throws IOException { + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + loader.close(); + return null; + }, context, new RuntimePermission("closeClassLoader")); + } catch (PrivilegedActionException pae) { + Exception e = pae.getException(); + if (e instanceof IOException ioe) { + throw ioe; + } else { + throw new IOException(e); + } + } + } + + public record ClosableURLClassLoader(URLClassLoader classloader) implements AutoCloseable { + @Override + public void close() throws Exception { + closeURLClassLoader(classloader); + } + } + + public static Boolean compilationTaskCall(JavaCompiler.CompilationTask compilationTask) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> compilationTask.call(), + context, + new RuntimePermission("createClassLoader"), + new RuntimePermission("closeClassLoader"), + new RuntimePermission("accessSystemModules"), + newAllFilesReadPermission() + ); + } + + @SuppressForbidden(reason = "need to create file permission") + private static FilePermission newAllFilesReadPermission() { + return new FilePermission("<>", "read"); + } + + // -- security manager related stuff, to facilitate asserting permissions for test operations. + + @SuppressWarnings("removal") + private static AccessControlContext getContext() { + ProtectionDomain[] pda = new ProtectionDomain[] { + new ProtectionDomain(new CodeSource(null, (CodeSigner[]) null), new PermissivePermissionCollection()) }; + DomainCombiner combiner = (ignoreCurrent, ignoreAssigned) -> pda; + AccessControlContext acc = new AccessControlContext(AccessController.getContext(), combiner); + // getContext must be called with the new acc so that a combined context will be created + return AccessController.doPrivileged((PrivilegedAction) AccessController::getContext, acc); + } + + // An all-powerful context for wrapping calls + @SuppressWarnings("removal") + private static final AccessControlContext context = getContext(); + + // A permissive permission collection - implies all permissions. + private static final class PermissivePermissionCollection extends PermissionCollection { + + private PermissivePermissionCollection() {} + + @Override + public void add(Permission permission) {} + + @Override + public boolean implies(Permission permission) { + return true; + } + + @Override + public Enumeration elements() { + return null; + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/compiler/InMemoryJavaCompiler.java b/test/framework/src/main/java/org/elasticsearch/test/compiler/InMemoryJavaCompiler.java index 106b27a9172ab..185fe32ec3bc0 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/compiler/InMemoryJavaCompiler.java +++ b/test/framework/src/main/java/org/elasticsearch/test/compiler/InMemoryJavaCompiler.java @@ -9,6 +9,8 @@ package org.elasticsearch.test.compiler; +import org.elasticsearch.test.PrivilegedOperations; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -135,7 +137,7 @@ public static Map compile(Map sources, Str try (FileManagerWrapper wrapper = new FileManagerWrapper(files)) { CompilationTask task = getCompilationTask(wrapper, options); - boolean result = task.call(); + boolean result = PrivilegedOperations.compilationTaskCall(task); if (result == false) { throw new RuntimeException("Could not compile " + sources.entrySet().stream().toList()); } @@ -160,7 +162,7 @@ public static byte[] compile(String className, CharSequence sourceCode, String.. try (FileManagerWrapper wrapper = new FileManagerWrapper(file)) { CompilationTask task = getCompilationTask(wrapper, options); - boolean result = task.call(); + boolean result = PrivilegedOperations.compilationTaskCall(task); if (result == false) { throw new RuntimeException("Could not compile " + className + " with source code " + sourceCode); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/jar/JarUtils.java b/test/framework/src/main/java/org/elasticsearch/test/jar/JarUtils.java index 98ccd0c16e888..0da392cb7fb01 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/jar/JarUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/test/jar/JarUtils.java @@ -9,6 +9,8 @@ package org.elasticsearch.test.jar; +import org.elasticsearch.test.PrivilegedOperations.ClosableURLClassLoader; + import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.OutputStream; @@ -99,10 +101,10 @@ public static void createJarWithEntriesUTF(Path jarfile, Map ent * @param path Path to the jar file to load * @return A URLClassLoader that will load classes from the jar. It should be closed when no longer needed. */ - public static URLClassLoader loadJar(Path path) { + public static ClosableURLClassLoader loadJar(Path path) { try { URL[] urls = new URL[] { path.toUri().toURL() }; - return URLClassLoader.newInstance(urls, JarUtils.class.getClassLoader()); + return new ClosableURLClassLoader(URLClassLoader.newInstance(urls, JarUtils.class.getClassLoader())); } catch (MalformedURLException e) { throw new RuntimeException(e); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureAnnotationEngine.java b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureAnnotationEngine.java new file mode 100644 index 0000000000000..bc226b78ccdf2 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureAnnotationEngine.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test.mockito; + +import org.mockito.internal.configuration.InjectingAnnotationEngine; +import org.mockito.plugins.AnnotationEngine; + +import static org.elasticsearch.test.mockito.SecureMockUtil.wrap; + +public class SecureAnnotationEngine implements AnnotationEngine { + private final AnnotationEngine delegate; + + public SecureAnnotationEngine() { + delegate = wrap(InjectingAnnotationEngine::new); + } + + @Override + public AutoCloseable process(Class clazz, Object testInstance) { + return wrap(() -> delegate.process(clazz, testInstance)); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureMockMaker.java b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureMockMaker.java new file mode 100644 index 0000000000000..a56a09a4ef08c --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureMockMaker.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test.mockito; + +import org.mockito.MockedConstruction; +import org.mockito.internal.creation.bytebuddy.SubclassByteBuddyMockMaker; +import org.mockito.internal.util.reflection.LenientCopyTool; +import org.mockito.invocation.MockHandler; +import org.mockito.mock.MockCreationSettings; +import org.mockito.plugins.MockMaker; + +import java.util.Optional; +import java.util.function.Function; + +import static org.elasticsearch.test.mockito.SecureMockUtil.wrap; + +/** + * A {@link MockMaker} that works with {@link SecurityManager}. + */ +public class SecureMockMaker implements MockMaker { + + // delegates to initializing util, which we don't want to have public + public static void init() { + SecureMockUtil.init(); + } + + // TODO: consider using InlineByteBuddyMockMaker, but this requires using a java agent for instrumentation + private final SubclassByteBuddyMockMaker delegate; + + public SecureMockMaker() { + delegate = wrap(SubclassByteBuddyMockMaker::new); + } + + @SuppressWarnings("rawtypes") + @Override + public T createMock(MockCreationSettings mockCreationSettings, MockHandler mockHandler) { + return wrap(() -> delegate.createMock(mockCreationSettings, mockHandler)); + } + + @SuppressWarnings("rawtypes") + @Override + public Optional createSpy(MockCreationSettings settings, MockHandler handler, T object) { + // spies are not implemented by the bytebuddy delegate implementation + return wrap(() -> { + T instance = delegate.createMock(settings, handler); + new LenientCopyTool().copyToMock(object, instance); + return Optional.of(instance); + }); + } + + @SuppressWarnings("rawtypes") + @Override + public MockHandler getHandler(Object o) { + return delegate.getHandler(o); + } + + @SuppressWarnings("rawtypes") + @Override + public void resetMock(Object o, MockHandler mockHandler, MockCreationSettings mockCreationSettings) { + wrap(() -> { + delegate.resetMock(o, mockHandler, mockCreationSettings); + return (Void) null; + }); + } + + @Override + public TypeMockability isTypeMockable(Class type) { + return delegate.isTypeMockable(type); + } + + @SuppressWarnings("rawtypes") + @Override + public StaticMockControl createStaticMock(Class type, MockCreationSettings settings, MockHandler handler) { + return delegate.createStaticMock(type, settings, handler); + } + + @Override + public ConstructionMockControl createConstructionMock( + Class type, + Function> settingsFactory, + Function> handlerFactory, + MockedConstruction.MockInitializer mockInitializer + ) { + return delegate.createConstructionMock(type, settingsFactory, handlerFactory, mockInitializer); + } + + @Override + public void clearAllCaches() { + delegate.clearAllCaches(); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureMockUtil.java b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureMockUtil.java new file mode 100644 index 0000000000000..dd97c2b30dc3b --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureMockUtil.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test.mockito; + +import org.mockito.plugins.MockMaker; + +import java.security.AccessControlContext; +import java.security.AccessController; +import java.security.DomainCombiner; +import java.security.PrivilegedAction; +import java.security.ProtectionDomain; +import java.util.function.Supplier; + +class SecureMockUtil { + + // we use the protection domain of mockito for wrapped calls so that + // Elasticsearch server jar does not need additional permissions + private static final AccessControlContext context = getContext(); + + private static AccessControlContext getContext() { + ProtectionDomain[] pda = new ProtectionDomain[] { wrap(MockMaker.class::getProtectionDomain) }; + DomainCombiner combiner = (current, assigned) -> pda; + AccessControlContext acc = new AccessControlContext(AccessController.getContext(), combiner); + // getContext must be called with the new acc so that a combined context will be created + return AccessController.doPrivileged((PrivilegedAction) AccessController::getContext, acc); + } + + // forces static init to run + public static void init() {} + + // wrap the given call to play nice with SecurityManager + static T wrap(Supplier call) { + return AccessController.doPrivileged((PrivilegedAction) call::get, context); + } + + // no construction + private SecureMockUtil() {} +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureObjectInstantiator.java b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureObjectInstantiator.java new file mode 100644 index 0000000000000..4c3ef53ccce47 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureObjectInstantiator.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test.mockito; + +import org.mockito.creation.instance.Instantiator; + +/** + * A wrapper for instantiating objects reflectively, but plays nice with SecurityManager. + */ +class SecureObjectInstantiator implements Instantiator { + private final Instantiator delegate; + + SecureObjectInstantiator(Instantiator delegate) { + this.delegate = delegate; + } + + @Override + public T newInstance(Class cls) { + return SecureMockUtil.wrap(() -> delegate.newInstance(cls)); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureObjectInstantiatorProvider.java b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureObjectInstantiatorProvider.java new file mode 100644 index 0000000000000..fb2798cdaf392 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/mockito/SecureObjectInstantiatorProvider.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test.mockito; + +import org.mockito.creation.instance.Instantiator; +import org.mockito.internal.creation.instance.DefaultInstantiatorProvider; +import org.mockito.mock.MockCreationSettings; +import org.mockito.plugins.InstantiatorProvider2; + +/** + * A wrapper around the default provider which itself just wraps + * {@link Instantiator} instances to play nice with {@link SecurityManager}. + */ +public class SecureObjectInstantiatorProvider implements InstantiatorProvider2 { + private final DefaultInstantiatorProvider delegate; + + public SecureObjectInstantiatorProvider() { + delegate = new DefaultInstantiatorProvider(); + } + + @Override + public Instantiator getInstantiator(MockCreationSettings settings) { + return new SecureObjectInstantiator(delegate.getInstantiator(settings)); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java index ef77154c76b5d..7b8709c908357 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java @@ -2057,16 +2057,12 @@ protected Map getIndexMappingAsMap(String index) throws IOExcept } protected static boolean indexExists(String index) throws IOException { - return indexExists(client(), index); - } - - protected static boolean indexExists(RestClient client, String index) throws IOException { // We use the /_cluster/health/{index} API to ensure the index exists on the master node - which means all nodes see the index. Request request = new Request("GET", "/_cluster/health/" + index); request.addParameter("timeout", "0"); request.addParameter("level", "indices"); try { - final var response = client.performRequest(request); + final var response = client().performRequest(request); @SuppressWarnings("unchecked") final var indices = (Map) entityAsMap(response).get("indices"); return indices.containsKey(index); @@ -2123,19 +2119,12 @@ protected static boolean aliasExists(String index, String alias) throws IOExcept return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode(); } - /** - * Returns a list of the data stream's backing index names. - */ - protected static List getDataStreamBackingIndexNames(String dataStreamName) throws IOException { - return getDataStreamBackingIndexNames(client(), dataStreamName); - } - /** * Returns a list of the data stream's backing index names. */ @SuppressWarnings("unchecked") - protected static List getDataStreamBackingIndexNames(RestClient client, String dataStreamName) throws IOException { - Map response = getAsMap(client, "/_data_stream/" + dataStreamName); + protected static List getDataStreamBackingIndexNames(String dataStreamName) throws IOException { + Map response = getAsMap(client(), "/_data_stream/" + dataStreamName); List dataStreams = (List) response.get("data_streams"); assertThat(dataStreams.size(), equalTo(1)); Map dataStream = (Map) dataStreams.getFirst(); diff --git a/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.AnnotationEngine b/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.AnnotationEngine new file mode 100644 index 0000000000000..be695fd30f509 --- /dev/null +++ b/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.AnnotationEngine @@ -0,0 +1 @@ +org.elasticsearch.test.mockito.SecureAnnotationEngine diff --git a/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.InstantiatorProvider2 b/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.InstantiatorProvider2 new file mode 100644 index 0000000000000..ba41d233143f9 --- /dev/null +++ b/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.InstantiatorProvider2 @@ -0,0 +1 @@ +org.elasticsearch.test.mockito.SecureObjectInstantiatorProvider diff --git a/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.MockMaker b/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 0000000000000..e19b9d550f81b --- /dev/null +++ b/test/framework/src/main/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +org.elasticsearch.test.mockito.SecureMockMaker diff --git a/x-pack/plugin/analytics/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/analytics/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/async-search/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/async-search/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/async/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/async/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AbstractCCRRestTestCase.java b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AbstractCCRRestTestCase.java index 7360953a135f0..81f84c6b0cf0a 100644 --- a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AbstractCCRRestTestCase.java +++ b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AbstractCCRRestTestCase.java @@ -16,11 +16,13 @@ import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; -import org.elasticsearch.cluster.metadata.DataStreamTestHelper; +import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xcontent.ToXContent; @@ -29,6 +31,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashSet; @@ -43,6 +46,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; @TestCaseOrdering(AbstractCCRRestTestCase.TargetClusterTestOrdering.class) public abstract class AbstractCCRRestTestCase extends ESRestTestCase { @@ -350,16 +354,33 @@ protected Set getCcrNodeTasks() throws IOException { protected record CcrNodeTask(String remoteCluster, String leaderIndex, String followerIndex, int shardId) {} - /** - * Verify that the specified data stream has the expected backing index generations. - */ - protected static List verifyDataStream(final RestClient client, final String name, final int... expectedBackingIndices) + protected static boolean indexExists(String index) throws IOException { + Response response = adminClient().performRequest(new Request("HEAD", "/" + index)); + return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode(); + } + + protected static List verifyDataStream(final RestClient client, final String name, final String... expectedBackingIndices) throws IOException { - final List actualBackingIndices = getDataStreamBackingIndexNames(client, name); + Request request = new Request("GET", "/_data_stream/" + name); + Map response = toMap(client.performRequest(request)); + List retrievedDataStreams = (List) response.get("data_streams"); + assertThat(retrievedDataStreams, hasSize(1)); + List actualBackingIndexItems = (List) ((Map) retrievedDataStreams.get(0)).get("indices"); + assertThat(actualBackingIndexItems, hasSize(expectedBackingIndices.length)); + final List actualBackingIndices = new ArrayList<>(); for (int i = 0; i < expectedBackingIndices.length; i++) { - String actualBackingIndex = actualBackingIndices.get(i); - int expectedBackingIndexGeneration = expectedBackingIndices[i]; - assertThat(actualBackingIndex, DataStreamTestHelper.backingIndexEqualTo(name, expectedBackingIndexGeneration)); + Map actualBackingIndexItem = (Map) actualBackingIndexItems.get(i); + String actualBackingIndex = (String) actualBackingIndexItem.get("index_name"); + String expectedBackingIndex = expectedBackingIndices[i]; + + String actualDataStreamName = actualBackingIndex.substring(5, actualBackingIndex.indexOf('-', 5)); + String expectedDataStreamName = expectedBackingIndex.substring(5, expectedBackingIndex.indexOf('-', 5)); + assertThat(actualDataStreamName, equalTo(expectedDataStreamName)); + + int actualGeneration = Integer.parseInt(actualBackingIndex.substring(actualBackingIndex.lastIndexOf('-'))); + int expectedGeneration = Integer.parseInt(expectedBackingIndex.substring(expectedBackingIndex.lastIndexOf('-'))); + assertThat(actualGeneration, equalTo(expectedGeneration)); + actualBackingIndices.add(actualBackingIndex); } return List.copyOf(actualBackingIndices); } @@ -387,6 +408,17 @@ protected static void createAutoFollowPattern( assertOK(client.performRequest(request)); } + /** + * Fix point in time when data stream backing index is first time queried. + * This is required to avoid failures when running test at midnight. + * (index is created for day0, but assertions are executed for day1 assuming different time based index name that does not exist) + */ + private final LazyInitializable time = new LazyInitializable<>(System::currentTimeMillis); + + protected String backingIndexName(String dataStreamName, int generation) { + return DataStream.getDefaultBackingIndexName(dataStreamName, generation, time.getOrCompute()); + } + protected RestClient buildLeaderClient() throws IOException { assert targetCluster != TargetCluster.LEADER; return buildClient(getLeaderCluster().getHttpAddresses()); diff --git a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AutoFollowIT.java b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AutoFollowIT.java index 5679bbce59fd8..533a77c84e22d 100644 --- a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AutoFollowIT.java +++ b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/AutoFollowIT.java @@ -18,7 +18,7 @@ import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; import org.elasticsearch.client.WarningFailureException; -import org.elasticsearch.cluster.metadata.DataStreamTestHelper; +import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; @@ -250,7 +250,7 @@ public void testAutoFollowPatterns() throws Exception { } else { assertThat(getIndexSettingsAsMap("metrics-20210101"), hasEntry("index.number_of_replicas", "1")); } - assertThat(indexExists(adminClient(), excludedIndex), is(false)); + assertThat(indexExists(excludedIndex), is(false)); }); assertLongBusy(() -> verifyCcrMonitoring("metrics-20210101", "metrics-20210101")); @@ -324,12 +324,12 @@ public void testDataStreams() throws Exception { indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, numDocs); } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 1)); - verifyDataStream(client(), dataStreamName, 1); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs); }); @@ -338,7 +338,7 @@ public void testDataStreams() throws Exception { try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); Request indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -348,7 +348,7 @@ public void testDataStreams() throws Exception { } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 2)); - verifyDataStream(client(), dataStreamName, 1, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs + 1); }); @@ -357,7 +357,13 @@ public void testDataStreams() throws Exception { try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2, 3); + verifyDataStream( + leaderClient, + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3) + ); Request indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -367,14 +373,28 @@ public void testDataStreams() throws Exception { } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 3)); - verifyDataStream(client(), dataStreamName, 1, 2, 3); + verifyDataStream( + client(), + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3) + ); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs + 2); }); } finally { - cleanUpFollower(List.of(), List.of(dataStreamName), List.of(autoFollowPatternName)); - cleanUpLeader(List.of(), List.of(dataStreamName), List.of()); + cleanUpFollower( + List.of(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2), backingIndexName(dataStreamName, 3)), + List.of(dataStreamName), + List.of(autoFollowPatternName) + ); + cleanUpLeader( + List.of(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2), backingIndexName(dataStreamName, 3)), + List.of(dataStreamName), + List.of() + ); } } @@ -401,13 +421,19 @@ public void testDataStreamsRenameFollowDataStream() throws Exception { indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, numDocs); } - logger.info("--> checking {} has been auto followed to {}", dataStreamName, dataStreamNameFollower); + logger.info( + "--> checking {} with index {} has been auto followed to {} with backing index {}", + dataStreamName, + backingIndexName(dataStreamName, 1), + dataStreamNameFollower, + backingIndexName(dataStreamNameFollower, 1) + ); assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 1)); - verifyDataStream(client(), dataStreamNameFollower, 1); + verifyDataStream(client(), dataStreamNameFollower, backingIndexName(dataStreamNameFollower, 1)); ensureYellow(dataStreamNameFollower); verifyDocuments(client(), dataStreamNameFollower, numDocs); }); @@ -417,7 +443,7 @@ public void testDataStreamsRenameFollowDataStream() throws Exception { try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); Request indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -427,7 +453,12 @@ public void testDataStreamsRenameFollowDataStream() throws Exception { } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 2)); - verifyDataStream(client(), dataStreamNameFollower, 1, 2); + verifyDataStream( + client(), + dataStreamNameFollower, + backingIndexName(dataStreamNameFollower, 1), + backingIndexName(dataStreamNameFollower, 2) + ); ensureYellow(dataStreamNameFollower); verifyDocuments(client(), dataStreamNameFollower, numDocs + 1); }); @@ -437,7 +468,13 @@ public void testDataStreamsRenameFollowDataStream() throws Exception { try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2, 3); + verifyDataStream( + leaderClient, + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3) + ); Request indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -447,14 +484,32 @@ public void testDataStreamsRenameFollowDataStream() throws Exception { } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 3)); - verifyDataStream(client(), dataStreamNameFollower, 1, 2, 3); + verifyDataStream( + client(), + dataStreamNameFollower, + backingIndexName(dataStreamNameFollower, 1), + backingIndexName(dataStreamNameFollower, 2), + backingIndexName(dataStreamNameFollower, 3) + ); ensureYellow(dataStreamNameFollower); verifyDocuments(client(), dataStreamNameFollower, numDocs + 2); }); } finally { - cleanUpFollower(List.of(), List.of(dataStreamNameFollower), List.of(autoFollowPatternName)); - cleanUpLeader(List.of(), List.of(dataStreamName), List.of()); + cleanUpFollower( + List.of( + backingIndexName(dataStreamNameFollower, 1), + backingIndexName(dataStreamNameFollower, 2), + backingIndexName(dataStreamNameFollower, 3) + ), + List.of(dataStreamNameFollower), + List.of(autoFollowPatternName) + ); + cleanUpLeader( + List.of(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2), backingIndexName(dataStreamName, 3)), + List.of(dataStreamName), + List.of() + ); } } @@ -486,7 +541,7 @@ public void testDataStreams_autoFollowAfterDataStreamCreated() throws Exception indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, initialNumDocs); } @@ -494,11 +549,10 @@ public void testDataStreams_autoFollowAfterDataStreamCreated() throws Exception createAutoFollowPattern(client(), autoFollowPatternName, dataStreamName + "*", "leader_cluster", null); // Rollover and ensure only second backing index is replicated: - final List backingIndexNames; try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - backingIndexNames = verifyDataStream(leaderClient, dataStreamName, 1, 2); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); Request indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -508,23 +562,31 @@ public void testDataStreams_autoFollowAfterDataStreamCreated() throws Exception } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 1)); - verifyDataStream(client(), dataStreamName, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 2)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, 1); }); // Explicitly follow the first backing index and check that the data stream in follow cluster is updated correctly: - followIndex(backingIndexNames.getFirst(), backingIndexNames.getFirst()); + followIndex(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 1)); assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 1)); - verifyDataStream(client(), dataStreamName, 1, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, initialNumDocs + 1); }); } finally { - cleanUpFollower(List.of(), List.of(dataStreamName), List.of(autoFollowPatternName)); - cleanUpLeader(List.of(), List.of(dataStreamName), List.of()); + cleanUpFollower( + List.of(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)), + List.of(dataStreamName), + List.of(autoFollowPatternName) + ); + cleanUpLeader( + List.of(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)), + List.of(dataStreamName), + List.of() + ); } } @@ -552,7 +614,7 @@ public void testDataStreamsBackingIndicesOrdering() throws Exception { indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, initialNumDocs); } @@ -560,11 +622,10 @@ public void testDataStreamsBackingIndicesOrdering() throws Exception { createAutoFollowPattern(client(), autoFollowPatternName, dataStreamName + "*", "leader_cluster", null); // Rollover and ensure only second backing index is replicated: - final List backingIndexNames; try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); Request indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -574,17 +635,22 @@ public void testDataStreamsBackingIndicesOrdering() throws Exception { assertOK(leaderClient.performRequest(rolloverRequest)); assertOK(leaderClient.performRequest(indexRequest)); - backingIndexNames = verifyDataStream(leaderClient, dataStreamName, 1, 2, 3); + verifyDataStream( + leaderClient, + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3) + ); } - awaitIndexExists(backingIndexNames.get(1)); - awaitIndexExists(backingIndexNames.get(2)); + assertBusy(() -> assertThat(indexExists(backingIndexName(dataStreamName, 2)), is(true))); + assertBusy(() -> assertThat(indexExists(backingIndexName(dataStreamName, 3)), is(true))); // Replace a backing index in the follower data stream with one that has a prefix (simulating a shrink) - final String secondBackingIndex = backingIndexNames.get(1); - String shrunkIndexName = SHRUNKEN_INDEX_PREFIX + secondBackingIndex; + String shrunkIndexName = SHRUNKEN_INDEX_PREFIX + DataStream.getDefaultBackingIndexName(dataStreamName, 2); Request indexRequest = new Request("POST", "/" + shrunkIndexName + "/_doc"); indexRequest.addParameter("refresh", "true"); indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); @@ -600,7 +666,7 @@ public void testDataStreamsBackingIndicesOrdering() throws Exception { + dataStreamName + "\",\n" + " \"index\": \"" - + secondBackingIndex + + DataStream.getDefaultBackingIndexName(dataStreamName, 2) + "\"\n" + " }\n" + " },\n" @@ -623,7 +689,14 @@ public void testDataStreamsBackingIndicesOrdering() throws Exception { try (RestClient leaderClient = buildLeaderClient()) { Request rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2, 3, 4); + verifyDataStream( + leaderClient, + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3), + backingIndexName(dataStreamName, 4) + ); indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -632,11 +705,14 @@ public void testDataStreamsBackingIndicesOrdering() throws Exception { } assertBusy(() -> { - List actualBackingIndexItems = getDataStreamBackingIndexNames(dataStreamName); + Request request = new Request("GET", "/_data_stream/" + dataStreamName); + Map response = toMap(client().performRequest(request)); + List retrievedDataStreams = (List) response.get("data_streams"); + List actualBackingIndexItems = (List) ((Map) retrievedDataStreams.get(0)).get("indices"); assertThat(actualBackingIndexItems.size(), is(3)); - String writeIndex = actualBackingIndexItems.get(2); - assertThat(writeIndex, not(shrunkIndexName)); - assertThat(writeIndex, DataStreamTestHelper.backingIndexEqualTo(dataStreamName, 4)); + Map writeIndexMap = (Map) actualBackingIndexItems.get(2); + assertThat(writeIndexMap.get("index_name"), not(shrunkIndexName)); + assertThat(writeIndexMap.get("index_name"), is(backingIndexName(dataStreamName, 4))); assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 3)); }); } @@ -664,12 +740,12 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, numDocs); } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 1)); - verifyDataStream(client(), dataStreamName, 1); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs); }); @@ -678,7 +754,7 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { try (var leaderClient = buildLeaderClient()) { var rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); var indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -688,7 +764,7 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 2)); - verifyDataStream(client(), dataStreamName, 1, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs + 1); }); @@ -703,13 +779,12 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { "data stream [" + dataStreamName + "] cannot be rolled over, " + "because it is a replicated data stream" ) ); - backingIndexNames = verifyDataStream(client(), dataStreamName, 1, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); // Unfollow .ds-logs-tomcat-prod-000001 - final String writeIndex = backingIndexNames.getFirst(); - pauseFollow(writeIndex); - closeIndex(writeIndex); - unfollow(writeIndex); + pauseFollow(backingIndexName(dataStreamName, 1)); + closeIndex(backingIndexName(dataStreamName, 1)); + unfollow(backingIndexName(dataStreamName, 1)); // Try again var rolloverRequest2 = new Request("POST", "/" + dataStreamName + "/_rollover"); @@ -720,7 +795,7 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { "data stream [" + dataStreamName + "] cannot be rolled over, " + "because it is a replicated data stream" ) ); - verifyDataStream(client(), dataStreamName, 1, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); // Promote local data stream var promoteRequest = new Request("POST", "/_data_stream/_promote/" + dataStreamName); @@ -729,7 +804,13 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { // Try again and now the rollover should be successful because local data stream is now : var rolloverRequest3 = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(client().performRequest(rolloverRequest3)); - backingIndexNames = verifyDataStream(client(), dataStreamName, 1, 2, 3); + backingIndexNames = verifyDataStream( + client(), + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3) + ); // TODO: verify that following a backing index for logs-tomcat-prod data stream in remote cluster fails, // because local data stream isn't a replicated data stream anymore. @@ -745,8 +826,17 @@ public void testRolloverDataStreamInFollowClusterForbidden() throws Exception { } } finally { - cleanUpFollower(List.of(), List.of(dataStreamName), List.of(autoFollowPatternName)); - cleanUpLeader(List.of(), List.of(dataStreamName), List.of()); + if (backingIndexNames == null) { + // we failed to compute the actual backing index names in the test because we failed earlier on, guessing them on a + // best-effort basis + backingIndexNames = List.of( + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2), + backingIndexName(dataStreamName, 3) + ); + } + cleanUpFollower(backingIndexNames, List.of(dataStreamName), List.of(autoFollowPatternName)); + cleanUpLeader(backingIndexNames.subList(0, 2), List.of(dataStreamName), List.of()); } } @@ -910,12 +1000,12 @@ public void testDataStreamsBiDirectionalReplication() throws Exception { indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, leaderDataStreamName, 1); + verifyDataStream(leaderClient, leaderDataStreamName, backingIndexName(leaderDataStreamName, 1)); verifyDocuments(leaderClient, leaderDataStreamName, numDocs); } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndicesInFollowCluster + 1)); - verifyDataStream(client(), leaderDataStreamName, 1); + verifyDataStream(client(), leaderDataStreamName, backingIndexName(leaderDataStreamName, 1)); ensureYellow(leaderDataStreamName); verifyDocuments(client(), leaderDataStreamName, numDocs); }); @@ -954,7 +1044,7 @@ public void testDataStreamsBiDirectionalReplication() throws Exception { getNumberOfSuccessfulFollowedIndices(leaderClient), equalTo(initialNumberOfSuccessfulFollowedIndicesInLeaderCluster + 1) ); - verifyDataStream(leaderClient, followerDataStreamName, 1); + verifyDataStream(leaderClient, followerDataStreamName, backingIndexName(followerDataStreamName, 1)); ensureYellow(followerDataStreamName); verifyDocuments(leaderClient, followerDataStreamName, numDocs); }); @@ -1006,8 +1096,16 @@ public void testDataStreamsBiDirectionalReplication() throws Exception { verifyDocuments(leaderClient, aliasName, (numDocs + moreDocs) * 2); } } finally { - cleanUpFollower(List.of(), List.of(followerDataStreamName, leaderDataStreamName), List.of("id1")); - cleanUpLeader(List.of(), List.of(leaderDataStreamName, followerDataStreamName), List.of("id2")); + cleanUpFollower( + List.of(backingIndexName(followerDataStreamName, 1), backingIndexName(leaderDataStreamName, 1)), + List.of(followerDataStreamName, leaderDataStreamName), + List.of("id1") + ); + cleanUpLeader( + List.of(backingIndexName(leaderDataStreamName, 1), backingIndexName(followerDataStreamName, 1)), + List.of(leaderDataStreamName, followerDataStreamName), + List.of("id2") + ); } } @@ -1118,6 +1216,7 @@ private void testDataStreamPromotionWarnings(Boolean createFollowerTemplate) thr final String autoFollowPatternName = getTestName().toLowerCase(Locale.ROOT); int initialNumberOfSuccessfulFollowedIndices = getNumberOfSuccessfulFollowedIndices(); + List backingIndexNames = null; try { // Create index template Request putComposableIndexTemplateRequest = new Request("POST", "/_index_template/" + getTestName().toLowerCase(Locale.ROOT)); @@ -1140,12 +1239,12 @@ private void testDataStreamPromotionWarnings(Boolean createFollowerTemplate) thr indexRequest.setJsonEntity("{\"@timestamp\": \"" + DATE_FORMAT.format(new Date()) + "\",\"message\":\"abc\"}"); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, numDocs); } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 1)); - verifyDataStream(client(), dataStreamName, 1); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs); }); @@ -1154,7 +1253,7 @@ private void testDataStreamPromotionWarnings(Boolean createFollowerTemplate) thr try (var leaderClient = buildLeaderClient()) { var rolloverRequest = new Request("POST", "/" + dataStreamName + "/_rollover"); assertOK(leaderClient.performRequest(rolloverRequest)); - verifyDataStream(leaderClient, dataStreamName, 1, 2); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); var indexRequest = new Request("POST", "/" + dataStreamName + "/_doc"); indexRequest.addParameter("refresh", "true"); @@ -1164,22 +1263,33 @@ private void testDataStreamPromotionWarnings(Boolean createFollowerTemplate) thr } assertBusy(() -> { assertThat(getNumberOfSuccessfulFollowedIndices(), equalTo(initialNumberOfSuccessfulFollowedIndices + 2)); - verifyDataStream(client(), dataStreamName, 1, 2); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs + 1); }); - verifyDataStream(client(), dataStreamName, 1, 2); + backingIndexNames = verifyDataStream( + client(), + dataStreamName, + backingIndexName(dataStreamName, 1), + backingIndexName(dataStreamName, 2) + ); // Promote local data stream var promoteRequest = new Request("POST", "/_data_stream/_promote/" + dataStreamName); Response response = client().performRequest(promoteRequest); assertOK(response); } finally { + if (backingIndexNames == null) { + // we failed to compute the actual backing index names in the test because we failed earlier on, guessing them on a + // best-effort basis + backingIndexNames = List.of(backingIndexName(dataStreamName, 1), backingIndexName(dataStreamName, 2)); + } + // These cleanup methods are copied from the finally block of other Data Stream tests in this class however // they may no longer be required but have been included for completeness - cleanUpFollower(List.of(), List.of(dataStreamName), List.of(autoFollowPatternName)); - cleanUpLeader(List.of(), List.of(dataStreamName), List.of()); + cleanUpFollower(backingIndexNames, List.of(dataStreamName), List.of(autoFollowPatternName)); + cleanUpLeader(backingIndexNames.subList(0, 1), List.of(dataStreamName), List.of()); Request deleteTemplateRequest = new Request("DELETE", "/_index_template/" + getTestName().toLowerCase(Locale.ROOT)); if (createFollowerTemplate) { assertOK(client().performRequest(deleteTemplateRequest)); diff --git a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/DowngradeLicenseFollowIndexIT.java b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/DowngradeLicenseFollowIndexIT.java index 76dfdccb60106..930f546e4f681 100644 --- a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/DowngradeLicenseFollowIndexIT.java +++ b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/DowngradeLicenseFollowIndexIT.java @@ -127,7 +127,7 @@ public void testDowngradeRemoteClusterToBasic() throws Exception { // Index2 was created in leader cluster after the downgrade and therefor the auto follow coordinator in // follow cluster should not pick that index up: - assertThat(indexExists(adminClient(), index2), is(false)); + assertThat(indexExists(index2), is(false)); // parse the logs and ensure that the auto-coordinator skipped coordination on the leader cluster assertBusy(() -> { diff --git a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexIT.java b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexIT.java index 46c4d64d5b9ad..1f89d316a4e3d 100644 --- a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexIT.java +++ b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; +import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; @@ -230,7 +231,7 @@ public void testFollowDataStreamFails() throws Exception { try (RestClient leaderClient = buildLeaderClient()) { Request request = new Request("PUT", "/_data_stream/" + dataStreamName); assertOK(leaderClient.performRequest(request)); - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, DataStream.getDefaultBackingIndexName("logs-syslog-prod", 1)); } ResponseException failure = expectThrows(ResponseException.class, () -> followIndex(dataStreamName, dataStreamName)); diff --git a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexSecurityIT.java b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexSecurityIT.java index 13e31785f430b..0adf0b31b4ebd 100644 --- a/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexSecurityIT.java +++ b/x-pack/plugin/ccr/src/javaRestTest/java/org/elasticsearch/xpack/ccr/FollowIndexSecurityIT.java @@ -161,7 +161,7 @@ public void testFollowIndex() throws Exception { e = expectThrows(ResponseException.class, () -> followIndex(client(), "leader_cluster", unallowedIndex, unallowedIndex)); assertThat(e.getMessage(), containsString("action [indices:admin/xpack/ccr/put_follow] is unauthorized for user [test_ccr]")); // Verify that the follow index has not been created and no node tasks are running - assertThat(indexExists(adminClient(), unallowedIndex), is(false)); + assertThat(indexExists(unallowedIndex), is(false)); assertBusy(() -> assertThat(getCcrNodeTasks(), empty())); // User does have manage_follow_index index privilege on 'allowed' index, @@ -176,7 +176,7 @@ public void testFollowIndex() throws Exception { ) ); // Verify that the follow index has not been created and no node tasks are running - assertThat(indexExists(adminClient(), unallowedIndex), is(false)); + assertThat(indexExists(unallowedIndex), is(false)); assertBusy(() -> assertThat(getCcrNodeTasks(), empty())); followIndex(adminClient(), "leader_cluster", unallowedIndex, unallowedIndex); @@ -242,7 +242,7 @@ public void testAutoFollowPatterns() throws Exception { try { assertBusy(() -> ensureYellow(allowedIndex), 30, TimeUnit.SECONDS); assertBusy(() -> verifyDocuments(allowedIndex, 5, "*:*"), 30, TimeUnit.SECONDS); - assertThat(indexExists(adminClient(), disallowedIndex), is(false)); + assertThat(indexExists(disallowedIndex), is(false)); withMonitoring(logger, () -> { assertBusy(() -> verifyCcrMonitoring(allowedIndex, allowedIndex), 120L, TimeUnit.SECONDS); assertBusy(AbstractCCRRestTestCase::verifyAutoFollowMonitoring, 120L, TimeUnit.SECONDS); @@ -350,11 +350,11 @@ public void testUnPromoteAndFollowDataStream() throws Exception { """, dateFormat.format(new Date()))); assertOK(leaderClient.performRequest(indexRequest)); } - verifyDataStream(leaderClient, dataStreamName, 1); + verifyDataStream(leaderClient, dataStreamName, backingIndexName(dataStreamName, 1)); verifyDocuments(leaderClient, dataStreamName, numDocs); } assertBusy(() -> { - verifyDataStream(client(), dataStreamName, 1); + verifyDataStream(client(), dataStreamName, backingIndexName(dataStreamName, 1)); ensureYellow(dataStreamName); verifyDocuments(client(), dataStreamName, numDocs); }); @@ -366,10 +366,9 @@ public void testUnPromoteAndFollowDataStream() throws Exception { assertOK(client().performRequest(new Request("POST", "/" + dataStreamName + "/_rollover"))); // Unfollow .ds-logs-eu-monitor1-000001, // which is now possible because this index can now be closed as it is no longer the write index. - final String writeIndex = getDataStreamBackingIndexNames(dataStreamName).getFirst(); - pauseFollow(writeIndex); - closeIndex(writeIndex); - unfollow(writeIndex); + pauseFollow(backingIndexName(dataStreamName, 1)); + closeIndex(backingIndexName(dataStreamName, 1)); + unfollow(backingIndexName(dataStreamName, 1)); } } diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java index 18e125a7ae1ce..8b80935ca4df5 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java @@ -38,7 +38,6 @@ import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.engine.EngineConfig; import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesRequestCache; import org.elasticsearch.indices.IndicesService; @@ -530,8 +529,7 @@ static String[] extractLeaderShardHistoryUUIDs(Map ccrIndexMetad EngineConfig.INDEX_CODEC_SETTING, DataTier.TIER_PREFERENCE_SETTING, IndexSettings.BLOOM_FILTER_ID_FIELD_ENABLED_SETTING, - MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING, - DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC + MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING ); public static Settings filter(Settings originalSettings) { diff --git a/x-pack/plugin/ccr/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/ccr/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..16701ab74d8c9 --- /dev/null +++ b/x-pack/plugin/ccr/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistryRolloverIT.java b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistryRolloverIT.java index 921749ef426a6..04f9c540f949c 100644 --- a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistryRolloverIT.java +++ b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistryRolloverIT.java @@ -17,7 +17,6 @@ import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.DataStream; -import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.project.TestProjectResolvers; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.datastreams.DataStreamsPlugin; @@ -69,8 +68,7 @@ public void setup() { public void testRollover() throws Exception { ClusterState state = clusterService.state(); registry.clusterChanged(new ClusterChangedEvent(IndexTemplateRegistryRolloverIT.class.getName(), state, state)); - final var projectId = ProjectId.DEFAULT; - awaitClusterState(s -> s.metadata().getProject(projectId).templatesV2().containsKey(TEST_INDEX_TEMPLATE_ID)); + assertBusy(() -> { assertTrue(clusterService.state().metadata().getProject().templatesV2().containsKey(TEST_INDEX_TEMPLATE_ID)); }); String dsName = TEST_INDEX_PATTERN.replace('*', '1'); CreateDataStreamAction.Request createDataStreamRequest = new CreateDataStreamAction.Request( TEST_REQUEST_TIMEOUT, @@ -82,7 +80,7 @@ public void testRollover() throws Exception { assertNumberOfBackingIndices(1); registry.incrementVersion(); registry.clusterChanged(new ClusterChangedEvent(IndexTemplateRegistryRolloverIT.class.getName(), clusterService.state(), state)); - awaitClusterState(s -> s.metadata().getProject(projectId).dataStreams().get(dsName).rolloverOnWrite()); + assertBusy(() -> assertTrue(getDataStream().rolloverOnWrite())); assertNumberOfBackingIndices(1); String timestampValue = DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.formatMillis(System.currentTimeMillis()); @@ -91,7 +89,7 @@ public void testRollover() throws Exception { .source(String.format(Locale.ROOT, "{\"%s\":\"%s\"}", DEFAULT_TIMESTAMP_FIELD, timestampValue), XContentType.JSON) ).actionGet(); assertThat(docWriteResponse.status().getStatus(), equalTo(201)); - awaitClusterState(s -> s.metadata().getProject(projectId).dataStreams().get(dsName).getIndices().size() == 2); + assertBusy(() -> assertNumberOfBackingIndices(2)); } private void assertNumberOfBackingIndices(final int expected) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java index cded88c36388c..a7f65c60a06c4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.inference.action; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -16,7 +15,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -43,15 +41,13 @@ public static class Request extends AcknowledgedRequest { private final String inferenceEntityId; private final BytesReference content; private final XContentType contentType; - private final TimeValue timeout; - public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) { + public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; this.content = content; this.contentType = contentType; - this.timeout = timeout; } public Request(StreamInput in) throws IOException { @@ -60,13 +56,6 @@ public Request(StreamInput in) throws IOException { this.taskType = TaskType.fromStream(in); this.content = in.readBytesReference(); this.contentType = in.readEnum(XContentType.class); - - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) - || in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { - this.timeout = in.readTimeValue(); - } else { - this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT; - } } public TaskType getTaskType() { @@ -85,10 +74,6 @@ public XContentType getContentType() { return contentType; } - public TimeValue getTimeout() { - return timeout; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -96,11 +81,6 @@ public void writeTo(StreamOutput out) throws IOException { taskType.writeTo(out); out.writeBytesReference(content); XContentHelper.writeTo(out, contentType); - - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) - || out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { - out.writeTimeValue(timeout); - } } @Override @@ -125,13 +105,12 @@ public boolean equals(Object o) { return taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) && Objects.equals(content, request.content) - && contentType == request.contentType - && Objects.equals(timeout, request.timeout); + && contentType == request.contentType; } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout); + return Objects.hash(taskType, inferenceEntityId, content, contentType); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index a15f78a29e269..99404d9ce66b0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentObject; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xpack.core.inference.DequeUtils; import java.io.IOException; import java.util.Collections; @@ -24,8 +23,6 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.Flow; -import java.util.concurrent.LinkedBlockingDeque; -import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk; import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals; @@ -35,7 +32,9 @@ /** * Chat Completion results that only contain a Flow.Publisher. */ -public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) implements InferenceServiceResults { +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) + implements + InferenceServiceResults { public static final String NAME = "chat_completion_chunk"; public static final String MODEL_FIELD = "model"; @@ -58,63 +57,6 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher publ public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; public static final String TYPE_FIELD = "type"; - /** - * OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible. - * So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time. - */ - public StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) { - Deque buffer = new LinkedBlockingDeque<>(); - AtomicBoolean onComplete = new AtomicBoolean(); - this.publisher = downstream -> { - publisher.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - downstream.onSubscribe(new Flow.Subscription() { - @Override - public void request(long n) { - var nextItem = buffer.poll(); - if (nextItem != null) { - downstream.onNext(new Results(DequeUtils.of(nextItem))); - } else if (onComplete.get()) { - downstream.onComplete(); - } else { - subscription.request(n); - } - } - - @Override - public void cancel() { - subscription.cancel(); - } - }); - } - - @Override - public void onNext(Results item) { - var chunks = item.chunks(); - var firstItem = chunks.poll(); - chunks.forEach(buffer::offer); - downstream.onNext(new Results(DequeUtils.of(firstItem))); - } - - @Override - public void onError(Throwable throwable) { - downstream.onError(throwable); - } - - @Override - public void onComplete() { - // only complete if the buffer is empty, so that the client has a chance to drain the buffer - if (onComplete.compareAndSet(false, true)) { - if (buffer.isEmpty()) { - downstream.onComplete(); - } - } - } - }); - }; - } - @Override public boolean isStreaming() { return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java index 9226a4148900d..609822a4bd20b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java @@ -16,8 +16,6 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xpack.core.security.authc.AuthenticationFailureHandler; import org.elasticsearch.xpack.core.security.authc.Realm; -import org.elasticsearch.xpack.core.security.authc.service.NodeLocalServiceAccountTokenStore; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; @@ -116,18 +114,6 @@ default List, ActionListener>> getRo return Collections.emptyList(); } - /** - * Returns a {@link NodeLocalServiceAccountTokenStore} used to authenticate service account tokens. - * If {@code null} is returned, the default service account token stores will be used. - * - * Providing a custom {@link NodeLocalServiceAccountTokenStore} here overrides the default implementation. - * - * @param components Access to components that can be used to authenticate service account tokens - */ - default ServiceAccountTokenStore getServiceAccountTokenStore(SecurityComponents components) { - return null; - } - /** * Returns a authorization engine for authorizing requests, or null to use the default authorization mechanism. * diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/NodeLocalServiceAccountTokenStore.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/NodeLocalServiceAccountTokenStore.java deleted file mode 100644 index 688d5a98f3972..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/NodeLocalServiceAccountTokenStore.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.security.authc.service; - -import org.elasticsearch.xpack.core.security.action.service.TokenInfo; - -import java.util.List; - -public interface NodeLocalServiceAccountTokenStore extends ServiceAccountTokenStore { - default List findNodeLocalTokensFor(ServiceAccount.ServiceAccountId accountId) { - throw new IllegalStateException("Find node local tokens not supported by [" + this.getClass() + "]"); - } -} diff --git a/x-pack/plugin/core/src/main/plugin-metadata/plugin-security.codebases b/x-pack/plugin/core/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..6abfadf6f744c --- /dev/null +++ b/x-pack/plugin/core/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1 @@ +httpasyncclient: org.apache.http.nio.client.HttpAsyncClient diff --git a/x-pack/plugin/core/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/core/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..753667c37cd95 --- /dev/null +++ b/x-pack/plugin/core/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,21 @@ +grant { + // CommandLineHttpClient + permission java.lang.RuntimePermission "setFactory"; + // bouncy castle + permission java.security.SecurityPermission "putProviderProperty.BC"; + + // needed in (cf. o.e.x.c.s.s.RestorableContextClassLoader) + permission java.lang.RuntimePermission "getClassLoader"; + permission java.lang.RuntimePermission "setContextClassLoader"; + + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; + + // needed because of problems in unbound LDAP library + permission java.util.PropertyPermission "*", "read,write"; +}; + +grant codeBase "${codebase.httpasyncclient}" { + // rest client uses system properties which gets the default proxy + permission java.net.NetPermission "getProxySelector"; +}; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java index f9f67167a12b1..e0b04c6fe8769 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java @@ -34,25 +34,13 @@ public void setup() throws Exception { public void testValidate() { // valid model ID - var request = new PutInferenceModelAction.Request( - TASK_TYPE, - MODEL_ID + "_-0", - BYTES, - X_CONTENT_TYPE, - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE); ActionRequestValidationException validationException = request.validate(); assertNull(validationException); // invalid model IDs - var invalidRequest = new PutInferenceModelAction.Request( - TASK_TYPE, - "", - BYTES, - X_CONTENT_TYPE, - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE); validationException = invalidRequest.validate(); assertNotNull(validationException); @@ -60,19 +48,12 @@ public void testValidate() { TASK_TYPE, randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS), BYTES, - X_CONTENT_TYPE, - InferenceAction.Request.DEFAULT_TIMEOUT + X_CONTENT_TYPE ); validationException = invalidRequest2.validate(); assertNotNull(validationException); - var invalidRequest3 = new PutInferenceModelAction.Request( - TASK_TYPE, - null, - BYTES, - X_CONTENT_TYPE, - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE); validationException = invalidRequest3.validate(); assertNotNull(validationException); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java index ccd67fe4029d9..669ba2f881fe7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -19,17 +19,8 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.List; -import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase< StreamingUnifiedChatCompletionResults.Results> { @@ -207,66 +198,6 @@ public void testToolCallToXContentChunked() throws IOException { assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); } - public void testBufferedPublishing() { - var results = new ArrayDeque(); - results.offer(randomChatCompletionChunk()); - results.offer(randomChatCompletionChunk()); - var completed = new AtomicBoolean(); - var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> { - downstream.onSubscribe(new Flow.Subscription() { - @Override - public void request(long n) { - if (completed.compareAndSet(false, true)) { - downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results)); - } else { - downstream.onComplete(); - } - } - - @Override - public void cancel() { - fail("Cancel should never be called."); - } - }); - }); - - AtomicInteger counter = new AtomicInteger(0); - AtomicReference upstream = new AtomicReference<>(null); - Flow.Subscriber subscriber = spy( - new Flow.Subscriber() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - if (upstream.compareAndSet(null, subscription) == false) { - fail("Upstream already set?!"); - } - subscription.request(1); - } - - @Override - public void onNext(StreamingUnifiedChatCompletionResults.Results item) { - assertNotNull(item); - counter.incrementAndGet(); - var sub = upstream.get(); - if (sub != null) { - sub.request(1); - } else { - fail("Upstream not yet set?!"); - } - } - - @Override - public void onError(Throwable throwable) { - fail(throwable); - } - - @Override - public void onComplete() {} - } - ); - streamingResults.publisher().subscribe(subscriber); - verify(subscriber, times(2)).onNext(any()); - } - @Override protected Writeable.Reader instanceReader() { return StreamingUnifiedChatCompletionResults.Results::new; diff --git a/x-pack/plugin/deprecation/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/deprecation/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..16701ab74d8c9 --- /dev/null +++ b/x-pack/plugin/deprecation/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/x-pack/plugin/ent-search/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/ent-search/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..4de6d5924521d --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // needed for Jackson ObjectMapper to parse floats + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/x-pack/plugin/eql/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/eql/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java index 613f5b0ae76c2..82c31c0dbdd7e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java @@ -184,19 +184,16 @@ protected void doCollectFirst(Predicate predicate, List matches) { public T transformDown(Function rule) { T root = rule.apply((T) this); Node node = this.equals(root) ? this : root; + return node.transformChildren(child -> child.transformDown(rule)); } @SuppressWarnings("unchecked") public T transformDown(Class typeToken, Function rule) { + // type filtering function return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } - @SuppressWarnings("unchecked") - public T transformDown(Predicate> nodePredicate, Function rule) { - return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t)); - } - @SuppressWarnings("unchecked") public T transformUp(Function rule) { T transformed = transformChildren(child -> child.transformUp(rule)); @@ -206,14 +203,10 @@ public T transformUp(Function rule) { @SuppressWarnings("unchecked") public T transformUp(Class typeToken, Function rule) { + // type filtering function return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } - @SuppressWarnings("unchecked") - public T transformUp(Predicate> nodePredicate, Function rule) { - return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t)); - } - @SuppressWarnings("unchecked") protected > T transformChildren(Function traversalOperation) { boolean childrenChanged = false; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java index a330ac076d2c4..3de94bb90e5e3 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java @@ -560,10 +560,6 @@ public static boolean isSpatialPoint(DataType t) { return t == GEO_POINT || t == CARTESIAN_POINT; } - public static boolean isSpatialShape(DataType t) { - return t == GEO_SHAPE || t == CARTESIAN_SHAPE; - } - public static boolean isSpatialGeo(DataType t) { return t == GEO_POINT || t == GEO_SHAPE; } diff --git a/x-pack/plugin/esql/arrow/src/test/resources/plugin-security.policy b/x-pack/plugin/esql/arrow/src/test/resources/plugin-security.policy new file mode 100644 index 0000000000000..c5da65410d3da --- /dev/null +++ b/x-pack/plugin/esql/arrow/src/test/resources/plugin-security.policy @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +// Needed by the Arrow memory manager +grant { + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + permission java.lang.RuntimePermission "accessClassInPackage.sun.misc"; +}; diff --git a/x-pack/plugin/esql/build.gradle b/x-pack/plugin/esql/build.gradle index 28c3e4d2b20cb..539da924e6504 100644 --- a/x-pack/plugin/esql/build.gradle +++ b/x-pack/plugin/esql/build.gradle @@ -53,7 +53,6 @@ dependencies { testImplementation project(path: xpackModule('enrich')) testImplementation project(path: xpackModule('spatial')) testImplementation project(path: xpackModule('kql')) - testImplementation project(path: xpackModule('mapper-unsigned-long')) testImplementation project(path: ':modules:reindex') testImplementation project(path: ':modules:parent-join') @@ -251,7 +250,6 @@ pluginManager.withPlugin('com.diffplug.spotless') { "src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser*.java", "src/main/generated/**/*.java", "src/main/generated-src/generated/**/*.java" - toggleOffOn('begin generated imports', 'end generated imports') } } } diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 66867ae668fcc..e0a3bc34acc55 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -49,7 +49,6 @@ spotless { /* * Generated files go here. */ - toggleOffOn('begin generated imports', 'end generated imports') targetExclude "src/main/generated/**/*.java" } } @@ -101,7 +100,7 @@ tasks.named('stringTemplates').configure { "", "BYTES_REF", "org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF", - "BytesRefArray", + "", "BytesRefHash" ) var ipProperties = prop("Ip", "BytesRef", "BytesRef", "", "BYTES_REF", "16", "", "") diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 1704f4cbeb1fe..5f9c3ffed7064 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -55,8 +55,7 @@ import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_FUNCTION; import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT; import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC; -import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK; -import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK; +import static org.elasticsearch.compute.gen.Types.INT_BLOCK; import static org.elasticsearch.compute.gen.Types.INT_VECTOR; import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; @@ -78,8 +77,6 @@ * and break-point-able as possible. */ public class GroupingAggregatorImplementer { - private static final List GROUP_IDS_CLASSES = List.of(INT_ARRAY_BLOCK, INT_BIG_ARRAY_BLOCK, INT_VECTOR); - private final TypeElement declarationType; private final List warnExceptions; private final ExecutableElement init; @@ -200,10 +197,10 @@ private TypeSpec type() { builder.addMethod(intermediateStateDesc()); builder.addMethod(intermediateBlockCount()); builder.addMethod(prepareProcessPage()); - for (ClassName groupIdClass : GROUP_IDS_CLASSES) { - builder.addMethod(addRawInputLoop(groupIdClass, blockType(aggParam.type()))); - builder.addMethod(addRawInputLoop(groupIdClass, vectorType(aggParam.type()))); - } + builder.addMethod(addRawInputLoop(INT_VECTOR, blockType(aggParam.type()))); + builder.addMethod(addRawInputLoop(INT_VECTOR, vectorType(aggParam.type()))); + builder.addMethod(addRawInputLoop(INT_BLOCK, blockType(aggParam.type()))); + builder.addMethod(addRawInputLoop(INT_BLOCK, vectorType(aggParam.type()))); builder.addMethod(selectedMayContainUnseenGroups()); builder.addMethod(addIntermediateInput()); builder.addMethod(addIntermediateRowInput()); @@ -373,12 +370,15 @@ private TypeSpec addInput(Consumer addBlock) { TypeSpec.Builder builder = TypeSpec.anonymousClassBuilder(""); builder.addSuperinterface(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT); - for (ClassName groupIdsType : GROUP_IDS_CLASSES) { - MethodSpec.Builder vector = MethodSpec.methodBuilder("add").addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); - vector.addParameter(TypeName.INT, "positionOffset").addParameter(groupIdsType, "groupIds"); - addBlock.accept(vector); - builder.addMethod(vector.build()); - } + MethodSpec.Builder block = MethodSpec.methodBuilder("add").addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); + block.addParameter(TypeName.INT, "positionOffset").addParameter(INT_BLOCK, "groupIds"); + addBlock.accept(block); + builder.addMethod(block.build()); + + MethodSpec.Builder vector = MethodSpec.methodBuilder("add").addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); + vector.addParameter(TypeName.INT, "positionOffset").addParameter(INT_VECTOR, "groupIds"); + addBlock.accept(vector); + builder.addMethod(vector.build()); MethodSpec.Builder close = MethodSpec.methodBuilder("close").addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); builder.addMethod(close.build()); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java index a9f4eef521716..62ecee6b5c6e9 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java @@ -46,8 +46,6 @@ public class Types { static final ClassName BOOLEAN_BLOCK = ClassName.get(DATA_PACKAGE, "BooleanBlock"); static final ClassName BYTES_REF_BLOCK = ClassName.get(DATA_PACKAGE, "BytesRefBlock"); static final ClassName INT_BLOCK = ClassName.get(DATA_PACKAGE, "IntBlock"); - static final ClassName INT_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "IntArrayBlock"); - static final ClassName INT_BIG_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "IntBigArrayBlock"); static final ClassName LONG_BLOCK = ClassName.get(DATA_PACKAGE, "LongBlock"); static final ClassName DOUBLE_BLOCK = ClassName.get(DATA_PACKAGE, "DoubleBlock"); static final ClassName FLOAT_BLOCK = ClassName.get(DATA_PACKAGE, "FloatBlock"); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java index 72cb193fcc46f..47d386d0bd690 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java @@ -7,23 +7,20 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link BooleanArrayVector}. * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class BooleanArrayBlock extends AbstractArrayBlock implements BooleanBlock { +final class BooleanArrayBlock extends AbstractArrayBlock implements BooleanBlock { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(BooleanArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayVector.java index 18a89f29655b6..cde163a2d3bc5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayVector.java @@ -7,20 +7,15 @@ package org.elasticsearch.compute.data; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -// end generated imports /** * Vector implementation that stores an array of boolean values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlock.java index f5b80450a84fb..140e1991bb199 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -15,7 +14,6 @@ import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores boolean values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlockBuilder.java index 1f7c59bbea153..1fe75bff6e1a5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanBlockBuilder.java @@ -7,16 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.core.Releasables; import java.util.Arrays; -// end generated imports /** * Block build of BooleanBlocks. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanLookup.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanLookup.java index c53b585a44336..10de16af922f6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanLookup.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanLookup.java @@ -7,12 +7,10 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link BooleanBlock}s. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVector.java index 39f0e98d76fe6..813f7cd757207 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,7 +14,6 @@ import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores boolean values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVectorBlock.java index ead14568ed7a8..56cfc725801ab 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanVectorBlock.java @@ -7,11 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link BooleanVector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java index 0f4b24ee91f2c..a85b75d8fdc2a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; @@ -18,14 +17,13 @@ import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link BytesRefArrayVector}. * Does not take ownership of the given {@link BytesRefArray} and does not adjust circuit breakers to account for it. * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class BytesRefArrayBlock extends AbstractArrayBlock implements BytesRefBlock { +final class BytesRefArrayBlock extends AbstractArrayBlock implements BytesRefBlock { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(BytesRefArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayVector.java index 679019755840c..509ee7e583e4c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; @@ -18,9 +17,6 @@ import org.elasticsearch.core.Releasables; import java.io.IOException; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -// end generated imports /** * Vector implementation that stores an array of BytesRef values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlock.java index 93dbff0b6997f..918ff1a1b1ca1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; @@ -16,7 +15,6 @@ import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores BytesRef values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlockBuilder.java index 50e2fa0f10c4c..2d724df2d3275 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefBlockBuilder.java @@ -7,17 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.Releasables; -import java.util.Arrays; -// end generated imports - /** * Block build of BytesRefBlocks. * This class is generated. Edit {@code X-BlockBuilder.java.st} instead. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefLookup.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefLookup.java index e9ccdce2bfa13..98967fdac3fbe 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefLookup.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefLookup.java @@ -7,13 +7,11 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link BytesRefBlock}s. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVector.java index f5f9acb503196..1bca89f531c14 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; @@ -16,7 +15,6 @@ import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores BytesRef values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVectorBlock.java index 8167bfef8dc11..10cc1b5503a64 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefVectorBlock.java @@ -7,12 +7,10 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link BytesRefVector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBooleanVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBooleanVector.java index 89bbb0dd871f3..6ef344b8cc40d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBooleanVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBooleanVector.java @@ -7,13 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.ReleasableIterator; -// end generated imports /** * Vector implementation that stores a constant boolean value. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBytesRefVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBytesRefVector.java index 5b4ccfdb12bf4..4bb8ee4a5a392 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBytesRefVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantBytesRefVector.java @@ -7,14 +7,11 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.ReleasableIterator; -// end generated imports /** * Vector implementation that stores a constant BytesRef value. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantDoubleVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantDoubleVector.java index a73dc895ab2cf..b2f145e6918e1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantDoubleVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantDoubleVector.java @@ -7,13 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.ReleasableIterator; -// end generated imports /** * Vector implementation that stores a constant double value. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantFloatVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantFloatVector.java index b4d978b464e3a..09b34f0b57494 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantFloatVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantFloatVector.java @@ -7,13 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.ReleasableIterator; -// end generated imports /** * Vector implementation that stores a constant float value. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantIntVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantIntVector.java index 2938f98b97bdc..1131096edf036 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantIntVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantIntVector.java @@ -7,13 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.ReleasableIterator; -// end generated imports /** * Vector implementation that stores a constant int value. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantLongVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantLongVector.java index 953761e29cd9e..a7e22ee58526b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantLongVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/ConstantLongVector.java @@ -7,13 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.ReleasableIterator; -// end generated imports /** * Vector implementation that stores a constant long value. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java index 68ab08ad59c3e..83c7b85a7ff5a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java @@ -7,23 +7,20 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link DoubleArrayVector}. * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class DoubleArrayBlock extends AbstractArrayBlock implements DoubleBlock { +final class DoubleArrayBlock extends AbstractArrayBlock implements DoubleBlock { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DoubleArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayVector.java index 14525399ff75b..5c375634011c6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayVector.java @@ -7,20 +7,15 @@ package org.elasticsearch.compute.data; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -// end generated imports /** * Vector implementation that stores an array of double values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlock.java index d3700d18f172d..b27bc64146760 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -15,7 +14,6 @@ import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores double values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlockBuilder.java index 7a69367d85db6..5896bbd2c51e5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleBlockBuilder.java @@ -7,16 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.DoubleArray; -import org.elasticsearch.core.Releasables; import java.util.Arrays; -// end generated imports /** * Block build of DoubleBlocks. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleLookup.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleLookup.java index 8a74d41ce62be..e8d69edb92c20 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleLookup.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleLookup.java @@ -7,12 +7,10 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link DoubleBlock}s. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVector.java index edfdb9aae7f47..b478c5ffbe043 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,7 +14,6 @@ import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores double values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVectorBlock.java index 0486ac56513d3..f6350bd4586ca 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleVectorBlock.java @@ -7,11 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link DoubleVector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java index 5c794ef0e74e9..749041d80d668 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java @@ -7,23 +7,20 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link FloatArrayVector}. * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class FloatArrayBlock extends AbstractArrayBlock implements FloatBlock { +final class FloatArrayBlock extends AbstractArrayBlock implements FloatBlock { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FloatArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayVector.java index ec36c23d7cd3c..f10e9dc39bbfd 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayVector.java @@ -7,20 +7,15 @@ package org.elasticsearch.compute.data; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -// end generated imports /** * Vector implementation that stores an array of float values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlock.java index ef9231b45848c..c2ba0260ff4d9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -15,7 +14,6 @@ import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores float values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlockBuilder.java index e0101df2b42de..809f74899c9c2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatBlockBuilder.java @@ -7,16 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.FloatArray; -import org.elasticsearch.core.Releasables; import java.util.Arrays; -// end generated imports /** * Block build of FloatBlocks. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatLookup.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatLookup.java index 26dd18406b942..25e39a649e948 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatLookup.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatLookup.java @@ -7,12 +7,10 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link FloatBlock}s. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVector.java index 652845a03d2a9..30fd4d69f221f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,7 +14,6 @@ import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores float values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVectorBlock.java index 28cfbe83ce4b1..6c2846183cd2d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatVectorBlock.java @@ -7,11 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link FloatVector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java index 3296ef9dbe3c1..0be8b6db78343 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java @@ -7,23 +7,20 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link IntArrayVector}. * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class IntArrayBlock extends AbstractArrayBlock implements IntBlock { +final class IntArrayBlock extends AbstractArrayBlock implements IntBlock { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IntArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayVector.java index ac72a7b0e3e0a..9db51c61bbf1d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayVector.java @@ -7,20 +7,15 @@ package org.elasticsearch.compute.data; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -// end generated imports /** * Vector implementation that stores an array of int values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java index 949a34ab02379..7e4d8f2801d22 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -15,7 +14,6 @@ import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores int values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java index d957993ce10a2..cf8f84d7449ee 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntBlockBuilder.java @@ -7,16 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.IntArray; -import org.elasticsearch.core.Releasables; import java.util.Arrays; -// end generated imports /** * Block build of IntBlocks. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntLookup.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntLookup.java index ac7531c8893ef..83a6d92f43586 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntLookup.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntLookup.java @@ -7,12 +7,10 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link IntBlock}s. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVector.java index 4a00399306ef2..afd7aea269772 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,7 +14,6 @@ import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores int values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java index 911cd8a1744e3..a18b2e8ab2384 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntVectorBlock.java @@ -7,11 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link IntVector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java index 62f7914441ef9..9b9b7a694ebb2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java @@ -7,23 +7,20 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link LongArrayVector}. * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class LongArrayBlock extends AbstractArrayBlock implements LongBlock { +final class LongArrayBlock extends AbstractArrayBlock implements LongBlock { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(LongArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayVector.java index cbd81f8bc848c..ff9179343536e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayVector.java @@ -7,20 +7,15 @@ package org.elasticsearch.compute.data; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -// end generated imports /** * Vector implementation that stores an array of long values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlock.java index b61e092e73b1b..05d2b3e24c214 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlock.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -15,7 +14,6 @@ import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores long values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlockBuilder.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlockBuilder.java index cab565d47f58a..58d3dbfe0cb38 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlockBuilder.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongBlockBuilder.java @@ -7,16 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.LongArray; -import org.elasticsearch.core.Releasables; import java.util.Arrays; -// end generated imports /** * Block build of LongBlocks. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongLookup.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongLookup.java index 9de5db0d4293e..3422784c4df60 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongLookup.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongLookup.java @@ -7,12 +7,10 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link LongBlock}s. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVector.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVector.java index 622ee98eaa5a6..3b3badab91a40 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVector.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVector.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,7 +14,6 @@ import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores long values. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVectorBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVectorBlock.java index 07df5fb707671..26a2cab5704b5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVectorBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongVectorBlock.java @@ -7,11 +7,9 @@ package org.elasticsearch.compute.data; -// begin generated imports import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link LongVector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java index 4fce90e84add6..d031450a77f56 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -66,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -87,12 +81,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -107,42 +96,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctBooleanAggregator.combine(state, groupId, values.getBoolean(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CountDistinctBooleanAggregator.combine(state, groupId, values.getBoolean(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - CountDistinctBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + CountDistinctBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -163,7 +138,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -177,27 +152,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } - private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctBooleanAggregator.combine(state, groupId, values.getBoolean(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - CountDistinctBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java index 2d005a17dd182..fec083927d5d6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,44 +99,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CountDistinctBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - CountDistinctBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + CountDistinctBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -184,29 +159,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - CountDistinctBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java index 0f0dfd4fa5b2c..756e922913841 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CountDistinctDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - CountDistinctDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + CountDistinctDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - CountDistinctDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java index 8e2fa1d71419a..1462deb1aab91 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CountDistinctFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - CountDistinctFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + CountDistinctFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - CountDistinctFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java index 08768acfa5261..2145489c67096 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java @@ -14,8 +14,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -70,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -91,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -111,42 +99,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctIntAggregator.combine(state, groupId, values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CountDistinctIntAggregator.combine(state, groupId, values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - CountDistinctIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + CountDistinctIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -167,7 +141,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -181,27 +155,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctIntAggregator.combine(state, groupId, values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - CountDistinctIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java index 0b1caa1c3727c..20ae39cdbcd19 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctLongAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CountDistinctLongAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - CountDistinctLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + CountDistinctLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - CountDistinctLongAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - CountDistinctLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java index c0e299d57f6bb..b9ee302f45b24 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -73,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -94,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -114,45 +103,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -174,7 +149,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -190,30 +165,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java index df4b6c843ff75..ad3f37cd22a00 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -73,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -94,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -114,45 +103,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntVector groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -174,7 +149,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -190,30 +165,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java index d0252f8b420d0..9253aa51831b2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java @@ -11,8 +11,6 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -72,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -93,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -113,45 +101,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntVector groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -173,7 +147,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntBlock groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -189,30 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java index 8506d1e8d527b..e5a372c767b73 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java @@ -11,8 +11,7 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -112,45 +101,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntVector groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -172,7 +147,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntBlock groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -188,30 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java index 8a32e5552dd1c..2ca6ab02875a2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -73,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -94,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -114,45 +103,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + LastOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - LastOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + LastOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -174,7 +149,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -190,30 +165,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - LastOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java index 250c5cd755a12..38a3b23ee8cc5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -73,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -94,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -114,45 +103,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + LastOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntVector groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - LastOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + LastOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -174,7 +149,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -190,30 +165,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - LastOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java index 9b118c7dea9be..f03728a905ac3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java @@ -11,8 +11,6 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -72,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -93,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -113,45 +101,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + LastOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntVector groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - LastOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + LastOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -173,7 +147,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntBlock groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -189,30 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - LastOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java index 82bfc732969e5..c9ee5fbad3707 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java @@ -11,8 +11,7 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -112,45 +101,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + LastOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntVector groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - LastOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + LastOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -172,7 +147,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntBlock groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -188,30 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - LastOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - LastOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java index f7390f55bc52b..5e2684b85c8db 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -66,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -87,12 +81,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -107,42 +96,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -163,7 +138,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -177,27 +152,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } - private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java index 41f98d962bd2f..52bc763449f59 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,44 +99,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MaxBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MaxBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MaxBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + MaxBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -184,29 +159,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MaxBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MaxBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java index 53273dad7c0f0..0b2e5cca5d244 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java @@ -15,8 +15,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java index 49afaf3c7265d..4ec8212a2da62 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java @@ -15,8 +15,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java index 3d97bf9df5dd9..024d0db097b29 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -67,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -88,12 +81,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -108,42 +96,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -164,7 +138,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -178,27 +152,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java index fd38873655edd..805fc77aa9306 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,44 +99,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MaxIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MaxIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MaxIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + MaxIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -184,29 +159,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MaxIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MaxIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java index fcaea869f84d4..5d6fa43723e7b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java index c380146094f44..9091515805dff 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MedianAbsoluteDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MedianAbsoluteDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + MedianAbsoluteDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MedianAbsoluteDeviationDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java index a895ebc9eda6b..1649e40d9045d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MedianAbsoluteDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MedianAbsoluteDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + MedianAbsoluteDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MedianAbsoluteDeviationFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java index f9b9934520f06..5904bef3956d3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java @@ -14,8 +14,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -67,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -88,12 +81,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -108,42 +96,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationIntAggregator.combine(state, groupId, values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MedianAbsoluteDeviationIntAggregator.combine(state, groupId, values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MedianAbsoluteDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + MedianAbsoluteDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -164,7 +138,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -178,27 +152,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationIntAggregator.combine(state, groupId, values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MedianAbsoluteDeviationIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java index e1693d7475c6f..bb50db9998a59 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationLongAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MedianAbsoluteDeviationLongAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MedianAbsoluteDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + MedianAbsoluteDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MedianAbsoluteDeviationLongAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MedianAbsoluteDeviationLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java index 4ca346913a25b..10bb3ca5c60bf 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -66,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -87,12 +81,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -107,42 +96,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -163,7 +138,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -177,27 +152,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } - private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java index dc721573876ab..29d96b63a8e59 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,44 +99,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MinBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MinBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MinBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + MinBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -184,29 +159,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MinBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MinBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java index 3212ca644aee7..c1396235fef0c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java @@ -15,8 +15,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java index 2e7b089e7592a..daadf3d7dbb53 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java @@ -15,8 +15,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), values.getFloat(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java index 50c5e80a55b0c..8f92c63096766 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -67,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -88,12 +81,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -108,42 +96,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -164,7 +138,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -178,27 +152,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java index c89c1feb6790f..05a5c3b57e2a6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,44 +99,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MinIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + MinIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - MinIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + MinIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -184,29 +159,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - MinIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - MinIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java index dc92d712ddb6a..c6421afa46211 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java index 1264bff20abf6..a2ba67333a05d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + PercentileDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - PercentileDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + PercentileDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - PercentileDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java index f844efae8d218..4c24b1b4221c6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java @@ -16,8 +16,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + PercentileFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - PercentileFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + PercentileFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - PercentileFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java index e0dd21ecc80d1..97ce87021f4f8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java @@ -14,8 +14,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -70,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -91,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -111,42 +99,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileIntAggregator.combine(state, groupId, values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + PercentileIntAggregator.combine(state, groupId, values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - PercentileIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + PercentileIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -167,7 +141,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -181,27 +155,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileIntAggregator.combine(state, groupId, values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - PercentileIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java index 1baa4a662175c..f2680fd7b5bef 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileLongAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + PercentileLongAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - PercentileLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + PercentileLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - PercentileLongAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - PercentileLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java index 25923bf02a761..e12686f2a66fa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -76,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -97,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -117,45 +105,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + RateDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - RateDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + RateDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -177,7 +151,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -193,30 +167,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - RateDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java index 7dbe1a2de02bd..5e2aced928554 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java @@ -15,8 +15,6 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -78,12 +76,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -99,12 +92,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -119,45 +107,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + RateFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntVector groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - RateFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + RateFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -179,7 +153,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -195,30 +169,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - RateFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java index 4650ebf0c5bb2..c85a0ac5b9fe0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -76,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -97,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -117,45 +105,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + RateIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntVector groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - RateIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + RateIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -177,7 +151,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntBlock groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -193,30 +167,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - RateIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java index a219a58068ea0..98996069fe554 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -76,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -97,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -117,45 +105,31 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + RateLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntVector groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - var valuePosition = groupPosition + positionOffset; - RateLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); - } + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + RateLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -177,7 +151,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntBlock groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -193,30 +167,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - RateLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values, - LongVector timestamps) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - RateLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java index 7cf0ab3e7b148..bbf1930cf0524 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,42 +99,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -166,7 +141,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -180,27 +155,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - StdDevDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java index e3bbbb5d4d624..818dd28386b9f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java @@ -15,8 +15,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java index b0c780b232fe7..d8417d71dc784 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -70,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -91,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -111,42 +99,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevIntAggregator.combine(state, groupId, values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevIntAggregator.combine(state, groupId, values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -167,7 +141,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -181,27 +155,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevIntAggregator.combine(state, groupId, values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - StdDevIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java index 7e33a0c70c145..fc5f061a04620 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,42 +99,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevLongAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + StdDevLongAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -166,7 +141,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -180,27 +155,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - StdDevLongAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - StdDevLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java index 303bb3d0ff5dc..d566756ca3282 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java @@ -15,8 +15,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -69,12 +68,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,42 +99,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SumDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SumDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SumDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + SumDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -166,7 +141,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -180,27 +155,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SumDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SumDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java index 154057db5f462..e12f91f4451a8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java @@ -17,8 +17,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -71,12 +70,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +86,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +101,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SumFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SumFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SumFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + SumFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +157,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SumFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SumFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java index 9b5cba8cd5a89..91c04d9370060 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java @@ -13,8 +13,6 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -69,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -110,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -166,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -180,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index a2969a4dddaa8..e53dfa3857753 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -68,12 +67,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -89,12 +83,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,42 +98,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); - } + int groupId = groups.getInt(groupPosition); + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -165,7 +140,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -179,27 +154,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java index 1fa211364cfcc..dc0e2831ee3ea 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -72,12 +71,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -93,12 +87,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -113,42 +102,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopBooleanAggregator.combine(state, groupId, values.getBoolean(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopBooleanAggregator.combine(state, groupId, values.getBoolean(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + TopBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -183,27 +158,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } - private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopBooleanAggregator.combine(state, groupId, values.getBoolean(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java index 4ab5bb9875107..aa779dfc4afad 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -73,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -94,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -114,44 +103,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -173,7 +148,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -188,29 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java index 8a2f4aef9cf35..68bdee76e48aa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -72,12 +71,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -93,12 +87,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -113,42 +102,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + TopDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -183,27 +158,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java index d09bf60c82aca..cf1063a02fcbf 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -72,12 +71,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -93,12 +87,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -113,42 +102,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + TopFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -183,27 +158,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java index 786f0660ea06f..67fa432b87995 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java @@ -11,8 +11,6 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -71,12 +69,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +85,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,42 +100,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopIntAggregator.combine(state, groupId, values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopIntAggregator.combine(state, groupId, values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + TopIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -168,7 +142,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -182,27 +156,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopIntAggregator.combine(state, groupId, values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java index 3d1137486fb75..eb5631c68d173 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -73,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -94,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -114,44 +103,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + TopIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -173,7 +148,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -188,29 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopIpAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopIpAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java index 820aa3c6c63e1..58b62c211a0b9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java @@ -11,8 +11,7 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -72,12 +71,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -93,12 +87,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -113,42 +102,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopLongAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopLongAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - TopLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + TopLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -169,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -183,27 +158,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - TopLongAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - TopLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java index a928d0908eb8e..772de15468ef3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -65,12 +64,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -86,12 +80,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -106,42 +95,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesBooleanAggregator.combine(state, groupId, values.getBoolean(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ValuesBooleanAggregator.combine(state, groupId, values.getBoolean(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + ValuesBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -162,7 +137,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlo } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -176,27 +151,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } - private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesBooleanAggregator.combine(state, groupId, values.getBoolean(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index 28843942b73cb..f9790deb190f3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -14,8 +14,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -66,12 +65,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } var addInput = new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -88,12 +82,7 @@ public void close() { } var addInput = new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -109,44 +98,30 @@ public void close() { return ValuesBytesRefAggregator.wrapAddInput(addInput, state, valuesVector); } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ValuesBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + ValuesBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -168,7 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -183,29 +158,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java index 76c865b33fd09..04984d5c0640a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -65,12 +64,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -86,12 +80,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -106,42 +95,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ValuesDoubleAggregator.combine(state, groupId, values.getDouble(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + ValuesDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -162,7 +137,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBloc } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -176,27 +151,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesDoubleAggregator.combine(state, groupId, values.getDouble(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java index bed9a884ccd10..1848b9e4f141b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java @@ -13,8 +13,7 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -65,12 +64,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -86,12 +80,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -106,42 +95,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesFloatAggregator.combine(state, groupId, values.getFloat(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ValuesFloatAggregator.combine(state, groupId, values.getFloat(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + ValuesFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -162,7 +137,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values) { + private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -176,27 +151,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesFloatAggregator.combine(state, groupId, values.getFloat(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, FloatVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java index fb801eadcf5cd..f0878fb085e6a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java @@ -11,8 +11,6 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -64,12 +62,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -85,12 +78,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -105,42 +93,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesIntAggregator.combine(state, groupId, values.getInt(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ValuesIntAggregator.combine(state, groupId, values.getInt(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + ValuesIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -161,7 +135,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -175,27 +149,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesIntAggregator.combine(state, groupId, values.getInt(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java index 061af9fcc9213..00bd36cda2523 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java @@ -11,8 +11,7 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -65,12 +64,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -86,12 +80,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -106,42 +95,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesLongAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + ValuesLongAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - ValuesLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + ValuesLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -162,7 +137,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -176,27 +151,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - ValuesLongAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - ValuesLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java index a959f808e438b..9a6328f075bca 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java @@ -17,8 +17,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -75,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -96,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -116,42 +105,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialCentroidCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialCentroidCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + SpatialCentroidCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -172,7 +147,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -186,27 +161,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialCentroidCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java index a3593b8152dd7..1681b1a210d3c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java @@ -20,8 +20,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -78,12 +77,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -99,12 +93,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -119,44 +108,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialCentroidCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialCentroidCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + SpatialCentroidCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -178,7 +153,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -193,29 +168,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialCentroidCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java index 77a959e654862..36413308e967f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java @@ -17,8 +17,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -75,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -96,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -116,42 +105,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialCentroidGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialCentroidGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + SpatialCentroidGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -172,7 +147,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -186,27 +161,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialCentroidGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java index fc05c0932f50c..935dd8f56887a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java @@ -20,8 +20,7 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -78,12 +77,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -99,12 +93,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -119,44 +108,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialCentroidGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialCentroidGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + SpatialCentroidGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -178,7 +153,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -193,29 +168,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialCentroidGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialCentroidGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java index 76f66cf41d569..8a9807be22ef5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java @@ -15,8 +15,6 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -73,12 +71,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -94,12 +87,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -114,42 +102,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialExtentCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialExtentCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + SpatialExtentCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -170,7 +144,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -184,27 +158,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialExtentCartesianPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java index 3c1159eb0de11..49198bbd74c69 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java @@ -18,8 +18,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -74,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -95,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -115,44 +103,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialExtentCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialExtentCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + SpatialExtentCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -174,7 +148,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -189,29 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialExtentCartesianPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java index 7057281c2ec6f..ca6b567810ae7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java @@ -15,8 +15,6 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -71,12 +69,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -92,12 +85,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -112,34 +100,27 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - int[] valuesArray = new int[valuesEnd - valuesStart]; - for (int v = valuesStart; v < valuesEnd; v++) { - valuesArray[v-valuesStart] = values.getInt(v); - } - SpatialExtentCartesianShapeDocValuesAggregator.combine(state, groupId, valuesArray); + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + int[] valuesArray = new int[valuesEnd - valuesStart]; + for (int v = valuesStart; v < valuesEnd; v++) { + valuesArray[v-valuesStart] = values.getInt(v); } + SpatialExtentCartesianShapeDocValuesAggregator.combine(state, groupId, valuesArray); } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { // This type does not support vectors because all values are multi-valued } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -162,27 +143,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { - // This type does not support vectors because all values are multi-valued - } - - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - int[] valuesArray = new int[valuesEnd - valuesStart]; - for (int v = valuesStart; v < valuesEnd; v++) { - valuesArray[v-valuesStart] = values.getInt(v); - } - SpatialExtentCartesianShapeDocValuesAggregator.combine(state, groupId, valuesArray); - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { // This type does not support vectors because all values are multi-valued } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java index 21241efbf3198..bf7e95c4f040f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java @@ -18,8 +18,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -74,12 +72,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -95,12 +88,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -115,44 +103,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentCartesianShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialExtentCartesianShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialExtentCartesianShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + SpatialExtentCartesianShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -174,7 +148,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -189,29 +163,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentCartesianShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialExtentCartesianShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java index 387ed0abc34bb..71abb7296232a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java @@ -15,8 +15,6 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -75,12 +73,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -96,12 +89,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -116,42 +104,28 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialExtentGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialExtentGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } + int groupId = groups.getInt(groupPosition); + SpatialExtentGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -172,7 +146,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -186,27 +160,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(v)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, LongVector values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialExtentGeoPointDocValuesAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java index 9d9c10902ada6..437c1f017ebc9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java @@ -18,8 +18,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -76,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -97,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -117,44 +105,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialExtentGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialExtentGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + SpatialExtentGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -176,7 +150,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -191,29 +165,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialExtentGeoPointSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java index 82553910e1587..b0a01d268280c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java @@ -15,8 +15,6 @@ import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -73,12 +71,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -94,12 +87,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -114,34 +102,27 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - int[] valuesArray = new int[valuesEnd - valuesStart]; - for (int v = valuesStart; v < valuesEnd; v++) { - valuesArray[v-valuesStart] = values.getInt(v); - } - SpatialExtentGeoShapeDocValuesAggregator.combine(state, groupId, valuesArray); + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + int[] valuesArray = new int[valuesEnd - valuesStart]; + for (int v = valuesStart; v < valuesEnd; v++) { + valuesArray[v-valuesStart] = values.getInt(v); } + SpatialExtentGeoShapeDocValuesAggregator.combine(state, groupId, valuesArray); } } - private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { // This type does not support vectors because all values are multi-valued } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -164,27 +145,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock v } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { - // This type does not support vectors because all values are multi-valued - } - - private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - int[] valuesArray = new int[valuesEnd - valuesStart]; - for (int v = valuesStart; v < valuesEnd; v++) { - valuesArray[v-valuesStart] = values.getInt(v); - } - SpatialExtentGeoShapeDocValuesAggregator.combine(state, groupId, valuesArray); - } - } - - private void addRawInput(int positionOffset, IntVector groups, IntVector values) { + private void addRawInput(int positionOffset, IntBlock groups, IntVector values) { // This type does not support vectors because all values are multi-valued } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java index ccab0870e206d..029e935f4765e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java @@ -18,8 +18,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -76,12 +74,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -97,12 +90,7 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesVector); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); } @@ -117,44 +105,30 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { continue; } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentGeoShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SpatialExtentGeoShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); } } } - private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = groups.getInt(g); - SpatialExtentGeoShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } + int groupId = groups.getInt(groupPosition); + SpatialExtentGeoShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -176,7 +150,7 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBl } } - private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -191,29 +165,6 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } - private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { - continue; - } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - SpatialExtentGeoShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); - } - } - } - - private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { - BytesRef scratch = new BytesRef(); - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - SpatialExtentGeoShapeSourceValuesAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); - } - } - @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java index 611118d03872b..124fb5a1745bd 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java @@ -11,8 +11,7 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -69,12 +68,7 @@ public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { } return new AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(positionOffset, groupIds, valuesBlock); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); } @@ -90,12 +84,7 @@ public void close() {} } return new AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addRawInput(groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { addRawInput(groupIds); } @@ -120,25 +109,7 @@ private void addRawInput(int positionOffset, IntVector groups, Block values) { } } - private void addRawInput(int positionOffset, IntArrayBlock groups, Block values) { - int position = positionOffset; - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = Math.toIntExact(groups.getInt(g)); - if (values.isNull(position)) { - continue; - } - state.increment(groupId, values.getValueCount(position)); - } - } - } - - private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block values) { + private void addRawInput(int positionOffset, IntBlock groups, Block values) { int position = positionOffset; for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) { if (groups.isNull(groupPosition)) { @@ -169,24 +140,7 @@ private void addRawInput(IntVector groups) { /** * This method is called for count all. */ - private void addRawInput(IntArrayBlock groups) { - for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - if (groups.isNull(groupPosition)) { - continue; - } - int groupStart = groups.getFirstValueIndex(groupPosition); - int groupEnd = groupStart + groups.getValueCount(groupPosition); - for (int g = groupStart; g < groupEnd; g++) { - int groupId = Math.toIntExact(groups.getInt(g)); - state.increment(groupId, 1); - } - } - } - - /** - * This method is called for count all. - */ - private void addRawInput(IntBigArrayBlock groups) { + private void addRawInput(IntBlock groups) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java index 8b7734fe33ab7..f34129c1116e4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java @@ -10,8 +10,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BooleanVector; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -58,21 +56,7 @@ public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { private record FilteredAddInput(BooleanVector mask, AddInput nextAdd, int positionCount) implements AddInput { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); - } - - private void addBlock(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { if (positionOffset == 0) { try (IntBlock filtered = groupIds.keepMask(mask)) { nextAdd.add(positionOffset, filtered); @@ -89,6 +73,11 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } + @Override + public void add(int positionOffset, IntVector groupIds) { + add(positionOffset, groupIds.asBlock()); + } + @Override public void close() { Releasables.close(mask, nextAdd); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java index 19012cabce5a1..1706b8c023995 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java @@ -10,8 +10,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.CompositeBlock; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -48,18 +46,6 @@ public void add(int positionOffset, IntBlock groupIds) { throw new IllegalStateException("Intermediate group id must not have nulls"); } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - assert false : "Intermediate group id must not have nulls"; - throw new IllegalStateException("Intermediate group id must not have nulls"); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - assert false : "Intermediate group id must not have nulls"; - throw new IllegalStateException("Intermediate group id must not have nulls"); - } - @Override public void add(int positionOffset, IntVector groupIds) { addIntermediateInput(positionOffset, groupIds, page); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java index e0d82b1f145b8..268ac3ba32678 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java @@ -9,8 +9,6 @@ import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -47,16 +45,6 @@ public void add(int positionOffset, IntBlock groupIds) { throw new IllegalStateException("Intermediate group id must not have nulls"); } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); - } - @Override public void add(int positionOffset, IntVector groupIds) { aggregatorFunction.addIntermediateInput(positionOffset, groupIds, page); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index 556902174f213..d5d432655620f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -8,12 +8,8 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.ConstantNullBlock; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.IntVectorBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.data.Vector; import org.elasticsearch.core.Releasable; @@ -52,40 +48,12 @@ interface AddInput extends Releasable { * be skipped entirely or the groupIds block could contain a * {@code null} value at that position. *

- *

- * This method delegates the processing to the other overloads for specific groupIds block types. - *

* @param positionOffset offset into the {@link Page} used to build this * {@link AddInput} of these ids * @param groupIds {@link Block} of group id, some of which may be null * or multivalued */ - default void add(int positionOffset, IntBlock groupIds) { - switch (groupIds) { - case ConstantNullBlock ignored: - // No-op - break; - case IntVectorBlock b: - add(positionOffset, b.asVector()); - break; - case IntArrayBlock b: - add(positionOffset, b); - break; - case IntBigArrayBlock b: - add(positionOffset, b); - break; - } - } - - /** - * Implementation of {@link #add(int, IntBlock)} for a specific type of block. - */ - void add(int positionOffset, IntArrayBlock groupIds); - - /** - * Implementation of {@link #add(int, IntBlock)} for a specific type of block. - */ - void add(int positionOffset, IntBigArrayBlock groupIds); + void add(int positionOffset, IntBlock groupIds); /** * Send a batch of group ids to the aggregator. The {@code groupIds} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/table/BlockHashRowInTableLookup.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/table/BlockHashRowInTableLookup.java index f00606f67548c..c198853bb36ad 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/table/BlockHashRowInTableLookup.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/table/BlockHashRowInTableLookup.java @@ -12,8 +12,6 @@ import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; @@ -42,23 +40,7 @@ final class BlockHashRowInTableLookup extends RowInTableLookup { private int lastOrd = -1; @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - for (int p = 0; p < groupIds.getPositionCount(); p++) { - int first = groupIds.getFirstValueIndex(p); - int end = groupIds.getValueCount(p) + first; - for (int i = first; i < end; i++) { - int ord = groupIds.getInt(i); - if (ord != lastOrd + 1) { - // TODO double check these errors over REST once we have LOOKUP - throw new IllegalArgumentException("found a duplicate row"); - } - lastOrd = ord; - } - } - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { for (int p = 0; p < groupIds.getPositionCount(); p++) { int first = groupIds.getFirstValueIndex(p); int end = groupIds.getValueCount(p) + first; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/ConstantNullBlock.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/ConstantNullBlock.java index 2ed905f4299ca..94cd0bf9ddd22 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/ConstantNullBlock.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/ConstantNullBlock.java @@ -19,7 +19,7 @@ /** * Block implementation representing a constant null value. */ -public final class ConstantNullBlock extends AbstractNonThreadSafeRefCounted +final class ConstantNullBlock extends AbstractNonThreadSafeRefCounted implements BooleanBlock, IntBlock, diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st index 5ebfc7bd4f9f6..707b19165bb3b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st @@ -7,20 +7,22 @@ package org.elasticsearch.compute.data; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; -$endif$ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BytesRefArray; +$else$ +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.ByteSizeValue; +$endif$ import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.BitSet; -// end generated imports /** * Block implementation that stores values in a {@link $Type$ArrayVector}. @@ -29,7 +31,7 @@ $if(BytesRef)$ $endif$ * This class is generated. Edit {@code X-ArrayBlock.java.st} instead. */ -public final class $Type$ArrayBlock extends AbstractArrayBlock implements $Type$Block { +final class $Type$ArrayBlock extends AbstractArrayBlock implements $Type$Block { static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance($Type$ArrayBlock.class); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayVector.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayVector.java.st index cb44f2df1732d..521e09d909a1c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayVector.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayVector.java.st @@ -7,7 +7,7 @@ package org.elasticsearch.compute.data; -// begin generated imports +$if(BytesRef)$ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.io.stream.StreamInput; @@ -17,10 +17,19 @@ import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; +import java.io.IOException; + +$else$ +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.ReleasableIterator; + import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -// end generated imports +$endif$ /** * Vector implementation that stores an array of $type$ values. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st index 70154ff682a44..19175b6a7b284 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Block.java.st @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ @@ -18,7 +17,6 @@ import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.index.mapper.BlockLoader; import java.io.IOException; -// end generated imports /** * Block that stores $type$ values. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st index 97adf2871909b..6553011e5b413 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-BlockBuilder.java.st @@ -7,16 +7,21 @@ package org.elasticsearch.compute.data; -// begin generated imports +$if(BytesRef)$ import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.$Array$; +import org.elasticsearch.common.util.BytesRefArray; import org.elasticsearch.core.Releasables; +$else$ +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.util.$Array$; + import java.util.Arrays; -// end generated imports +$endif$ /** * Block build of $Type$Blocks. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ConstantVector.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ConstantVector.java.st index 37cd9e4a82b14..ebac760031678 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ConstantVector.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ConstantVector.java.st @@ -7,17 +7,19 @@ package org.elasticsearch.compute.data; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.unit.ByteSizeValue; +$if(BytesRef)$ import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; + +$else$ import org.elasticsearch.core.ReleasableIterator; -// end generated imports +$endif$ /** * Vector implementation that stores a constant $type$ value. * This class is generated. Edit {@code X-ConstantVector.java.st} instead. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Lookup.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Lookup.java.st index 743f4c8b4ea57..ad3d93a76ad40 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Lookup.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Lookup.java.st @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ @@ -15,7 +14,6 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Generic {@link Block#lookup} implementation {@link $Type$Block}s. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st index 2996e83133017..47a7dc5735fd2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-Vector.java.st @@ -7,7 +7,6 @@ package org.elasticsearch.compute.data; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ @@ -18,7 +17,6 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import java.io.IOException; -// end generated imports /** * Vector that stores $type$ values. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st index ee4daca6edf45..5ab410e843eca 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-VectorBlock.java.st @@ -7,14 +7,12 @@ package org.elasticsearch.compute.data; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.ReleasableIterator; import org.elasticsearch.core.Releasables; -// end generated imports /** * Block view of a {@link $Type$Vector}. Cannot represent multi-values or nulls. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 2c6f9312e64bc..75d5a5bc51323 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -20,8 +20,7 @@ import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; @@ -146,21 +145,17 @@ class AddInput implements GroupingAggregatorFunction.AddInput { long aggStart; @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - startAggEndHash(); - for (GroupingAggregatorFunction.AddInput p : prepared) { - p.add(positionOffset, groupIds); - } - end(); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - startAggEndHash(); - for (GroupingAggregatorFunction.AddInput p : prepared) { - p.add(positionOffset, groupIds); + public void add(int positionOffset, IntBlock groupIds) { + IntVector groupIdsVector = groupIds.asVector(); + if (groupIdsVector != null) { + add(positionOffset, groupIdsVector); + } else { + startAggEndHash(); + for (GroupingAggregatorFunction.AddInput p : prepared) { + p.add(positionOffset, groupIds); + } + end(); } - end(); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index 9b8246be49799..0dccce6040e64 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -12,17 +12,13 @@ import org.elasticsearch.common.util.BitArray; import org.elasticsearch.compute.ConstantBooleanExpressionEvaluator; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; -import org.elasticsearch.compute.aggregation.blockhash.BlockHashWrapper; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BlockTypeRandomizer; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -41,7 +37,6 @@ import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasables; -import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.esql.core.type.DataType; import org.hamcrest.Matcher; @@ -50,7 +45,6 @@ import java.util.SortedSet; import java.util.TreeSet; import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.DoubleStream; import java.util.stream.IntStream; import java.util.stream.LongStream; @@ -108,7 +102,7 @@ private Operator.OperatorFactory simpleWithMode( if (randomBoolean()) { supplier = chunkGroups(emitChunkSize, supplier); } - return new RandomizingHashAggregationOperatorFactory( + return new HashAggregationOperator.HashAggregationOperatorFactory( List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), mode, List.of(supplier.groupingAggregatorFactory(mode, channels(mode))), @@ -654,7 +648,8 @@ public AddInput prepareProcessPage(SeenGroupIds ignoredSeenGroupIds, Page page) return seen; }, page); - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { for (int offset = 0; offset < groupIds.getPositionCount(); offset += emitChunkSize) { try (IntBlock.Builder builder = blockFactory().newIntBlockBuilder(emitChunkSize)) { int endP = Math.min(groupIds.getPositionCount(), offset + emitChunkSize); @@ -687,16 +682,6 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { int[] chunk = new int[emitChunkSize]; @@ -781,84 +766,4 @@ public String describe() { }; } - /** - * Custom {@link HashAggregationOperator.HashAggregationOperatorFactory} implementation that - * randomizes the GroupIds block type passed to AddInput. - *

- * This helps testing the different overloads of - * {@link org.elasticsearch.compute.aggregation.GroupingAggregatorFunction.AddInput#add} - *

- */ - private record RandomizingHashAggregationOperatorFactory( - List groups, - AggregatorMode aggregatorMode, - List aggregators, - int maxPageSize, - AnalysisRegistry analysisRegistry - ) implements Operator.OperatorFactory { - - @Override - public Operator get(DriverContext driverContext) { - Supplier blockHashSupplier = () -> { - BlockHash blockHash = groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize) - ? BlockHash.buildCategorizeBlockHash( - groups, - aggregatorMode, - driverContext.blockFactory(), - analysisRegistry, - maxPageSize - ) - : BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false); - - return new BlockHashWrapper(driverContext.blockFactory(), blockHash) { - @Override - public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { - blockHash.add(page, new GroupingAggregatorFunction.AddInput() { - @Override - public void add(int positionOffset, IntBlock groupIds) { - IntBlock newGroupIds = aggregatorMode.isInputPartial() - ? groupIds - : BlockTypeRandomizer.randomizeBlockType(groupIds); - addInput.add(positionOffset, newGroupIds); - } - - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - add(positionOffset, (IntBlock) groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - add(positionOffset, (IntBlock) groupIds); - } - - @Override - public void add(int positionOffset, IntVector groupIds) { - add(positionOffset, groupIds.asBlock()); - } - - @Override - public void close() { - addInput.close(); - } - }); - } - }; - }; - - return new HashAggregationOperator(aggregators, blockHashSupplier, driverContext); - } - - @Override - public String describe() { - return new HashAggregationOperator.HashAggregationOperatorFactory( - groups, - aggregatorMode, - aggregators, - maxPageSize, - analysisRegistry - ).describe(); - } - } - } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/AddPageTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/AddPageTests.java index c9628cc8074de..fb8b01e68c6cc 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/AddPageTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/AddPageTests.java @@ -11,8 +11,6 @@ import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockFactoryTests; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.test.ESTestCase; @@ -158,7 +156,8 @@ Added added(int positionOffset, int... ords) { private class TestAddInput implements GroupingAggregatorFunction.AddInput { private final List added = new ArrayList<>(); - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { List> result = new ArrayList<>(groupIds.getPositionCount()); for (int p = 0; p < groupIds.getPositionCount(); p++) { int valueCount = groupIds.getValueCount(p); @@ -173,19 +172,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { added.add(new Added(positionOffset, result)); } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override @@ -198,12 +187,7 @@ private class CountingAddInput implements GroupingAggregatorFunction.AddInput { private int count; @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - count++; - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { count++; } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java index 34cd299811470..ed86969a62227 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java @@ -19,8 +19,6 @@ import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -46,6 +44,7 @@ import java.util.stream.LongStream; import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -1277,13 +1276,7 @@ public void close() { ) { hash1.add(page, new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - groupIds.incRef(); - output1.add(new Output(positionOffset, groupIds, null)); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { groupIds.incRef(); output1.add(new Output(positionOffset, groupIds, null)); } @@ -1301,13 +1294,7 @@ public void close() { }); hash2.add(page, new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - groupIds.incRef(); - output2.add(new Output(positionOffset, groupIds, null)); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { groupIds.incRef(); output2.add(new Output(positionOffset, groupIds, null)); } @@ -1329,8 +1316,7 @@ public void close() { Output o2 = output2.get(i); assertThat(o1.offset, equalTo(o2.offset)); if (o1.vector != null) { - assertNull(o1.block); - assertThat(o1.vector, equalTo(o2.vector != null ? o2.vector : o2.block.asVector())); + assertThat(o1.vector, either(equalTo(o2.vector)).or(equalTo(o2.block.asVector()))); } else { assertNull(o2.vector); assertThat(o1.block, equalTo(o2.block)); @@ -1394,12 +1380,7 @@ public void testTimeSeriesBlockHash() throws Exception { Holder ords1 = new Holder<>(); hash1.add(page, new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - throw new AssertionError("time-series block hash should emit a vector"); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { + public void add(int positionOffset, IntBlock groupIds) { throw new AssertionError("time-series block hash should emit a vector"); } @@ -1416,7 +1397,8 @@ public void close() { }); Holder ords2 = new Holder<>(); hash2.add(page, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { // TODO: check why PackedValuesBlockHash doesn't emit a vector? IntVector vector = groupIds.asVector(); assertNotNull("should emit a vector", vector); @@ -1424,16 +1406,6 @@ private void addBlock(int positionOffset, IntBlock groupIds) { ords2.set(vector); } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { groupIds.incRef(); @@ -1526,7 +1498,8 @@ private BlockHash buildBlockHash(int emitBatchSize, Block... values) { static void hash(boolean collectKeys, BlockHash blockHash, Consumer callback, Block... values) { blockHash.add(new Page(values), new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { OrdsAndKeys result = new OrdsAndKeys( blockHash.toString(), positionOffset, @@ -1559,19 +1532,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashWrapper.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashWrapper.java deleted file mode 100644 index 0c93b6ff111cc..0000000000000 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashWrapper.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.aggregation.blockhash; - -import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.core.ReleasableIterator; - -/** - * A test BlockHash that wraps another one. - *

- * Its methods can be overridden to implement custom behaviours or checks. - *

- */ -public abstract class BlockHashWrapper extends BlockHash { - protected BlockHash blockHash; - - public BlockHashWrapper(BlockFactory blockFactory, BlockHash blockHash) { - super(blockFactory); - this.blockHash = blockHash; - } - - @Override - public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { - blockHash.add(page, addInput); - } - - @Override - public ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { - return blockHash.lookup(page, targetBlockSize); - } - - @Override - public Block[] getKeys() { - return blockHash.getKeys(); - } - - @Override - public IntVector nonEmpty() { - return blockHash.nonEmpty(); - } - - @Override - public BitArray seenGroupIds(BigArrays bigArrays) { - return blockHash.seenGroupIds(bigArrays); - } - - @Override - public void close() { - blockHash.close(); - } - - @Override - public String toString() { - return blockHash.toString(); - } -} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index 842952f9ef8bd..b1319e65e0989 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -25,8 +25,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntArrayBlock; -import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -101,7 +99,8 @@ public void testCategorizeRaw() { try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { for (int i = randomInt(2); i < 3; i++) { hash.add(page, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { assertEquals(groupIds.getPositionCount(), positions); assertEquals(1, groupIds.getInt(0)); @@ -116,19 +115,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override @@ -173,7 +162,8 @@ public void testCategorizeRawMultivalue() { try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) { for (int i = randomInt(2); i < 3; i++) { hash.add(page, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { assertEquals(groupIds.getPositionCount(), positions); assertThat(groupIds.getFirstValueIndex(0), equalTo(0)); @@ -195,19 +185,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override @@ -263,7 +243,8 @@ public void testCategorizeIntermediate() { BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); ) { rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { assertEquals(groupIds.getPositionCount(), positions1); assertEquals(1, groupIds.getInt(0)); assertEquals(2, groupIds.getInt(1)); @@ -277,19 +258,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override @@ -300,7 +271,8 @@ public void close() { intermediatePage1 = new Page(rawHash1.getKeys()[0]); rawHash2.add(page2, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { assertEquals(groupIds.getPositionCount(), positions2); assertEquals(1, groupIds.getInt(0)); assertEquals(2, groupIds.getInt(1)); @@ -309,19 +281,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { assertEquals(3, groupIds.getInt(4)); } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override @@ -337,7 +299,8 @@ public void close() { try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) { intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { List values = IntStream.range(0, groupIds.getPositionCount()) .map(groupIds::getInt) .boxed() @@ -349,19 +312,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { } } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override @@ -372,7 +325,8 @@ public void close() { for (int i = randomInt(2); i < 3; i++) { intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() { - private void addBlock(int positionOffset, IntBlock groupIds) { + @Override + public void add(int positionOffset, IntBlock groupIds) { List values = IntStream.range(0, groupIds.getPositionCount()) .map(groupIds::getInt) .boxed() @@ -382,19 +336,9 @@ private void addBlock(int positionOffset, IntBlock groupIds) { assertEquals(List.of(3, 1, 4), values); } - @Override - public void add(int positionOffset, IntArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - - @Override - public void add(int positionOffset, IntBigArrayBlock groupIds) { - addBlock(positionOffset, groupIds); - } - @Override public void add(int positionOffset, IntVector groupIds) { - addBlock(positionOffset, groupIds.asBlock()); + add(positionOffset, groupIds.asBlock()); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java deleted file mode 100644 index 9640cbb2b44ea..0000000000000 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTypeRandomizer.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.data; - -import org.elasticsearch.compute.test.TestBlockFactory; - -import java.util.BitSet; - -import static org.elasticsearch.test.ESTestCase.randomIntBetween; - -public class BlockTypeRandomizer { - private BlockTypeRandomizer() {} - - /** - * Returns a block with the same contents, but with a randomized type (Constant, vector, big-array...). - *

- * The new block uses a non-breaking block builder, and doesn't increment the circuit breaking. - * This is done to avoid randomly using more memory in tests that expect a deterministic memory usage. - *

- */ - public static IntBlock randomizeBlockType(IntBlock block) { - // Just to track the randomization - int classCount = 4; - - BlockFactory blockFactory = TestBlockFactory.getNonBreakingInstance(); - - // - // ConstantNullBlock. It requires all positions to be null - // - if (randomIntBetween(0, --classCount) == 0 && block.areAllValuesNull()) { - if (block instanceof ConstantNullBlock) { - return block; - } - return new ConstantNullBlock(block.getPositionCount(), blockFactory); - } - - // - // IntVectorBlock. It doesn't allow nulls or multivalues - // - if (randomIntBetween(0, --classCount) == 0 && block.doesHaveMultivaluedFields() == false && block.mayHaveNulls() == false) { - if (block instanceof IntVectorBlock) { - return block; - } - - int[] values = new int[block.getPositionCount()]; - for (int i = 0; i < values.length; i++) { - values[i] = block.getInt(i); - } - - return new IntVectorBlock(new IntArrayVector(values, block.getPositionCount(), blockFactory)); - } - - // Both IntArrayBlock and IntBigArrayBlock need a nullsBitSet and a firstValueIndexes int[] - int[] firstValueIndexes = new int[block.getPositionCount() + 1]; - BitSet nullsMask = new BitSet(block.getPositionCount()); - for (int i = 0; i < block.getPositionCount(); i++) { - firstValueIndexes[i] = block.getFirstValueIndex(i); - - if (block.isNull(i)) { - nullsMask.set(i); - } - } - int totalValues = block.getFirstValueIndex(block.getPositionCount() - 1) + block.getValueCount(block.getPositionCount() - 1); - firstValueIndexes[firstValueIndexes.length - 1] = totalValues; - - // - // IntArrayBlock - // - if (randomIntBetween(0, --classCount) == 0) { - if (block instanceof IntArrayBlock) { - return block; - } - - int[] values = new int[totalValues]; - for (int i = 0; i < values.length; i++) { - values[i] = block.getInt(i); - } - - return new IntArrayBlock(values, block.getPositionCount(), firstValueIndexes, nullsMask, block.mvOrdering(), blockFactory); - } - assert classCount == 1; - - // - // IntBigArrayBlock - // - if (block instanceof IntBigArrayBlock) { - return block; - } - - var intArray = blockFactory.bigArrays().newIntArray(totalValues); - for (int i = 0; i < block.getPositionCount(); i++) { - intArray.set(i, block.getInt(i)); - } - - return new IntBigArrayBlock(intArray, block.getPositionCount(), firstValueIndexes, nullsMask, block.mvOrdering(), blockFactory); - } -} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec index 49b16baf30f58..f860b1518750c 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec @@ -850,37 +850,3 @@ c:long | b:date 11 | 1984-05-01T00:00:00.000Z 11 | 1991-01-01T00:00:00.000Z ; - -resolveGroupingsBeforeResolvingImplicitReferencesToGroupings -required_capability: resolve_groupings_before_resolving_references_to_groupings_in_aggregations - -FROM employees -| STATS c = count(emp_no), b = BUCKET(hire_date, "1 year") + 1 year BY yr = BUCKET(hire_date, "1 year") -| SORT yr -| LIMIT 5 -; - -c:long | b:datetime | yr:datetime -11 | 1986-01-01T00:00:00.000Z | 1985-01-01T00:00:00.000Z -11 | 1987-01-01T00:00:00.000Z | 1986-01-01T00:00:00.000Z -15 | 1988-01-01T00:00:00.000Z | 1987-01-01T00:00:00.000Z -9 | 1989-01-01T00:00:00.000Z | 1988-01-01T00:00:00.000Z -13 | 1990-01-01T00:00:00.000Z | 1989-01-01T00:00:00.000Z -; - -resolveGroupingsBeforeResolvingExplicitReferencesToGroupings -required_capability: resolve_groupings_before_resolving_references_to_groupings_in_aggregations - -FROM employees -| STATS c = count(emp_no), b = yr + 1 year BY yr = BUCKET(hire_date, "1 year") -| SORT yr -| LIMIT 5 -; - -c:long | b:datetime | yr:datetime -11 | 1986-01-01T00:00:00.000Z | 1985-01-01T00:00:00.000Z -11 | 1987-01-01T00:00:00.000Z | 1986-01-01T00:00:00.000Z -15 | 1988-01-01T00:00:00.000Z | 1987-01-01T00:00:00.000Z -9 | 1989-01-01T00:00:00.000Z | 1988-01-01T00:00:00.000Z -13 | 1990-01-01T00:00:00.000Z | 1989-01-01T00:00:00.000Z -; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/sample.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/sample.csv-spec index 237eee40e60b7..9505001415f12 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/sample.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/sample.csv-spec @@ -4,7 +4,7 @@ // range. These stats should be correctly adjusted for the sampling. Furthermore, // they also assert the value of MV_COUNT(VALUES(...)), which is not adjusted for // the sampling and therefore gives the size of the sample. -// All ranges are very loose, so that the tests should practically never fail. +// All ranges are very loose, so that the tests should fail less than 1 in a billion. // The range checks are done in ES|QL, resulting in one boolean value (is_expected), // because the CSV tests don't support such assertions. @@ -40,10 +40,10 @@ required_capability: sample FROM employees | SAMPLE 0.5 | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no), sum_emp_no = SUM(emp_no) - | EVAL is_expected = count >= 20 AND count <= 180 AND - values_count >= 10 AND values_count <= 90 AND + | EVAL is_expected = count >= 40 AND count <= 160 AND + values_count >= 20 AND values_count <= 80 AND avg_emp_no > 10010 AND avg_emp_no < 10090 AND - sum_emp_no > 20*10010 AND sum_emp_no < 180*10090 + sum_emp_no > 40*10010 AND sum_emp_no < 160*10090 | KEEP is_expected ; @@ -59,8 +59,8 @@ FROM employees | SAMPLE 0.5 | WHERE emp_no > 10050 | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no) - | EVAL is_expected = count >= 5 AND count <= 95 AND - values_count >= 2 AND values_count <= 48 AND + | EVAL is_expected = count >= 10 AND count <= 90 AND + values_count >= 5 AND values_count <= 45 AND avg_emp_no > 10055 AND avg_emp_no < 10095 | KEEP is_expected ; @@ -77,8 +77,8 @@ FROM employees | WHERE emp_no <= 10050 | SAMPLE 0.5 | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no) - | EVAL is_expected = count >= 5 AND count <= 95 AND - values_count >= 2 AND values_count <= 48 AND + | EVAL is_expected = count >= 10 AND count <= 90 AND + values_count >= 5 AND values_count <= 45 AND avg_emp_no > 10005 AND avg_emp_no < 10045 | KEEP is_expected ; @@ -95,8 +95,8 @@ FROM employees | SAMPLE 0.5 | SORT emp_no | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no) - | EVAL is_expected = count >= 20 AND count <= 180 AND - values_count >= 10 AND values_count <= 90 AND + | EVAL is_expected = count >= 40 AND count <= 160 AND + values_count >= 20 AND values_count <= 80 AND avg_emp_no > 10010 AND avg_emp_no < 10090 | KEEP is_expected ; @@ -113,8 +113,8 @@ FROM employees | SORT emp_no | SAMPLE 0.5 | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no) - | EVAL is_expected = count >= 20 AND count <= 180 AND - values_count >= 10 AND values_count <= 90 AND + | EVAL is_expected = count >= 40 AND count <= 160 AND + values_count >= 20 AND values_count <= 80 AND avg_emp_no > 10010 AND avg_emp_no < 10090 | KEEP is_expected ; @@ -147,8 +147,8 @@ FROM employees | LIMIT 50 | SAMPLE 0.5 | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)) - | EVAL is_expected = count >= 5 AND count <= 95 AND - values_count >= 2 AND values_count <= 48 + | EVAL is_expected = count >= 10 AND count <= 90 AND + values_count >= 5 AND values_count <= 45 | KEEP is_expected ; @@ -201,8 +201,8 @@ FROM employees | SAMPLE 0.8 | SAMPLE 0.9 | STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no) - | EVAL is_expected = count >= 20 AND count <= 180 AND - values_count >= 10 AND values_count <= 90 AND + | EVAL is_expected = count >= 40 AND count <= 160 AND + values_count >= 20 AND values_count <= 80 AND avg_emp_no > 10010 AND avg_emp_no < 10090 | KEEP is_expected ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 86dbf4734da02..610baeca91c51 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -3097,27 +3097,3 @@ ROW a = [1,2,3], b = 5 STD_DEV(a):double | STD_DEV(b):double 0.816496580927726 | 0.0 ; - -resolveGroupingsBeforeResolvingImplicitReferencesToGroupings -required_capability: resolve_groupings_before_resolving_references_to_groupings_in_aggregations - -FROM employees -| EVAL date = "2025-01-01"::datetime -| stats m = MAX(hire_date) BY d = (date == "2025-01-01") -; - -m:datetime | d:boolean -1999-04-30T00:00:00.000Z | true -; - -resolveGroupingsBeforeResolvingExplicitReferencesToGroupings -required_capability: resolve_groupings_before_resolving_references_to_groupings_in_aggregations - -FROM employees -| EVAL date = "2025-01-01"::datetime -| stats m = MAX(hire_date), x = d::int + 1 BY d = (date == "2025-01-01") -; - -m:datetime | x:integer | d:boolean -1999-04-30T00:00:00.000Z | 2 | true -; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupJoinTypesIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupJoinTypesIT.java deleted file mode 100644 index 52c41e4056a8e..0000000000000 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupJoinTypesIT.java +++ /dev/null @@ -1,542 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.action; - -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.ESIntegTestCase.ClusterScope; -import org.elasticsearch.xpack.core.esql.action.ColumnInfo; -import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.plan.logical.join.Join; -import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; -import org.elasticsearch.xpack.spatial.SpatialPlugin; -import org.elasticsearch.xpack.unsignedlong.UnsignedLongMapperPlugin; -import org.elasticsearch.xpack.versionfield.VersionFieldPlugin; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; -import java.util.stream.Collectors; - -import static org.elasticsearch.test.ESIntegTestCase.Scope.SUITE; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; -import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; -import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; -import static org.elasticsearch.xpack.esql.core.type.DataType.BYTE; -import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; -import static org.elasticsearch.xpack.esql.core.type.DataType.DOC_DATA_TYPE; -import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; -import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; -import static org.elasticsearch.xpack.esql.core.type.DataType.HALF_FLOAT; -import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; -import static org.elasticsearch.xpack.esql.core.type.DataType.IP; -import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; -import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; -import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; -import static org.elasticsearch.xpack.esql.core.type.DataType.SCALED_FLOAT; -import static org.elasticsearch.xpack.esql.core.type.DataType.SHORT; -import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; -import static org.elasticsearch.xpack.esql.core.type.DataType.TSID_DATA_TYPE; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; - -/** - * This test suite tests the lookup join functionality in ESQL with various data types. - * For each pair of types being tested, it builds a main index called "index" containing a single document with as many fields as - * types being tested on the left of the pair, and then creates that many other lookup indexes, each with a single document containing - * exactly two fields: the field to join on, and a field to return. - * The assertion is that for valid combinations, the return result should exist, and for invalid combinations an exception should be thrown. - * If no exception is thrown, and no result is returned, our validation rules are not aligned with the internal behaviour (ie. a bug). - * Since the `LOOKUP JOIN` command requires the match field name to be the same between the main index and the lookup index, - * we will have field names that correctly represent the type of the field in the main index, but not the type of the field - * in the lookup index. This can be confusing, but it is important to remember that the field names are not the same as the types. - * For example, if we are testing the pairs (double, double), (double, float), (float, double) and (float, float), - * we will create the following indexes: - *
- *
index_double_double
- *
Index containing a single document with a field of type 'double' like:
- *         {
- *             "field_double": 1.0,  // this is mapped as type 'double'
- *             "other": "value"
- *         }
- *     
- *
index_double_float
- *
Index containing a single document with a field of type 'float' like:
- *         {
- *             "field_double": 1.0,  // this is mapped as type 'float' (a float with the name of the main index field)
- *             "other": "value"
- *         }
- *     
- *
index_float_double
- *
Index containing a single document with a field of type 'double' like:
- *         {
- *             "field_float": 1.0,  // this is mapped as type 'double' (a double with the name of the main index field)
- *             "other": "value"
- *         }
- *     
- *
index_float_float
- *
Index containing single document with a field of type 'float' like:
- *         {
- *             "field_float": 1.0,  // this is mapped as type 'float'
- *             "other": "value"
- *         }
- *     
- *
index
- *
Index containing document like:
- *         {
- *             "field_double": 1.0,  // this is mapped as type 'double'
- *             "field_float": 1.0    // this is mapped as type 'float'
- *         }
- *     
- *
- * Note that the lookup indexes have fields with a name that matches the type in the main index, and not the type actually used in the - * lookup index. Instead, the mapped type should be the type of the right-hand side of the pair being tested. - * Then we can run queries like: - *
- *     FROM index | LOOKUP JOIN index_double_float ON field_double | KEEP other
- * 
- * And assert that the result exists and is equal to "value". - */ -@ClusterScope(scope = SUITE, numClientNodes = 1, numDataNodes = 1) -public class LookupJoinTypesIT extends ESIntegTestCase { - protected Collection> nodePlugins() { - return List.of( - EsqlPlugin.class, - MapperExtrasPlugin.class, - VersionFieldPlugin.class, - UnsignedLongMapperPlugin.class, - SpatialPlugin.class - ); - } - - private static final Map testConfigurations = new HashMap<>(); - static { - // Initialize the test configurations for string tests - { - TestConfigs configs = testConfigurations.computeIfAbsent("strings", TestConfigs::new); - configs.addPasses(KEYWORD, KEYWORD); - configs.addPasses(TEXT, KEYWORD); - configs.addFailsUnsupported(KEYWORD, TEXT); - } - - // Test integer types - var integerTypes = List.of(BYTE, SHORT, INTEGER); - { - TestConfigs configs = testConfigurations.computeIfAbsent("integers", TestConfigs::new); - for (DataType mainType : integerTypes) { - for (DataType lookupType : integerTypes) { - configs.addPasses(mainType, lookupType); - } - // Long is currently treated differently in the validation, but we could consider changing that - configs.addFails(mainType, LONG); - configs.addFails(LONG, mainType); - } - } - - // Test float and double - var floatTypes = List.of(HALF_FLOAT, FLOAT, DOUBLE, SCALED_FLOAT); - { - TestConfigs configs = testConfigurations.computeIfAbsent("floats", TestConfigs::new); - for (DataType mainType : floatTypes) { - for (DataType lookupType : floatTypes) { - configs.addPasses(mainType, lookupType); - } - } - } - - // Tests for mixed-numerical types - { - TestConfigs configs = testConfigurations.computeIfAbsent("mixed-numerical", TestConfigs::new); - for (DataType mainType : integerTypes) { - for (DataType lookupType : floatTypes) { - // TODO: We should probably allow this, but we need to change the validation code in Join.java - configs.addFails(mainType, lookupType); - configs.addFails(lookupType, mainType); - } - } - } - - // Tests for all unsupported types - DataType[] unsupported = Join.UNSUPPORTED_TYPES; - { - Collection existing = testConfigurations.values(); - TestConfigs configs = testConfigurations.computeIfAbsent("unsupported", TestConfigs::new); - for (DataType type : unsupported) { - if (type == NULL - || type == DOC_DATA_TYPE - || type == TSID_DATA_TYPE - || type == AGGREGATE_METRIC_DOUBLE - || type.esType() == null - || type.isCounter() - || DataType.isRepresentable(type) == false) { - // Skip unmappable types, or types not supported in ES|QL in general - continue; - } - if (existingIndex(existing, type, type)) { - // Skip existing configurations - continue; - } - configs.addFailsUnsupported(type, type); - } - } - - // Tests for all types where left and right are the same type - DataType[] supported = { BOOLEAN, LONG, INTEGER, DOUBLE, SHORT, BYTE, FLOAT, HALF_FLOAT, DATETIME, IP, KEYWORD, SCALED_FLOAT }; - { - Collection existing = testConfigurations.values(); - TestConfigs configs = testConfigurations.computeIfAbsent("same", TestConfigs::new); - for (DataType type : supported) { - assertThat("Claiming supported for unsupported type: " + type, List.of(unsupported).contains(type), is(false)); - if (existingIndex(existing, type, type) == false) { - // Only add the configuration if it doesn't already exist - configs.addPasses(type, type); - } - } - } - - // Assert that unsupported types are not in the supported list - for (DataType type : unsupported) { - assertThat("Claiming supported for unsupported type: " + type, List.of(supported).contains(type), is(false)); - } - - // Assert that unsupported+supported covers all types: - List missing = new ArrayList<>(); - for (DataType type : DataType.values()) { - boolean isUnsupported = List.of(unsupported).contains(type); - boolean isSupported = List.of(supported).contains(type); - if (isUnsupported == false && isSupported == false) { - missing.add(type); - } - } - assertThat(missing + " are not in the supported or unsupported list", missing.size(), is(0)); - - // Tests for all other type combinations - { - Collection existing = testConfigurations.values(); - TestConfigs configs = testConfigurations.computeIfAbsent("others", TestConfigs::new); - for (DataType mainType : supported) { - for (DataType lookupType : supported) { - if (existingIndex(existing, mainType, lookupType) == false) { - // Only add the configuration if it doesn't already exist - configs.addFails(mainType, lookupType); - } - } - } - } - - // Make sure we have never added two configurations with the same index name - Set knownTypes = new HashSet<>(); - for (TestConfigs configs : testConfigurations.values()) { - for (TestConfig config : configs.configs.values()) { - if (knownTypes.contains(config.indexName())) { - throw new IllegalArgumentException("Duplicate index name: " + config.indexName()); - } - knownTypes.add(config.indexName()); - } - } - } - - private static boolean existingIndex(Collection existing, DataType mainType, DataType lookupType) { - String indexName = "index_" + mainType.esType() + "_" + lookupType.esType(); - return existing.stream().anyMatch(c -> c.exists(indexName)); - } - - public void testLookupJoinStrings() { - testLookupJoinTypes("strings"); - } - - public void testLookupJoinIntegers() { - testLookupJoinTypes("integers"); - } - - public void testLookupJoinFloats() { - testLookupJoinTypes("floats"); - } - - public void testLookupJoinMixedNumerical() { - testLookupJoinTypes("mixed-numerical"); - } - - public void testLookupJoinSame() { - testLookupJoinTypes("same"); - } - - public void testLookupJoinUnsupported() { - testLookupJoinTypes("unsupported"); - } - - public void testLookupJoinOthers() { - testLookupJoinTypes("others"); - } - - private void testLookupJoinTypes(String group) { - initIndexes(group); - initData(group); - for (TestConfig config : testConfigurations.get(group).configs.values()) { - String query = String.format( - Locale.ROOT, - "FROM index | LOOKUP JOIN %s ON %s | KEEP other", - config.indexName(), - config.fieldName() - ); - config.validateMainIndex(); - config.validateLookupIndex(); - config.testQuery(query); - } - } - - private void initIndexes(String group) { - Collection configs = testConfigurations.get(group).configs.values(); - String propertyPrefix = "{\n \"properties\" : {\n"; - String propertySuffix = " }\n}\n"; - // The main index will have many fields, one of each type to use in later type specific joins - String mainFields = propertyPrefix + configs.stream() - .map(TestConfig::mainPropertySpec) - .distinct() - .collect(Collectors.joining(",\n ")) + propertySuffix; - assertAcked(prepareCreate("index").setMapping(mainFields)); - - Settings.Builder settings = Settings.builder() - .put("index.number_of_shards", 1) - .put("index.number_of_replicas", 0) - .put("index.mode", "lookup"); - configs.forEach( - // Each lookup index will get a document with a field to join on, and a results field to get back - (c) -> assertAcked( - prepareCreate(c.indexName()).setSettings(settings.build()) - .setMapping(propertyPrefix + c.lookupPropertySpec() + propertySuffix) - ) - ); - } - - private void initData(String group) { - Collection configs = testConfigurations.get(group).configs.values(); - int docId = 0; - for (TestConfig config : configs) { - String doc = String.format(Locale.ROOT, """ - { - %s, - "other": "value" - } - """, lookupPropertyFor(config)); - index(config.indexName(), "" + (++docId), doc); - refresh(config.indexName()); - } - List mainProperties = configs.stream().map(this::mainPropertyFor).distinct().collect(Collectors.toList()); - index("index", "1", String.format(Locale.ROOT, """ - { - %s - } - """, String.join(",\n ", mainProperties))); - refresh("index"); - } - - private String lookupPropertyFor(TestConfig config) { - return String.format(Locale.ROOT, "\"%s\": %s", config.fieldName(), sampleDataTextFor(config.lookupType())); - } - - private String mainPropertyFor(TestConfig config) { - return String.format(Locale.ROOT, "\"%s\": %s", config.fieldName(), sampleDataTextFor(config.mainType())); - } - - private static String sampleDataTextFor(DataType type) { - var value = sampleDataFor(type); - if (value instanceof String) { - return "\"" + value + "\""; - } - return String.valueOf(value); - } - - private static final double SCALING_FACTOR = 10.0; - - private static Object sampleDataFor(DataType type) { - return switch (type) { - case BOOLEAN -> true; - case DATETIME, DATE_NANOS -> "2025-04-02T12:00:00.000Z"; - case IP -> "127.0.0.1"; - case KEYWORD, TEXT -> "key"; - case BYTE, SHORT, INTEGER -> 1; - case LONG, UNSIGNED_LONG -> 1L; - case HALF_FLOAT, FLOAT, DOUBLE, SCALED_FLOAT -> 1.0; - case VERSION -> "1.2.19"; - case GEO_POINT, CARTESIAN_POINT -> "POINT (1.0 2.0)"; - case GEO_SHAPE, CARTESIAN_SHAPE -> "POLYGON ((0.0 0.0, 1.0 0.0, 1.0 1.0, 0.0 1.0, 0.0 0.0))"; - default -> throw new IllegalArgumentException("Unsupported type: " + type); - }; - } - - private static class TestConfigs { - final String group; - final Map configs; - - TestConfigs(String group) { - this.group = group; - this.configs = new LinkedHashMap<>(); - } - - private boolean exists(String indexName) { - return configs.containsKey(indexName); - } - - private void add(TestConfig config) { - if (configs.containsKey(config.indexName())) { - throw new IllegalArgumentException("Duplicate index name: " + config.indexName()); - } - configs.put(config.indexName(), config); - } - - private void addPasses(DataType mainType, DataType lookupType) { - add(new TestConfigPasses(mainType, lookupType, true)); - } - - private void addFails(DataType mainType, DataType lookupType) { - String fieldName = "field_" + mainType.esType(); - String errorMessage = String.format( - Locale.ROOT, - "JOIN left field [%s] of type [%s] is incompatible with right field [%s] of type [%s]", - fieldName, - mainType.widenSmallNumeric(), - fieldName, - lookupType.widenSmallNumeric() - ); - add( - new TestConfigFails<>( - mainType, - lookupType, - VerificationException.class, - e -> assertThat(e.getMessage(), containsString(errorMessage)) - ) - ); - } - - private void addFailsUnsupported(DataType mainType, DataType lookupType) { - String fieldName = "field_" + mainType.esType(); - String errorMessage = String.format( - Locale.ROOT, - "JOIN with right field [%s] of type [%s] is not supported", - fieldName, - lookupType - ); - add( - new TestConfigFails<>( - mainType, - lookupType, - VerificationException.class, - e -> assertThat(e.getMessage(), containsString(errorMessage)) - ) - ); - } - } - - interface TestConfig { - DataType mainType(); - - DataType lookupType(); - - default String indexName() { - return "index_" + mainType().esType() + "_" + lookupType().esType(); - } - - default String fieldName() { - return "field_" + mainType().esType(); - } - - default String mainPropertySpec() { - return propertySpecFor(fieldName(), mainType(), ""); - } - - default String lookupPropertySpec() { - return propertySpecFor(fieldName(), lookupType(), ", \"other\": { \"type\" : \"keyword\" }"); - } - - /** Make sure the left index has the expected fields and types */ - default void validateMainIndex() { - validateIndex("index", fieldName(), sampleDataFor(mainType())); - } - - /** Make sure the lookup index has the expected fields and types */ - default void validateLookupIndex() { - validateIndex(indexName(), fieldName(), sampleDataFor(lookupType())); - } - - void testQuery(String query); - } - - private static String propertySpecFor(String fieldName, DataType type, String extra) { - if (type == SCALED_FLOAT) { - return String.format( - Locale.ROOT, - "\"%s\": { \"type\" : \"%s\", \"scaling_factor\": %f }", - fieldName, - type.esType(), - SCALING_FACTOR - ) + extra; - } - return String.format(Locale.ROOT, "\"%s\": { \"type\" : \"%s\" }", fieldName, type.esType().replaceAll("cartesian_", "")) + extra; - } - - private static void validateIndex(String indexName, String fieldName, Object expectedValue) { - String query = String.format(Locale.ROOT, "FROM %s | KEEP %s", indexName, fieldName); - try (var response = EsqlQueryRequestBuilder.newRequestBuilder(client()).query(query).get()) { - ColumnInfo info = response.response().columns().getFirst(); - assertThat("Expected index '" + indexName + "' to have column '" + fieldName + ": " + query, info.name(), is(fieldName)); - Iterator results = response.response().column(0).iterator(); - assertTrue("Expected at least one result for query: " + query, results.hasNext()); - Object indexedResult = response.response().column(0).iterator().next(); - assertThat("Expected valid result: " + query, indexedResult, is(expectedValue)); - } - } - - private record TestConfigPasses(DataType mainType, DataType lookupType, boolean hasResults) implements TestConfig { - @Override - public void testQuery(String query) { - try (var response = EsqlQueryRequestBuilder.newRequestBuilder(client()).query(query).get()) { - Iterator results = response.response().column(0).iterator(); - assertTrue("Expected at least one result for query: " + query, results.hasNext()); - Object indexedResult = response.response().column(0).iterator().next(); - if (hasResults) { - assertThat("Expected valid result: " + query, indexedResult, equalTo("value")); - } else { - assertThat("Expected empty results for query: " + query, indexedResult, is(nullValue())); - } - } - } - } - - private record TestConfigFails(DataType mainType, DataType lookupType, Class exception, Consumer assertion) - implements - TestConfig { - @Override - public void testQuery(String query) { - E e = expectThrows( - exception(), - "Expected exception " + exception().getSimpleName() + " but no exception was thrown: " + query, - () -> { - // noinspection EmptyTryBlock - try (var ignored = EsqlQueryRequestBuilder.newRequestBuilder(client()).query(query).get()) { - // We use try-with-resources to ensure the request is closed if the exception is not thrown (less cluttered errors) - } - } - ); - assertion().accept(e); - } - } -} diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBooleanEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBooleanEvaluator.java index 15f976d6e4090..97b4cba0d9938 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBooleanEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBooleanEvaluator.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.nulls; -// begin generated imports import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.Page; @@ -21,7 +20,6 @@ import java.util.List; import java.util.stream.IntStream; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Coalesce}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBytesRefEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBytesRefEvaluator.java index 547c325ccf132..7d6834e765a96 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBytesRefEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceBytesRefEvaluator.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.nulls; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BytesRefBlock; @@ -22,7 +21,6 @@ import java.util.List; import java.util.stream.IntStream; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Coalesce}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceDoubleEvaluator.java index a6c36ea2aac4a..4c01a961ecbee 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceDoubleEvaluator.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.nulls; -// begin generated imports import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; @@ -21,7 +20,6 @@ import java.util.List; import java.util.stream.IntStream; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Coalesce}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceIntEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceIntEvaluator.java index b4be642f34f84..e90bd4b8e5e35 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceIntEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceIntEvaluator.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.nulls; -// begin generated imports import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.Page; @@ -21,7 +20,6 @@ import java.util.List; import java.util.stream.IntStream; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Coalesce}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceLongEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceLongEvaluator.java index 98a782abd1ed1..53a21ad1198f4 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceLongEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceLongEvaluator.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.nulls; -// begin generated imports import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; @@ -21,7 +20,6 @@ import java.util.List; import java.util.stream.IntStream; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Coalesce}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBooleanEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBooleanEvaluator.java index 5db804c9a4852..a49c191b90fa9 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBooleanEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBooleanEvaluator.java @@ -7,12 +7,8 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; -import org.elasticsearch.compute.data.BooleanBlock; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -24,7 +20,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBytesRefEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBytesRefEvaluator.java index 7113f004c17c9..15703773ca15a 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBytesRefEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InBytesRefEvaluator.java @@ -7,13 +7,11 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +22,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InDoubleEvaluator.java index 99ffa891b9c7c..ce03061ce2729 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InDoubleEvaluator.java @@ -7,13 +7,10 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +21,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InIntEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InIntEvaluator.java index d6c160c0e45d3..077ec4d473794 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InIntEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InIntEvaluator.java @@ -7,13 +7,10 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +21,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InLongEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InLongEvaluator.java index 9f9b05b4c9c54..e54a89865c41b 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InLongEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InLongEvaluator.java @@ -7,13 +7,10 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +21,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InMillisNanosEvaluator.java index d95aa2f52550e..14ca1e9a768fb 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InMillisNanosEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InMillisNanosEvaluator.java @@ -7,13 +7,10 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +21,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InNanosMillisEvaluator.java index 6461acde51187..278ecdba7d9e5 100644 --- a/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InNanosMillisEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated-src/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InNanosMillisEvaluator.java @@ -7,13 +7,10 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; -import org.elasticsearch.compute.data.BooleanVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +21,6 @@ import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 3e2ffa706b441..1e0577193cab2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -173,8 +173,13 @@ public class Analyzer extends ParameterizedRuleExecutor( "Resolution", - new ResolveRefs(), + /* + * ImplicitCasting must be before ResolveRefs. Because a reference is created for a Bucket in Aggregate's aggregates, + * resolving this reference before implicit casting may cause this reference to have customMessage=true, it prevents further + * attempts to resolve this reference. + */ new ImplicitCasting(), + new ResolveRefs(), new ResolveUnionTypes() // Must be after ResolveRefs, so union types can be found ), new Batch<>("Finish Analysis", Limiter.ONCE, new AddImplicitLimit(), new AddImplicitForkLimit(), new UnionTypesCleanup()) @@ -569,7 +574,7 @@ private Aggregate resolveAggregate(Aggregate aggregate, List children } } - if (Resolvables.resolved(groupings) == false || Resolvables.resolved(aggregates) == false) { + if (Resolvables.resolved(groupings) == false || (Resolvables.resolved(aggregates) == false)) { ArrayList resolved = new ArrayList<>(); for (Expression e : groupings) { Attribute attr = Expressions.attribute(e); @@ -580,29 +585,17 @@ private Aggregate resolveAggregate(Aggregate aggregate, List children List resolvedList = NamedExpressions.mergeOutputAttributes(resolved, childrenOutput); List newAggregates = new ArrayList<>(); - // If the groupings are not resolved, skip the resolution of the references to groupings in the aggregates, resolve the - // aggregations that do not reference to groupings, so that the fields/attributes referenced by the aggregations can be - // resolved, and verifier doesn't report field/reference/column not found errors for them. - boolean groupingResolved = Resolvables.resolved(groupings); - int size = groupingResolved ? aggregates.size() : aggregates.size() - groupings.size(); - for (int i = 0; i < aggregates.size(); i++) { - NamedExpression maybeResolvedAgg = aggregates.get(i); - if (i < size) { // Skip resolving references to groupings in the aggregations if the groupings are not resolved yet. - maybeResolvedAgg = (NamedExpression) maybeResolvedAgg.transformUp(UnresolvedAttribute.class, ua -> { - Expression ne = ua; - Attribute maybeResolved = maybeResolveAttribute(ua, resolvedList); - // An item in aggregations can reference to groupings explicitly, if groupings are not resolved yet and - // maybeResolved is not resolved, return the original UnresolvedAttribute, so that it has another chance - // to get resolved in the next iteration. - // For example STATS c = count(emp_no), x = d::int + 1 BY d = (date == "2025-01-01") - if (groupingResolved || maybeResolved.resolved()) { - changed.set(true); - ne = maybeResolved; - } - return ne; - }); - } - newAggregates.add(maybeResolvedAgg); + for (NamedExpression ag : aggregate.aggregates()) { + var agg = (NamedExpression) ag.transformUp(UnresolvedAttribute.class, ua -> { + Expression ne = ua; + Attribute maybeResolved = maybeResolveAttribute(ua, resolvedList); + if (maybeResolved != null) { + changed.set(true); + ne = maybeResolved; + } + return ne; + }); + newAggregates.add(agg); } // TODO: remove this when Stats interface is removed diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/X-CoalesceEvaluator.java.st b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/X-CoalesceEvaluator.java.st index 83a32ad6d0f2b..33841f03f7803 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/X-CoalesceEvaluator.java.st +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/X-CoalesceEvaluator.java.st @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.nulls; -// begin generated imports $if(BytesRef)$ import org.apache.lucene.util.BytesRef; $endif$ @@ -24,7 +23,6 @@ import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import java.util.List; import java.util.stream.IntStream; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Coalesce}. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/X-InEvaluator.java.st b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/X-InEvaluator.java.st index 41b9d36cd4749..0cdea82119cae 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/X-InEvaluator.java.st +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/X-InEvaluator.java.st @@ -7,13 +7,26 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; -// begin generated imports +$if(BytesRef)$ import org.apache.lucene.util.BytesRef; +$endif$ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BooleanBlock; -import org.elasticsearch.compute.data.$Type$Block; -import org.elasticsearch.compute.data.$Type$Vector; +$if(int)$ +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +$elseif(long)$ +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +$elseif(double)$ +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +$elseif(BytesRef)$ +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +$elseif(boolean)$ import org.elasticsearch.compute.data.BooleanVector; +$endif$ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; @@ -24,7 +37,6 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import java.util.Arrays; import java.util.BitSet; -// end generated imports /** * {@link EvalOperator.ExpressionEvaluator} implementation for {@link In}. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java index a32bf3a720088..169ac2ac8c0fe 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java @@ -7,13 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; -import org.elasticsearch.xpack.esql.plan.logical.EsRelation; -import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.rule.Rule; @@ -59,26 +55,12 @@ public OptimizerExpressionRule(TransformDirection direction) { @Override public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) { return direction == TransformDirection.DOWN - ? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx)) - : plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx)); + ? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx)) + : plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx)); } protected abstract Expression rule(E e, LogicalOptimizerContext ctx); - /** - * Defines if a node should be visited or not. - * Allows to skip nodes that are not applicable for the rule even if they contain expressions. - * By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses. - */ - protected boolean shouldVisit(Node node) { - return switch (node) { - case EsRelation relation -> false; - case Project project -> false;// this covers project, keep and drop - case Limit limit -> false; - default -> true; - }; - } - public Class expressionToken() { return expressionTypeToken; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java index 81a89950b0a02..6f92351981f14 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java @@ -18,7 +18,6 @@ import java.util.List; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Predicate; /** * There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan} @@ -110,36 +109,22 @@ public PlanType transformExpressionsOnlyUp(Class typeT return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } + public PlanType transformExpressionsDown(Function rule) { + return transformExpressionsDown(Expression.class, rule); + } + public PlanType transformExpressionsDown(Class typeToken, Function rule) { return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); } - public PlanType transformExpressionsDown( - Predicate> shouldVisit, - Class typeToken, - Function rule - ) { - return transformDown( - shouldVisit, - t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))) - ); + public PlanType transformExpressionsUp(Function rule) { + return transformExpressionsUp(Expression.class, rule); } public PlanType transformExpressionsUp(Class typeToken, Function rule) { return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } - public PlanType transformExpressionsUp( - Predicate> shouldVisit, - Class typeToken, - Function rule - ) { - return transformUp( - shouldVisit, - t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))) - ); - } - @SuppressWarnings("unchecked") private static Object doTransformExpression(Object arg, Function traversal) { if (arg instanceof Expression exp) { @@ -199,10 +184,18 @@ public void forEachExpression(Class typeToken, Consume forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule))); } + public void forEachExpressionDown(Consumer rule) { + forEachExpressionDown(Expression.class, rule); + } + public void forEachExpressionDown(Class typeToken, Consumer rule) { forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule))); } + public void forEachExpressionUp(Consumer rule) { + forEachExpressionUp(Expression.class, rule); + } + public void forEachExpressionUp(Class typeToken, Consumer rule) { forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule))); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java index 8e887d1e92c25..5e1afe1452d99 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.logical.BinaryPlan; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -26,59 +25,16 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.esql.common.Failure.fail; -import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; -import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; -import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; -import static org.elasticsearch.xpack.esql.core.type.DataType.COUNTER_DOUBLE; -import static org.elasticsearch.xpack.esql.core.type.DataType.COUNTER_INTEGER; -import static org.elasticsearch.xpack.esql.core.type.DataType.COUNTER_LONG; -import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; -import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; -import static org.elasticsearch.xpack.esql.core.type.DataType.DOC_DATA_TYPE; -import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; -import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_SHAPE; -import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; -import static org.elasticsearch.xpack.esql.core.type.DataType.OBJECT; -import static org.elasticsearch.xpack.esql.core.type.DataType.PARTIAL_AGG; -import static org.elasticsearch.xpack.esql.core.type.DataType.SOURCE; import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; -import static org.elasticsearch.xpack.esql.core.type.DataType.TIME_DURATION; -import static org.elasticsearch.xpack.esql.core.type.DataType.TSID_DATA_TYPE; -import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; -import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED; -import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION; import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes; import static org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes.LEFT; public class Join extends BinaryPlan implements PostAnalysisVerificationAware, SortAgnostic { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Join", Join::new); - public static final DataType[] UNSUPPORTED_TYPES = { - TEXT, - VERSION, - UNSIGNED_LONG, - GEO_POINT, - GEO_SHAPE, - CARTESIAN_POINT, - CARTESIAN_SHAPE, - UNSUPPORTED, - NULL, - COUNTER_LONG, - COUNTER_INTEGER, - COUNTER_DOUBLE, - DATE_NANOS, - OBJECT, - SOURCE, - DATE_PERIOD, - TIME_DURATION, - DOC_DATA_TYPE, - TSID_DATA_TYPE, - PARTIAL_AGG, - AGGREGATE_METRIC_DOUBLE }; private final JoinConfig config; private List lazyOutput; @@ -261,7 +217,7 @@ public void postAnalysisVerification(Failures failures) { for (int i = 0; i < config.leftFields().size(); i++) { Attribute leftField = config.leftFields().get(i); Attribute rightField = config.rightFields().get(i); - if (comparableTypes(leftField, rightField) == false) { + if (leftField.dataType().noText() != rightField.dataType().noText()) { failures.add( fail( leftField, @@ -273,18 +229,11 @@ public void postAnalysisVerification(Failures failures) { ) ); } - // TODO: Add support for VERSION by implementing QueryList.versionTermQueryList similar to ipTermQueryList - if (Arrays.stream(UNSUPPORTED_TYPES).anyMatch(t -> rightField.dataType().equals(t))) { + if (rightField.dataType().equals(TEXT)) { failures.add( fail(leftField, "JOIN with right field [{}] of type [{}] is not supported", rightField.name(), rightField.dataType()) ); } } } - - private static boolean comparableTypes(Attribute left, Attribute right) { - // TODO: Consider allowing more valid types - // return left.dataType().noText() == right.dataType().noText() || left.dataType().isNumeric() == right.dataType().isNumeric(); - return left.dataType().noText() == right.dataType().noText(); - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java index f76d9643e4a6d..d333eb7238148 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java @@ -189,20 +189,6 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu if (changed.compareAndSet(true, false) == false) { break; } - var pendingRetries = new HashSet(); - for (ShardId shardId : pendingShardIds) { - if (targetShards.getShard(shardId).remainingNodes.isEmpty()) { - var failure = shardFailures.get(shardId); - if (failure != null && failure.fatal == false && failure.failure instanceof NoShardAvailableActionException) { - pendingRetries.add(shardId); - } - } - } - if (pendingRetries.isEmpty() == false && remainingUnavailableShardResolutionAttempts.decrementAndGet() >= 0) { - for (var entry : resolveShards(pendingRetries).entrySet()) { - targetShards.getShard(entry.getKey()).remainingNodes.addAll(entry.getValue()); - } - } for (ShardId shardId : pendingShardIds) { if (targetShards.getShard(shardId).remainingNodes.isEmpty()) { shardFailures.compute( @@ -271,11 +257,26 @@ private void sendOneNodeRequest(TargetShards targetShards, ComputeListener compu final ActionListener listener = computeListener.acquireCompute(); sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener() { + private final Set pendingRetries = new HashSet<>(); + void onAfter(DriverCompletionInfo info) { nodePermits.get(request.node).release(); if (concurrentRequests != null) { concurrentRequests.release(); } + + if (pendingRetries.isEmpty() == false && remainingUnavailableShardResolutionAttempts.decrementAndGet() >= 0) { + try { + sendingLock.lock(); + var resolutions = resolveShards(pendingRetries); + for (var entry : resolutions.entrySet()) { + targetShards.shards.get(entry.getKey()).remainingNodes.addAll(entry.getValue()); + } + } finally { + sendingLock.unlock(); + } + } + trySendingRequestsForPendingShards(targetShards, computeListener); listener.onResponse(info); } @@ -292,6 +293,7 @@ public void onResponse(DataNodeComputeResponse response) { final ShardId shardId = entry.getKey(); trackShardLevelFailure(shardId, false, entry.getValue()); pendingShardIds.add(shardId); + maybeScheduleRetry(shardId, false, entry.getValue()); } onAfter(response.completionInfo()); } @@ -301,6 +303,7 @@ public void onFailure(Exception e, boolean receivedData) { for (ShardId shardId : request.shardIds) { trackShardLevelFailure(shardId, receivedData, e); pendingShardIds.add(shardId); + maybeScheduleRetry(shardId, receivedData, e); } onAfter(DriverCompletionInfo.EMPTY); } @@ -314,6 +317,14 @@ public void onSkip() { onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())); } } + + private void maybeScheduleRetry(ShardId shardId, boolean receivedData, Exception e) { + if (receivedData == false + && targetShards.getShard(shardId).remainingNodes.isEmpty() + && unwrapFailure(shardId, e) instanceof NoShardAvailableActionException) { + pendingRetries.add(shardId); + } + } }); } diff --git a/x-pack/plugin/esql/src/main/plugin-metadata/plugin-security.codebases b/x-pack/plugin/esql/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..ecae5129b3563 --- /dev/null +++ b/x-pack/plugin/esql/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1 @@ +arrow: org.elasticsearch.xpack.esql.arrow.AllocationManagerShim diff --git a/x-pack/plugin/esql/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/esql/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..22884437add88 --- /dev/null +++ b/x-pack/plugin/esql/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +grant codeBase "${codebase.arrow}" { + // Needed for AllocationManagerShim + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; +}; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index f3afa24969f33..3a5f133702f57 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -47,11 +47,8 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; -import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; -import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.index.EsIndex; @@ -83,7 +80,6 @@ import org.elasticsearch.xpack.esql.session.IndexResolver; import java.io.IOException; -import java.time.Period; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -114,9 +110,6 @@ import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomValueOtherThanTest; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; -import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; -import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; -import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToString; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -375,7 +368,7 @@ public void testNoProjection() { DataType.INTEGER, DataType.KEYWORD, DataType.TEXT, - DATETIME, + DataType.DATETIME, DataType.TEXT, DataType.KEYWORD, DataType.INTEGER, @@ -3796,147 +3789,6 @@ public void testResolveCompletionOutputField() { assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField()))); } - public void testResolveGroupingsBeforeResolvingImplicitReferencesToGroupings() { - var plan = analyze(""" - FROM test - | EVAL date = "2025-01-01"::datetime - | STATS c = count(emp_no) BY d = (date == "2025-01-01") - """, "mapping-default.json"); - - var limit = as(plan, Limit.class); - var agg = as(limit.child(), Aggregate.class); - var aggregates = agg.aggregates(); - assertThat(aggregates, hasSize(2)); - Alias a = as(aggregates.get(0), Alias.class); - assertEquals("c", a.name()); - Count c = as(a.child(), Count.class); - FieldAttribute fa = as(c.field(), FieldAttribute.class); - assertEquals("emp_no", fa.name()); - ReferenceAttribute ra = as(aggregates.get(1), ReferenceAttribute.class); // reference in aggregates is resolved - assertEquals("d", ra.name()); - List groupings = agg.groupings(); - assertEquals(1, groupings.size()); - a = as(groupings.get(0), Alias.class); // reference in groupings is resolved - assertEquals("d", ra.name()); - Equals equals = as(a.child(), Equals.class); - ra = as(equals.left(), ReferenceAttribute.class); - assertEquals("date", ra.name()); - Literal literal = as(equals.right(), Literal.class); - assertEquals("2025-01-01T00:00:00.000Z", dateTimeToString(Long.parseLong(literal.value().toString()))); - assertEquals(DATETIME, literal.dataType()); - } - - public void testResolveGroupingsBeforeResolvingExplicitReferencesToGroupings() { - var plan = analyze(""" - FROM test - | EVAL date = "2025-01-01"::datetime - | STATS c = count(emp_no), x = d::int + 1 BY d = (date == "2025-01-01") - """, "mapping-default.json"); - - var limit = as(plan, Limit.class); - var agg = as(limit.child(), Aggregate.class); - var aggregates = agg.aggregates(); - assertThat(aggregates, hasSize(3)); - Alias a = as(aggregates.get(0), Alias.class); - assertEquals("c", a.name()); - Count c = as(a.child(), Count.class); - FieldAttribute fa = as(c.field(), FieldAttribute.class); - assertEquals("emp_no", fa.name()); - a = as(aggregates.get(1), Alias.class); // explicit reference to groupings is resolved - assertEquals("x", a.name()); - Add add = as(a.child(), Add.class); - ToInteger toInteger = as(add.left(), ToInteger.class); - ReferenceAttribute ra = as(toInteger.field(), ReferenceAttribute.class); - assertEquals("d", ra.name()); - ra = as(aggregates.get(2), ReferenceAttribute.class); // reference in aggregates is resolved - assertEquals("d", ra.name()); - List groupings = agg.groupings(); - assertEquals(1, groupings.size()); - a = as(groupings.get(0), Alias.class); // reference in groupings is resolved - assertEquals("d", ra.name()); - Equals equals = as(a.child(), Equals.class); - ra = as(equals.left(), ReferenceAttribute.class); - assertEquals("date", ra.name()); - Literal literal = as(equals.right(), Literal.class); - assertEquals("2025-01-01T00:00:00.000Z", dateTimeToString(Long.parseLong(literal.value().toString()))); - assertEquals(DATETIME, literal.dataType()); - } - - public void testBucketWithIntervalInStringInBothAggregationAndGrouping() { - var plan = analyze(""" - FROM test - | STATS c = count(emp_no), b = BUCKET(hire_date, "1 year") + 1 year BY yr = BUCKET(hire_date, "1 year") - """, "mapping-default.json"); - - var limit = as(plan, Limit.class); - var agg = as(limit.child(), Aggregate.class); - var aggregates = agg.aggregates(); - assertThat(aggregates, hasSize(3)); - Alias a = as(aggregates.get(0), Alias.class); - assertEquals("c", a.name()); - Count c = as(a.child(), Count.class); - FieldAttribute fa = as(c.field(), FieldAttribute.class); - assertEquals("emp_no", fa.name()); - a = as(aggregates.get(1), Alias.class); // explicit reference to groupings is resolved - assertEquals("b", a.name()); - Add add = as(a.child(), Add.class); - Bucket bucket = as(add.left(), Bucket.class); - fa = as(bucket.field(), FieldAttribute.class); - assertEquals("hire_date", fa.name()); - Literal literal = as(bucket.buckets(), Literal.class); - Literal oneYear = new Literal(EMPTY, Period.ofYears(1), DATE_PERIOD); - assertEquals(oneYear, literal); - literal = as(add.right(), Literal.class); - assertEquals(oneYear, literal); - ReferenceAttribute ra = as(aggregates.get(2), ReferenceAttribute.class); // reference in aggregates is resolved - assertEquals("yr", ra.name()); - List groupings = agg.groupings(); - assertEquals(1, groupings.size()); - a = as(groupings.get(0), Alias.class); // reference in groupings is resolved - assertEquals("yr", ra.name()); - bucket = as(a.child(), Bucket.class); - fa = as(bucket.field(), FieldAttribute.class); - assertEquals("hire_date", fa.name()); - literal = as(bucket.buckets(), Literal.class); - assertEquals(oneYear, literal); - } - - public void testBucketWithIntervalInStringInGroupingReferencedInAggregation() { - var plan = analyze(""" - FROM test - | STATS c = count(emp_no), b = yr + 1 year BY yr = BUCKET(hire_date, "1 year") - """, "mapping-default.json"); - - var limit = as(plan, Limit.class); - var agg = as(limit.child(), Aggregate.class); - var aggregates = agg.aggregates(); - assertThat(aggregates, hasSize(3)); - Alias a = as(aggregates.get(0), Alias.class); - assertEquals("c", a.name()); - Count c = as(a.child(), Count.class); - FieldAttribute fa = as(c.field(), FieldAttribute.class); - assertEquals("emp_no", fa.name()); - a = as(aggregates.get(1), Alias.class); // explicit reference to groupings is resolved - assertEquals("b", a.name()); - Add add = as(a.child(), Add.class); - ReferenceAttribute ra = as(add.left(), ReferenceAttribute.class); - assertEquals("yr", ra.name()); - Literal oneYear = new Literal(EMPTY, Period.ofYears(1), DATE_PERIOD); - Literal literal = as(add.right(), Literal.class); - assertEquals(oneYear, literal); - ra = as(aggregates.get(2), ReferenceAttribute.class); // reference in aggregates is resolved - assertEquals("yr", ra.name()); - List groupings = agg.groupings(); - assertEquals(1, groupings.size()); - a = as(groupings.get(0), Alias.class); // reference in groupings is resolved - assertEquals("yr", ra.name()); - Bucket bucket = as(a.child(), Bucket.class); - fa = as(bucket.field(), FieldAttribute.class); - assertEquals("hire_date", fa.name()); - literal = as(bucket.buckets(), Literal.class); - assertEquals(oneYear, literal); - } - @Override protected IndexAnalyzers createDefaultIndexAnalyzers() { return super.createDefaultIndexAnalyzers(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index f66133d10a8e2..e220eac5ab7a1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -26,13 +26,9 @@ import org.elasticsearch.compute.test.BlockTestUtils; import org.elasticsearch.compute.test.TestBlockFactory; import org.elasticsearch.indices.CrankyCircuitBreakerService; -import org.elasticsearch.license.License; -import org.elasticsearch.license.XPackLicenseState; -import org.elasticsearch.license.internal.XPackLicenseStatus; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.esql.LicenseAware; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; @@ -58,6 +54,7 @@ import org.junit.After; import org.junit.AfterClass; +import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -733,8 +730,7 @@ public void testSerializationOfSimple() { */ @AfterClass public static void testFunctionInfo() { - Class testClass = getTestClass(); - Logger log = LogManager.getLogger(testClass); + Logger log = LogManager.getLogger(getTestClass()); FunctionDefinition definition = definition(functionName()); if (definition == null) { log.info("Skipping function info checks because the function isn't registered"); @@ -757,7 +753,7 @@ public static void testFunctionInfo() { for (int i = 0; i < args.size(); i++) { typesFromSignature.add(new HashSet<>()); } - for (Map.Entry, DataType> entry : signatures(testClass).entrySet()) { + for (Map.Entry, DataType> entry : signatures(getTestClass()).entrySet()) { List types = entry.getKey(); for (int i = 0; i < args.size() && i < types.size(); i++) { typesFromSignature.get(i).add(types.get(i).esNameIfPossible()); @@ -800,101 +796,6 @@ public static void testFunctionInfo() { assertEquals(returnFromSignature, returnTypes); } - /** - * This test is meant to validate that the license checks documented match those enforced. - * The expectations are set in the test class using a method with this signature: - * - * public static License.OperationMode licenseRequirement(List<DataType> fieldTypes); - * - * License enforcement in the function class is achieved using the interface LicenseAware. - * This test will make sure the two are in agreement, and does not require that the function class actually - * report its license level. If we add license checks to any function, but fail to also add the expected - * license level to the test class, this test will fail. - */ - @AfterClass - public static void testFunctionLicenseChecks() throws Exception { - Class testClass = getTestClass(); - Logger log = LogManager.getLogger(testClass); - FunctionDefinition definition = definition(functionName()); - if (definition == null) { - log.info("Skipping function info checks because the function isn't registered"); - return; - } - log.info("Running function license checks"); - DocsV3Support.LicenseRequirementChecker licenseChecker = new DocsV3Support.LicenseRequirementChecker(testClass); - License.OperationMode functionLicense = licenseChecker.invoke(null); - Constructor ctor = constructorWithFunctionInfo(definition.clazz()); - if (LicenseAware.class.isAssignableFrom(definition.clazz()) == false) { - // Perform simpler no-signature tests - assertThat( - "Function " + definition.name() + " should be licensed under " + functionLicense, - functionLicense, - equalTo(License.OperationMode.BASIC) - ); - return; - } - // For classes with LicenseAware, we need to check that the license is correct - TestCheckLicense checkLicense = new TestCheckLicense(); - - // Go through all signatures and assert that the license is as expected - signatures(testClass).forEach((signature, returnType) -> { - try { - License.OperationMode license = licenseChecker.invoke(signature); - assertNotNull("License should not be null", license); - - // Construct an instance of the class and then call it's licenseCheck method, and compare the results - Object[] args = new Object[signature.size() + 1]; - args[0] = Source.EMPTY; - for (int i = 0; i < signature.size(); i++) { - args[i + 1] = new Literal(Source.EMPTY, null, signature.get(i)); - } - Object instance = ctor.newInstance(args); - // Check that object implements the LicenseAware interface - if (LicenseAware.class.isAssignableFrom(instance.getClass())) { - LicenseAware licenseAware = (LicenseAware) instance; - switch (license) { - case BASIC -> checkLicense.assertLicenseCheck(licenseAware, signature, true, true, true); - case PLATINUM -> checkLicense.assertLicenseCheck(licenseAware, signature, false, true, true); - case ENTERPRISE -> checkLicense.assertLicenseCheck(licenseAware, signature, false, false, true); - } - } else { - fail("Function " + definition.name() + " does not implement LicenseAware"); - } - } catch (Exception e) { - fail(e); - } - }); - } - - private static class TestCheckLicense { - XPackLicenseState basicLicense = makeLicenseState(License.OperationMode.BASIC); - XPackLicenseState platinumLicense = makeLicenseState(License.OperationMode.PLATINUM); - XPackLicenseState enterpriseLicense = makeLicenseState(License.OperationMode.ENTERPRISE); - - private void assertLicenseCheck( - LicenseAware licenseAware, - List signature, - boolean allowsBasic, - boolean allowsPlatinum, - boolean allowsEnterprise - ) { - boolean basic = licenseAware.licenseCheck(basicLicense); - boolean platinum = licenseAware.licenseCheck(platinumLicense); - boolean enterprise = licenseAware.licenseCheck(enterpriseLicense); - assertThat("Basic license should be accepted for " + signature, basic, equalTo(allowsBasic)); - assertThat("Platinum license should be accepted for " + signature, platinum, equalTo(allowsPlatinum)); - assertThat("Enterprise license should be accepted for " + signature, enterprise, equalTo(allowsEnterprise)); - } - - private void assertLicenseCheck(List signature, boolean allowed, boolean expected) { - assertThat("Basic license should " + (expected ? "" : "not ") + "be accepted for " + signature, allowed, equalTo(expected)); - } - } - - private static XPackLicenseState makeLicenseState(License.OperationMode mode) { - return new XPackLicenseState(System::currentTimeMillis, new XPackLicenseStatus(mode, true, "")); - } - /** * Asserts the result of a test case matches the expected result and warnings. *

@@ -964,7 +865,7 @@ public static Map, DataType> signatures(Class testClass) { } @AfterClass - public static void renderDocs() throws Exception { + public static void renderDocs() throws IOException { if (System.getProperty("generateDocs") == null) { return; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/DocsV3SupportTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/DocsV3SupportTests.java index 86fe852ab9002..275849b9bb0fb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/DocsV3SupportTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/DocsV3SupportTests.java @@ -228,7 +228,7 @@ public void testRenderingExampleResultEmojis() throws IOException { assertThat(results, equalTo(expectedResults)); } - public void testRenderingExampleFromClass() throws Exception { + public void testRenderingExampleFromClass() throws IOException { String expected = """ % This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. @@ -306,7 +306,7 @@ public void testRenderingExampleFromClass() throws Exception { assertThat(rendered.trim(), equalTo(expected.trim())); } - public void testRenderingLayoutFromClass() throws Exception { + public void testRenderingLayoutFromClass() throws IOException { String expected = """ % This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. @@ -353,7 +353,7 @@ public void testRenderingLayoutFromClass() throws Exception { assertThat(rendered.trim(), equalTo(expected.trim())); } - private TestDocsFileWriter renderTestClassDocs() throws Exception { + private TestDocsFileWriter renderTestClassDocs() throws IOException { FunctionInfo info = functionInfo(TestClass.class); assert info != null; FunctionDefinition definition = EsqlFunctionRegistry.def(TestClass.class, TestClass::new, "count"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregationTestCase.java deleted file mode 100644 index 3ac6179cd1a3a..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregationTestCase.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.expression.function.aggregate; - -import org.elasticsearch.license.License; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; - -import java.util.List; - -public abstract class SpatialAggregationTestCase extends AbstractAggregationTestCase { - - /** - * All spatial aggregations have the same licensing requirements, which is that the function itself is not licensed, but - * the field types are. Aggregations over shapes are licensed under platinum, while aggregations over points are licensed under basic. - * @param fieldTypes (null for the function itself, otherwise a map of field named to types) - * @return The license requirement for the function with that type signature - */ - protected static License.OperationMode licenseRequirement(List fieldTypes) { - if (fieldTypes == null || fieldTypes.isEmpty()) { - // The function itself is not licensed, but the field types are. - return License.OperationMode.BASIC; - } - if (fieldTypes.stream().anyMatch(DataType::isSpatialShape)) { - // Only aggregations over shapes are licensed under platinum. - return License.OperationMode.PLATINUM; - } - // All other field types are licensed under basic. - return License.OperationMode.BASIC; - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidTests.java index 25bb6e242cf8c..a99cb8f60e3fa 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.geometry.Point; import org.elasticsearch.geometry.utils.GeometryValidator; import org.elasticsearch.geometry.utils.WellKnownBinary; -import org.elasticsearch.license.License; import org.elasticsearch.search.aggregations.metrics.CompensatedSum; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -40,10 +39,6 @@ public SpatialCentroidTests(@Name("TestCase") Supplier fieldTypes) { - return SpatialAggregationTestCase.licenseRequirement(fieldTypes); - } - @ParametersFactory public static Iterable parameters() { var suppliers = Stream.of( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialExtentTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialExtentTests.java index a73a00741e3c5..9a0a62ce2d06e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialExtentTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialExtentTests.java @@ -17,12 +17,12 @@ import org.elasticsearch.geometry.utils.SpatialEnvelopeVisitor; import org.elasticsearch.geometry.utils.SpatialEnvelopeVisitor.WrapLongitude; import org.elasticsearch.geometry.utils.WellKnownBinary; -import org.elasticsearch.license.License; import org.elasticsearch.test.hamcrest.RectangleMatcher; import org.elasticsearch.test.hamcrest.WellKnownBinaryBytesRefMatcher; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.FunctionName; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier.IncludingAltitude; @@ -33,15 +33,11 @@ import java.util.stream.Stream; @FunctionName("st_extent_agg") -public class SpatialExtentTests extends SpatialAggregationTestCase { +public class SpatialExtentTests extends AbstractAggregationTestCase { public SpatialExtentTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } - public static License.OperationMode licenseRequirement(List fieldTypes) { - return SpatialAggregationTestCase.licenseRequirement(fieldTypes); - } - @ParametersFactory public static Iterable parameters() { var suppliers = Stream.of( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java index f69bb7eb3e7bb..dfdfd82d08afe 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java @@ -11,7 +11,6 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.license.License; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -35,10 +34,6 @@ public CategorizeTests(@Name("TestCase") Supplier tes this.testCase = testCaseSupplier.get(); } - public static License.OperationMode licenseRequirement(List fieldTypes) { - return License.OperationMode.PLATINUM; - } - @ParametersFactory public static Iterable parameters() { List suppliers = new ArrayList<>(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java index c2397f0340e67..5578b2bfb1919 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.junit.AfterClass; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.function.Function; @@ -154,7 +155,7 @@ static Expression buildRLike(Logger logger, Source source, List args } @AfterClass - public static void renderNotRLike() throws Exception { + public static void renderNotRLike() throws IOException { renderNegatedOperator(constructorWithFunctionInfo(RLike.class), "RLIKE", d -> d, getTestClass()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java index bb59f9e501efb..3bf14fb8c3475 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.junit.AfterClass; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; @@ -91,7 +92,7 @@ static Expression buildWildcardLike(Source source, List args) { } @AfterClass - public static void renderNotLike() throws Exception { + public static void renderNotLike() throws IOException { renderNegatedOperator(constructorWithFunctionInfo(WildcardLike.class), "LIKE", d -> d, getTestClass()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/CastOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/CastOperatorTests.java index ff1f1f69a6590..c1f5b9faed8e4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/CastOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/CastOperatorTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.junit.AfterClass; +import java.io.IOException; import java.util.List; import java.util.Map; @@ -25,7 +26,7 @@ public void testDummy() { } @AfterClass - public static void renderDocs() throws Exception { + public static void renderDocs() throws IOException { if (System.getProperty("generateDocs") == null) { return; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/NullPredicatesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/NullPredicatesTests.java index 69bfcc99a21ea..7af61532fed4c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/NullPredicatesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/NullPredicatesTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToStringTests; import org.junit.AfterClass; +import java.io.IOException; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -31,7 +32,7 @@ public void testDummy() { } @AfterClass - public static void renderDocs() throws Exception { + public static void renderDocs() throws IOException { if (System.getProperty("generateDocs") == null) { return; } @@ -61,7 +62,7 @@ public static void renderDocs() throws Exception { ); } - private static void renderNullPredicate(DocsV3Support.OperatorConfig op) throws Exception { + private static void renderNullPredicate(DocsV3Support.OperatorConfig op) throws IOException { var docs = new DocsV3Support.OperatorsDocsSupport(op.name(), NullPredicatesTests.class, op, NullPredicatesTests::signatures); docs.renderSignature(); docs.renderDocs(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java index 449389accc37b..aed08b32bb6d3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.junit.AfterClass; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -318,7 +319,7 @@ protected Expression build(Source source, List args) { } @AfterClass - public static void renderNotIn() throws Exception { + public static void renderNotIn() throws IOException { renderNegatedOperator( constructorWithFunctionInfo(In.class), "IN", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java index b36cb3f6c6a42..e163d082249b4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java @@ -8,38 +8,31 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; -import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.TestUtils; import org.elasticsearch.xpack.esql.expression.predicate.Range; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; -import org.elasticsearch.xpack.esql.parser.EsqlParser; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; -import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.core.util.TestUtils.of; -import static org.hamcrest.Matchers.containsInAnyOrder; public class OptimizerRulesTests extends ESTestCase { - private static final Literal FIVE = of(5); - private static final Literal SIX = of(6); + private static final Literal FIVE = L(5); + private static final Literal SIX = L(6); - public static final class DummyBooleanExpression extends Expression { + public static class DummyBooleanExpression extends Expression { private final int id; @@ -94,13 +87,21 @@ public boolean equals(Object obj) { } } + private static Literal L(Object value) { + return of(value); + } + + private static FieldAttribute getFieldAttribute() { + return TestUtils.getFieldAttribute("a"); + } + // // Range optimization // // 6 < a <= 5 -> FALSE public void testFoldExcludingRangeToFalse() { - FieldAttribute fa = getFieldAttribute("a"); + FieldAttribute fa = getFieldAttribute(); Range r = rangeOf(fa, SIX, false, FIVE, true); assertTrue(r.foldable()); @@ -109,35 +110,13 @@ public void testFoldExcludingRangeToFalse() { // 6 < a <= 5.5 -> FALSE public void testFoldExcludingRangeWithDifferentTypesToFalse() { - FieldAttribute fa = getFieldAttribute("a"); + FieldAttribute fa = getFieldAttribute(); - Range r = rangeOf(fa, SIX, false, of(5.5d), true); + Range r = rangeOf(fa, SIX, false, L(5.5d), true); assertTrue(r.foldable()); assertEquals(Boolean.FALSE, r.fold(FoldContext.small())); } - public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() { - var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) { - private final List appliedTo = new ArrayList<>(); + // Conjunction - @Override - protected Expression rule(Expression e, LogicalOptimizerContext ctx) { - appliedTo.add(e); - return e; - } - }; - - rule.apply( - new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"), - new LogicalOptimizerContext(null, FoldContext.small()) - ); - - var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER); - var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1"); - var add = new Add(new Source(1, 20, "f1+1"), attribute, literal); - var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add); - - // contains expressions only from EVAL - assertThat(rule.appliedTo, containsInAnyOrder(alias, add, attribute, literal)); - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java index d10185d95a913..ce0ef53eadb3f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java @@ -501,46 +501,6 @@ public void testRetryOnlyMovedShards() { assertThat("Must retry only affected shards", resolvedShards, contains(shard2)); } - public void testRetryUnassignedShardWithoutPartialResults() { - var attempt = new AtomicInteger(0); - var future = sendRequests(false, -1, List.of(targetShard(shard1, node1), targetShard(shard2, node2)), shardIds -> { - attempt.incrementAndGet(); - return Map.of(shard1, List.of()); - }, - (node, shardIds, aliasFilters, listener) -> runWithDelay( - () -> listener.onResponse( - Objects.equals(shardIds, List.of(shard2)) - ? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()) - : new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) - ) - ) - - ); - expectThrows(NoShardAvailableActionException.class, containsString("no such shard"), future::actionGet); - } - - public void testRetryUnassignedShardWithPartialResults() { - var response = safeGet( - sendRequests( - true, - -1, - List.of(targetShard(shard1, node1), targetShard(shard2, node2)), - shardIds -> Map.of(shard1, List.of()), - (node, shardIds, aliasFilters, listener) -> runWithDelay( - () -> listener.onResponse( - Objects.equals(shardIds, List.of(shard2)) - ? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()) - : new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) - ) - ) - ) - ); - assertThat(response.totalShards, equalTo(2)); - assertThat(response.successfulShards, equalTo(1)); - assertThat(response.skippedShards, equalTo(0)); - assertThat(response.failedShards, equalTo(1)); - } - static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) { return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null); } diff --git a/x-pack/plugin/graph/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/graph/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..16701ab74d8c9 --- /dev/null +++ b/x-pack/plugin/graph/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/x-pack/plugin/identity-provider/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/identity-provider/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..0310ce4542dbb --- /dev/null +++ b/x-pack/plugin/identity-provider/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,16 @@ +grant { + permission java.lang.RuntimePermission "setFactory"; + + // ApacheXMLSecurityInitializer + permission java.util.PropertyPermission "org.apache.xml.security.ignoreLineBreaks", "read,write"; + permission java.security.SecurityPermission "org.apache.xml.security.register"; + + // needed during initialization of OpenSAML library where xml security algorithms are registered + // see https://github.com/apache/santuario-java/blob/e79f1fe4192de73a975bc7246aee58ed0703343d/src/main/java/org/apache/xml/security/utils/JavaUtils.java#L205-L220 + // and https://git.shibboleth.net/view/?p=java-opensaml.git;a=blob;f=opensaml-xmlsec-impl/src/main/java/org/opensaml/xmlsec/signature/impl/SignatureMarshaller.java;hb=db0eaa64210f0e32d359cd6c57bedd57902bf811#l52 + // which uses it in the opensaml-xmlsec-impl + permission java.security.SecurityPermission "org.apache.xml.security.register"; + + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 58aa9b29f8565..fba8d9e61f0c4 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -8,7 +8,6 @@ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' -apply plugin: 'elasticsearch.internal-test-artifact' restResources { restApi { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 682eebd0fa69b..71a4d20d3ad4c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(22)); + assertThat(services.size(), equalTo(23)); var providers = providers(services); @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "custom", "deepseek", "elastic", "elasticsearch", @@ -70,7 +71,7 @@ private Iterable providers(List services) { public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(16)); + assertThat(services.size(), equalTo(17)); var providers = providers(services); @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "custom", "elasticsearch", "googleaistudio", "googlevertexai", @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(7)); + assertThat(services.size(), equalTo(8)); var providers = providers(services); @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { List.of( "alibabacloud-ai-search", "cohere", + "custom", "elasticsearch", "googlevertexai", "jinaai", @@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(10)); + assertThat(services.size(), equalTo(11)); var providers = providers(services); @@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "custom", "deepseek", "googleaistudio", "openai", @@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); - assertThat(services.size(), equalTo(6)); + assertThat(services.size(), equalTo(7)); var providers = providers(services); @@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { containsInAnyOrder( List.of( "alibabacloud-ai-search", + "custom", "elastic", "elasticsearch", "hugging_face", diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index b2f8ba5475eb8..e34018c5b8df1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -34,7 +34,6 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.DequeUtils; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; @@ -257,24 +256,37 @@ public void cancel() {} "object": "chat.completion.chunk" } */ - private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) { - return new StreamingUnifiedChatCompletionResults.Results( - DequeUtils.of( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - "id", - List.of( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null), - null, - 0 - ) - ), - "gpt-4o-2024-08-06", - "chat.completion.chunk", - null - ) - ) - ); + private InferenceServiceResults.Result unifiedCompletionChunk(String delta) { + return new InferenceServiceResults.Result() { + @Override + public String getWriteableName() { + return "test_unifiedCompletionChunk"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(delta); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return ChunkedToXContentHelper.chunk( + (b, p) -> b.startObject() + .field("id", "id") + .startArray("choices") + .startObject() + .startObject("delta") + .field("content", delta) + .endObject() + .field("index", 0) + .endObject() + .endArray() + .field("model", "gpt-4o-2024-08-06") + .field("object", "chat.completion.chunk") + .endObject() + ); + } + }; } @Override diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 8405fba22460f..074678bbea095 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -16,7 +17,6 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; @@ -242,10 +242,12 @@ public void testRestart() throws Exception { private void assertRandomBulkOperations(String indexName, Function> sourceSupplier) throws Exception { int numHits = numHits(indexName); - int totalBulkReqs = randomIntBetween(2, 10); + int totalBulkReqs = randomIntBetween(2, 100); + long totalDocs = numHits; Set ids = new HashSet<>(); - for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) { - BulkRequestBuilder bulkReqBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) { + BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); int totalBulkSize = randomIntBetween(1, 100); for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { if (ids.size() > 0 && rarely(random())) { @@ -255,15 +257,24 @@ private void assertRandomBulkOperations(String indexName, Function source = sourceSupplier.apply(isIndexRequest); if (isIndexRequest) { - String id = randomAlphaOfLength(20); bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source)); ids.add(id); } else { - String id = randomFrom(ids); - bulkReqBuilder.add(new UpdateRequestBuilder(client()).setIndex(indexName).setId(id).setDoc(source)); + boolean isUpsert = randomBoolean(); + UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source); + if (isUpsert || ids.size() == 0) { + request.setDocAsUpsert(true); + } else { + // Update already existing document + id = randomFrom(ids); + } + request.setId(id); + bulkReqBuilder.add(request); + ids.add(id); } } BulkResponse bulkResponse = bulkReqBuilder.get(); @@ -282,7 +293,8 @@ private void assertRandomBulkOperations(String indexName, Function getNamedWriteables() { addAlibabaCloudSearchNamedWriteables(namedWriteables); addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); + addCustomNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -165,6 +175,38 @@ public static List getNamedWriteables() { return namedWriteables; } + private static void addCustomNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new) + ); + + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new)); + + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new)); + + namedWriteables.add( + new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + CustomResponseParser.class, + SparseEmbeddingResponseParser.NAME, + SparseEmbeddingResponseParser::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new) + ); + + namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new)); + + namedWriteables.add( + new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new) + ); + } + private static void addUnifiedNamedWriteables(List namedWriteables) { var writeables = UnifiedCompletionRequest.getNamedWriteables(); namedWriteables.addAll(writeables); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 87256494a60e0..a8d783eacbca0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -119,6 +119,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.custom.CustomService; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; @@ -276,17 +277,13 @@ public Collection createComponents(PluginServices services) { var inferenceServices = new ArrayList<>(inferenceServiceExtensions); inferenceServices.add(this::getInferenceServiceFactories); - var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); - inferenceServiceSettings.init(services.clusterService()); - // Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`). var elasticInferenceServiceHttpClientManager = HttpClientManager.create( settings, services.threadPool(), services.clusterService(), throttlerManager, - getSslService(), - inferenceServiceSettings.getConnectionTtl() + getSslService() ); var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory( @@ -296,6 +293,9 @@ public Collection createComponents(PluginServices services) { ); elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); + var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); + inferenceServiceSettings.init(services.clusterService()); + var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( inferenceServiceSettings.getElasticInferenceServiceUrl(), services.threadPool() @@ -396,6 +396,7 @@ public List getInferenceServiceFactories() { context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new CustomService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index bc9d87f43ada0..eeea8a28df486 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -177,7 +177,7 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); + parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener); } private void parseAndStoreModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java index ddf19ff0dc96f..6d09c9e67b363 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java @@ -32,7 +32,6 @@ import java.io.Closeable; import java.io.IOException; import java.util.List; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX; @@ -113,15 +112,14 @@ public static HttpClientManager create( ThreadPool threadPool, ClusterService clusterService, ThrottlerManager throttlerManager, - SSLService sslService, - TimeValue connectionTtl + SSLService sslService ) { // Set the sslStrategy to ensure an encrypted connection, as Elastic Inference Service requires it. SSLIOSessionStrategy sslioSessionStrategy = sslService.sslIOSessionStrategy( sslService.getSSLConfiguration(ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX) ); - PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy, connectionTtl); + PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy); return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager); } @@ -148,7 +146,7 @@ public static HttpClientManager create( this.addSettingsUpdateConsumers(clusterService); } - private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy, TimeValue connectionTtl) { + private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy) { ConnectingIOReactor ioReactor; try { var configBuilder = IOReactorConfig.custom().setSoKeepAlive(true); @@ -164,15 +162,7 @@ private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIO .register("https", sslStrategy) .build(); - return new PoolingNHttpClientConnectionManager( - ioReactor, - null, - registry, - null, - null, - Math.toIntExact(connectionTtl.getMillis()), - TimeUnit.MILLISECONDS - ); + return new PoolingNHttpClientConnectionManager(ioReactor, registry); } private static PoolingNHttpClientConnectionManager createConnectionManager() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java index 1786ee98fcd87..f384d79adae3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java @@ -11,6 +11,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; import java.io.ByteArrayOutputStream; import java.util.concurrent.Flow; @@ -21,7 +22,7 @@ public boolean isSuccessfulResponse() { return RestStatus.isSuccessful(response.getStatusLine().getStatusCode()); } - public Flow.Publisher toHttpResult() { + public Flow.Publisher toHttpResult(HttpRequest httpRequest) { return subscriber -> body().subscribe(new Flow.Subscriber<>() { @Override public void onSubscribe(Flow.Subscription subscription) { @@ -45,7 +46,7 @@ public void onComplete() { }); } - public void readFullResponse(ActionListener fullResponse) { + public void readFullResponse(HttpRequest httpRequest, ActionListener fullResponse) { var stream = new ByteArrayOutputStream(); AtomicReference upstream = new AtomicReference<>(null); body.subscribe(new Flow.Subscriber<>() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 3dac8d849ba6f..56e994be86eb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; protected final String requestType; - private final ResponseParser parseFunction; + protected final ResponseParser parseFunction; private final Function errorParseFunction; private final boolean canHandleStreamingResponses; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java index d009ec87d5776..e8cb5d3ad16d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java @@ -115,11 +115,12 @@ public void tryAction(ActionListener listener) { try { if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) { - httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> { + var httpRequest = request.createHttpRequest(); + httpClient.stream(httpRequest, context, retryableListener.delegateFailure((l, r) -> { if (r.isSuccessfulResponse()) { - l.onResponse(responseHandler.parseResult(request, r.toHttpResult())); + l.onResponse(responseHandler.parseResult(request, r.toHttpResult(httpRequest))); } else { - r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> { + r.readFullResponse(httpRequest, l.delegateFailureAndWrap((ll, httpResult) -> { try { responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true); InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index d15414e34aef1..548f65d4f93fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -32,7 +32,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.BlockLoader; @@ -99,7 +98,6 @@ import java.util.function.Supplier; import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ; -import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; @@ -1079,8 +1077,7 @@ private static Mapper.Builder createEmbeddingsField( denseVectorMapperBuilder.elementType(modelSettings.elementType()); DenseVectorFieldMapper.IndexOptions defaultIndexOptions = null; - if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ) - || indexVersionCreated.between(SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0)) { + if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ)) { defaultIndexOptions = defaultSemanticDenseIndexOptions(); } if (defaultIndexOptions != null diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java index 838e6512d805f..655e11996d522 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java @@ -20,7 +20,6 @@ import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; -import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; @@ -50,15 +49,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient taskType = TaskType.ANY; // task type must be defined in the body } - var inferTimeout = parseTimeout(restRequest); var content = restRequest.requiredContent(); - var request = new PutInferenceModelAction.Request( - taskType, - inferenceEntityId, - content, - restRequest.getXContentType(), - inferTimeout - ); + var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType()); return channel -> client.execute( PutInferenceModelAction.INSTANCE, request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index bdcadb2277c2b..428c266379f65 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -21,9 +22,11 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; @@ -31,6 +34,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; @@ -80,6 +84,11 @@ public static T removeAsType(Map sourceMap, String key, Clas */ @SuppressWarnings("unchecked") public static T removeAsType(Map sourceMap, String key, Class type, ValidationException validationException) { + if (sourceMap == null) { + validationException.addValidationError(Strings.format("Encountered a null input map while parsing field [%s]", key)); + return null; + } + Object o = sourceMap.remove(key); if (o == null) { return null; @@ -188,6 +197,12 @@ public static void throwIfNotEmptyMap(Map settingsMap, String se } } + public static void throwIfNotEmptyMap(Map settingsMap, String field, String scope) { + if (settingsMap != null && settingsMap.isEmpty() == false) { + throw ServiceUtils.unknownSettingsError(settingsMap, field, scope); + } + } + public static ElasticsearchStatusException unknownSettingsError(Map config, String serviceName) { // TODO map as JSON return new ElasticsearchStatusException( @@ -198,6 +213,16 @@ public static ElasticsearchStatusException unknownSettingsError(Map config, String field, String scope) { + return new ElasticsearchStatusException( + "Model configuration contains unknown settings [{}] while parsing field [{}] for settings [{}]", + RestStatus.BAD_REQUEST, + config, + field, + scope + ); + } + public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithEmbeddingDetails(Class invalidModelType) { throw new ElasticsearchStatusException( Strings.format("Can't update embedding details for model with unexpected type %s", invalidModelType), @@ -249,6 +274,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } + public static String mustBeNonEmptyMap(String settingName, String scope) { + return Strings.format("[%s] Invalid value empty map. [%s] must be a non-empty map", scope, settingName); + } + public static String invalidTimeValueMsg(String timeValueStr, String settingName, String scope, String exceptionMsg) { return Strings.format( "[%s] Invalid time value [%s]. [%s] must be a valid time value string: %s", @@ -422,6 +451,236 @@ public static Integer extractRequiredPositiveInteger( return field; } + @SuppressWarnings("unchecked") + public static Map extractRequiredMap( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Map requiredField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (requiredField == null) { + validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(settingName, scope)); + } else if (requiredField.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope)); + } + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return requiredField; + } + + @SuppressWarnings("unchecked") + public static Map extractOptionalMap( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Map optionalField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return optionalField; + } + + public static List> extractOptionalListOfStringTuples( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + List optionalField = ServiceUtils.removeAsType(map, settingName, List.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (optionalField == null) { + return null; + } + + var tuples = new ArrayList>(); + for (int tuplesIndex = 0; tuplesIndex < optionalField.size(); tuplesIndex++) { + + var tupleEntry = optionalField.get(tuplesIndex); + if (tupleEntry instanceof List == false) { + validationException.addValidationError( + Strings.format( + "[%s] failed to parse tuple list entry [%d] for setting [%s], expected a list but the entry is [%s]", + scope, + tuplesIndex, + settingName, + tupleEntry.getClass().getSimpleName() + ) + ); + throw validationException; + } + + var listEntry = (List) tupleEntry; + if (listEntry.size() != 2) { + validationException.addValidationError( + Strings.format( + "[%s] failed to parse tuple list entry [%d] for setting [%s], the tuple list size must be two, but was [%d]", + scope, + tuplesIndex, + settingName, + listEntry.size() + ) + ); + throw validationException; + } + + var firstElement = listEntry.get(0); + var secondElement = listEntry.get(1); + validateString(firstElement, settingName, scope, "the first element", tuplesIndex, validationException); + validateString(secondElement, settingName, scope, "the second element", tuplesIndex, validationException); + tuples.add(new Tuple<>((String) firstElement, (String) secondElement)); + } + + return tuples; + } + + private static void validateString( + Object tupleValue, + String settingName, + String scope, + String elementDescription, + int index, + ValidationException validationException + ) { + if (tupleValue instanceof String == false) { + validationException.addValidationError( + Strings.format( + "[%s] failed to parse tuple list entry [%d] for setting [%s], %s must be a string but was [%s]", + scope, + index, + settingName, + elementDescription, + tupleValue.getClass().getSimpleName() + ) + ); + throw validationException; + } + } + + /** + * Validates that each value in the map is a {@link String} and returns a new map of {@code Map}. + */ + public static Map validateMapStringValues( + Map map, + String settingName, + ValidationException validationException, + boolean censorValue + ) { + if (map == null) { + return Map.of(); + } + + validateMapValues(map, List.of(String.class), settingName, validationException, censorValue); + + return map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (String) e.getValue())); + } + + /** + * Ensures the values of the map match one of the supplied types. + * @param map Map to validate + * @param allowedTypes List of {@link Class} to accept + * @param settingName the setting name for the field + * @param validationException exception to return if one of the values is invalid + * @param censorValue if true the key and value will be included in the exception message + */ + public static void validateMapValues( + Map map, + List> allowedTypes, + String settingName, + ValidationException validationException, + boolean censorValue + ) { + if (map == null) { + return; + } + + for (var entry : map.entrySet()) { + var isAllowed = false; + + for (Class allowedType : allowedTypes) { + if (allowedType.isInstance(entry.getValue())) { + isAllowed = true; + break; + } + } + + Function errorMessage = (String[] validTypesAsStrings) -> { + if (censorValue) { + return Strings.format( + "Map field [%s] has an entry that is not valid. Value type is not one of [%s].", + settingName, + String.join(", ", validTypesAsStrings) + ); + } else { + return Strings.format( + "Map field [%s] has an entry that is not valid, [%s => %s]. Value type of [%s] is not one of [%s].", + settingName, + entry.getKey(), + entry.getValue(), + entry.getValue(), + String.join(", ", validTypesAsStrings) + ); + } + }; + + if (isAllowed == false) { + var validTypesAsStrings = allowedTypes.stream().map(Class::getSimpleName).toArray(String[]::new); + Arrays.sort(validTypesAsStrings); + + validationException.addValidationError(errorMessage.apply(validTypesAsStrings)); + throw validationException; + } + } + } + + public static Map convertMapStringsToSecureString( + Map map, + String settingName, + ValidationException validationException + ) { + if (map == null) { + return Map.of(); + } + + validateMapStringValues(map, settingName, validationException, true); + + return map.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> new SerializableSecureString((String) e.getValue()))); + } + + /** + * Removes null values. + */ + public static Map removeNullValues(Map map) { + if (map == null) { + return map; + } + + map.values().removeIf(Objects::isNull); + + return map; + } + public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java new file mode 100644 index 0000000000000..7c00b0a242f94 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.util.Map; +import java.util.Objects; + +public class CustomModel extends Model { + private final CustomRateLimitServiceSettings rateLimitServiceSettings; + + public CustomModel(ModelConfigurations configurations, ModelSecrets secrets, CustomRateLimitServiceSettings rateLimitServiceSettings) { + super(configurations, secrets); + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + public static CustomModel of(CustomModel model, Map taskSettings) { + var requestTaskSettings = CustomTaskSettings.fromMap(taskSettings); + return new CustomModel(model, CustomTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public CustomModel( + String inferenceId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + taskType, + service, + CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomTaskSettings.fromMap(taskSettings), + CustomSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + CustomModel( + String inferenceId, + TaskType taskType, + String service, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings + ) { + this( + new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + + protected CustomModel(CustomModel model, TaskSettings taskSettings) { + super(model, taskSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + protected CustomModel(CustomModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + @Override + public CustomServiceSettings getServiceSettings() { + return (CustomServiceSettings) super.getServiceSettings(); + } + + @Override + public CustomTaskSettings getTaskSettings() { + return (CustomTaskSettings) super.getTaskSettings(); + } + + @Override + public CustomSecretSettings getSecretSettings() { + return (CustomSecretSettings) super.getSecretSettings(); + } + + public CustomRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java new file mode 100644 index 0000000000000..55641bad7ccaa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface CustomRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java new file mode 100644 index 0000000000000..a112e7db26fe3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class CustomRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(CustomRequestManager.class); + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(CustomModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().hashCode()); + } + } + + private static ResponseHandler createCustomHandler(CustomModel model) { + return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser()); + } + + public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) { + return new CustomRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final CustomModel model; + private final ResponseHandler handler; + + private CustomRequestManager(CustomModel model, ThreadPool threadPool) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + this.model = model; + this.handler = createCustomHandler(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + String query; + List input; + if (inferenceInputs instanceof QueryAndDocsInputs) { + QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs); + query = queryAndDocsInputs.getQuery(); + input = queryAndDocsInputs.getChunks(); + } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) { + query = null; + input = chatInputs.getInputs(); + } else if (inferenceInputs instanceof EmbeddingsInput) { + EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs); + query = null; + input = embeddingsInput.getStringInputs(); + } else { + listener.onFailure( + new ElasticsearchStatusException( + Strings.format("Invalid input received from custom service %s", inferenceInputs.getClass().getSimpleName()), + RestStatus.BAD_REQUEST + ) + ); + return; + } + + try { + var request = new CustomRequest(query, input, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener)); + } catch (Exception e) { + // Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction + listener.onFailure( + new ElasticsearchStatusException("Failed to construct the custom service request", RestStatus.BAD_REQUEST, e) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java new file mode 100644 index 0000000000000..14a962b112ccd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; + +/** + * Defines how to handle various response types returned from the custom integration. + */ +public class CustomResponseHandler extends BaseResponseHandler { + public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) { + super(requestType, parseFunction, errorParser); + } + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + try { + return parseFunction.apply(request, result); + } catch (Exception e) { + // if we get a parse failure it's probably an incorrect configuration of the service so report the error back to the user + // immediately without retrying + throw new RetryException( + false, + new ElasticsearchStatusException( + "Failed to parse custom model response, please check that the response parser path matches the response format.", + RestStatus.BAD_REQUEST, + e + ) + ); + } + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode >= 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java new file mode 100644 index 0000000000000..e74d56e1b4fd1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; + +public class CustomSecretSettings implements SecretSettings { + public static final String NAME = "custom_secret_settings"; + public static final String SECRET_PARAMETERS = "secret_parameters"; + + public static CustomSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + + Map requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException); + removeNullValues(requestSecretParamsMap); + var secureStringMap = convertMapStringsToSecureString(requestSecretParamsMap, SECRET_PARAMETERS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomSecretSettings(secureStringMap); + } + + private final Map secretParameters; + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return fromMap(new HashMap<>(newSecrets)); + } + + public CustomSecretSettings(@Nullable Map secretParameters) { + this.secretParameters = Objects.requireNonNullElse(secretParameters, Map.of()); + } + + public CustomSecretSettings(StreamInput in) throws IOException { + secretParameters = in.readImmutableMap(SerializableSecureString::new); + } + + public Map getSecretParameters() { + return secretParameters; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (secretParameters.isEmpty() == false) { + builder.startObject(SECRET_PARAMETERS); + { + for (var entry : secretParameters.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + } + builder.endObject(); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(secretParameters, (streamOutput, v) -> { v.writeTo(streamOutput); }); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomSecretSettings that = (CustomSecretSettings) o; + return Objects.equals(secretParameters, that.secretParameters); + } + + @Override + public int hashCode() { + return Objects.hash(secretParameters); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java new file mode 100644 index 0000000000000..e30a0ab564026 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -0,0 +1,279 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +public class CustomService extends SenderService { + public static final String NAME = "custom"; + private static final String SERVICE_NAME = "Custom"; + + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + TaskType.RERANK, + TaskType.COMPLETION + ); + + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + CustomModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + private static CustomModel createModelWithoutLoggingDeprecations( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + ConfigurationParseContext.PERSISTENT + ); + } + + private static CustomModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + ConfigurationParseContext context + ) { + if (supportedTaskTypes.contains(taskType) == false) { + throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + } + return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + } + + @Override + public CustomModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); + } + + @Override + public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof CustomModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + CustomModel customModel = (CustomModel) model; + + var overriddenModel = CustomModel.of(customModel, taskSettings); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME); + var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof CustomModel customModel && customModel.getTaskType() == TaskType.TEXT_EMBEDDING) { + var newServiceSettings = getCustomServiceSettings(customModel, embeddingSize); + + return new CustomModel(customModel, newServiceSettings); + } else { + throw new ElasticsearchStatusException( + Strings.format( + "Can't update embedding details for model of type: [%s], task type: [%s]", + model.getClass().getSimpleName(), + model.getTaskType() + ), + RestStatus.BAD_REQUEST + ); + } + } + + private static CustomServiceSettings getCustomServiceSettings(CustomModel customModel, int embeddingSize) { + var serviceSettings = customModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + return new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + similarityToUse, + embeddingSize, + serviceSettings.getMaxInputTokens(), + serviceSettings.elementType() + ), + serviceSettings.getUrl(), + serviceSettings.getHeaders(), + serviceSettings.getQueryParameters(), + serviceSettings.getRequestContentString(), + serviceSettings.getResponseJsonParser(), + serviceSettings.rateLimitSettings(), + serviceSettings.getErrorParser() + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + // TODO revisit this + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(supportedTaskTypes) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index d0f9faf283aef..d40a265b6ef19 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -7,7 +7,52 @@ package org.elasticsearch.xpack.inference.services.custom; -public class CustomServiceSettings { +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; + +public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings { public static final String NAME = "custom_service_settings"; public static final String URL = "url"; public static final String HEADERS = "headers"; @@ -16,4 +61,366 @@ public class CustomServiceSettings { public static final String RESPONSE = "response"; public static final String JSON_PARSER = "json_parser"; public static final String ERROR_PARSER = "error_parser"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); + + public static CustomServiceSettings fromMap( + Map map, + ConfigurationParseContext context, + TaskType taskType, + String inferenceId + ) { + ValidationException validationException = new ValidationException(); + + var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException); + + String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + var queryParams = QueryParameters.fromMap(map, validationException); + + Map headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException); + removeNullValues(headers); + var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false); + + Map requestBodyMap = extractRequiredMap(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException); + + String requestContentString = extractRequiredString( + Objects.requireNonNullElse(requestBodyMap, new HashMap<>()), + REQUEST_CONTENT, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + Map responseParserMap = extractRequiredMap( + map, + RESPONSE, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + Map jsonParserMap = extractRequiredMap( + Objects.requireNonNullElse(responseParserMap, new HashMap<>()), + JSON_PARSER, + RESPONSE_SCOPE, + validationException + ); + + var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException); + + Map errorParserMap = extractRequiredMap( + Objects.requireNonNullElse(responseParserMap, new HashMap<>()), + ERROR_PARSER, + RESPONSE_SCOPE, + validationException + ); + + var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CustomService.NAME, + context + ); + + if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null || errorParserMap == null) { + throw validationException; + } + + throwIfNotEmptyMap(requestBodyMap, REQUEST, NAME); + throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME); + throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME); + throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomServiceSettings( + textEmbeddingSettings, + url, + stringHeaders, + queryParams, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser + ); + } + + public record TextEmbeddingSettings( + @Nullable SimilarityMeasure similarityMeasure, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable DenseVectorFieldMapper.ElementType elementType + ) implements ToXContentFragment, Writeable { + + // This specifies float for the element type but null for all other settings + public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings( + null, + null, + null, + DenseVectorFieldMapper.ElementType.FLOAT + ); + + // This refers to settings that are not related to the text embedding task type (all the settings should be null) + public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); + + public static TextEmbeddingSettings fromMap(Map map, TaskType taskType, ValidationException validationException) { + if (taskType != TaskType.TEXT_EMBEDDING) { + return NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; + } + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT); + } + + public TextEmbeddingSettings(StreamInput in) throws IOException { + this( + in.readOptionalEnum(SimilarityMeasure.class), + in.readOptionalVInt(), + in.readOptionalVInt(), + in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(similarityMeasure); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(elementType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (similarityMeasure != null) { + builder.field(SIMILARITY, similarityMeasure); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + return builder; + } + } + + private final TextEmbeddingSettings textEmbeddingSettings; + private final String url; + private final Map headers; + private final QueryParameters queryParameters; + private final String requestContentString; + private final CustomResponseParser responseJsonParser; + private final RateLimitSettings rateLimitSettings; + private final ErrorResponseParser errorParser; + + public CustomServiceSettings( + TextEmbeddingSettings textEmbeddingSettings, + String url, + @Nullable Map headers, + @Nullable QueryParameters queryParameters, + String requestContentString, + CustomResponseParser responseJsonParser, + @Nullable RateLimitSettings rateLimitSettings, + ErrorResponseParser errorParser + ) { + this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); + this.url = Objects.requireNonNull(url); + this.headers = Collections.unmodifiableMap(Objects.requireNonNullElse(headers, Map.of())); + this.queryParameters = Objects.requireNonNullElse(queryParameters, QueryParameters.EMPTY); + this.requestContentString = Objects.requireNonNull(requestContentString); + this.responseJsonParser = Objects.requireNonNull(responseJsonParser); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.errorParser = Objects.requireNonNull(errorParser); + } + + public CustomServiceSettings(StreamInput in) throws IOException { + textEmbeddingSettings = new TextEmbeddingSettings(in); + url = in.readString(); + headers = in.readImmutableMap(StreamInput::readString); + queryParameters = new QueryParameters(in); + requestContentString = in.readString(); + responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); + rateLimitSettings = new RateLimitSettings(in); + errorParser = new ErrorResponseParser(in); + } + + @Override + public String modelId() { + // returning null because the model id is embedded in the url or the request body + return null; + } + + @Override + public SimilarityMeasure similarity() { + return textEmbeddingSettings.similarityMeasure; + } + + @Override + public Integer dimensions() { + return textEmbeddingSettings.dimensions; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return textEmbeddingSettings.elementType; + } + + public Integer getMaxInputTokens() { + return textEmbeddingSettings.maxInputTokens; + } + + public String getUrl() { + return url; + } + + public Map getHeaders() { + return headers; + } + + public QueryParameters getQueryParameters() { + return queryParameters; + } + + public String getRequestContentString() { + return requestContentString; + } + + public CustomResponseParser getResponseJsonParser() { + return responseJsonParser; + } + + public ErrorResponseParser getErrorParser() { + return errorParser; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + textEmbeddingSettings.toXContent(builder, params); + builder.field(URL, url); + + if (headers.isEmpty() == false) { + builder.field(HEADERS, headers); + } + + queryParameters.toXContent(builder, params); + + builder.startObject(REQUEST); + { + builder.field(REQUEST_CONTENT, requestContentString); + } + builder.endObject(); + + builder.startObject(RESPONSE); + { + responseJsonParser.toXContent(builder, params); + errorParser.toXContent(builder, params); + } + builder.endObject(); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + textEmbeddingSettings.writeTo(out); + out.writeString(url); + out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); + queryParameters.writeTo(out); + out.writeString(requestContentString); + out.writeNamedWriteable(responseJsonParser); + rateLimitSettings.writeTo(out); + errorParser.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomServiceSettings that = (CustomServiceSettings) o; + return Objects.equals(textEmbeddingSettings, that.textEmbeddingSettings) + && Objects.equals(url, that.url) + && Objects.equals(headers, that.headers) + && Objects.equals(queryParameters, that.queryParameters) + && Objects.equals(requestContentString, that.requestContentString) + && Objects.equals(responseJsonParser, that.responseJsonParser) + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(errorParser, that.errorParser); + } + + @Override + public int hashCode() { + return Objects.hash( + textEmbeddingSettings, + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser + ); + } + + private static CustomResponseParser extractResponseParser( + TaskType taskType, + Map responseParserMap, + ValidationException validationException + ) { + if (responseParserMap == null) { + return NoopResponseParser.INSTANCE; + } + + return switch (taskType) { + case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case RERANK -> RerankResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + default -> throw new IllegalArgumentException( + Strings.format("Invalid task type received [%s] while constructing response parser", taskType) + ); + }; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java new file mode 100644 index 0000000000000..1ca07ae0caf19 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues; + +public class CustomTaskSettings implements TaskSettings { + public static final String NAME = "custom_task_settings"; + + public static final String PARAMETERS = "parameters"; + + static final CustomTaskSettings EMPTY_SETTINGS = new CustomTaskSettings(new HashMap<>()); + + public static CustomTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Map parameters = extractOptionalMap(map, PARAMETERS, ModelConfigurations.TASK_SETTINGS, validationException); + removeNullValues(parameters); + validateMapValues( + parameters, + List.of(String.class, Integer.class, Double.class, Float.class, Boolean.class), + PARAMETERS, + validationException, + false + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomTaskSettings(Objects.requireNonNullElse(parameters, new HashMap<>())); + } + + /** + * Creates a new {@link CustomTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link CustomTaskSettings} + */ + public static CustomTaskSettings of(CustomTaskSettings originalSettings, CustomTaskSettings requestTaskSettings) { + var copy = new HashMap<>(originalSettings.parameters); + requestTaskSettings.parameters.forEach((key, value) -> copy.merge(key, value, (originalValue, requestValue) -> requestValue)); + return new CustomTaskSettings(copy); + } + + private final Map parameters; + + public CustomTaskSettings(StreamInput in) throws IOException { + parameters = in.readGenericMap(); + } + + public CustomTaskSettings(Map parameters) { + this.parameters = Objects.requireNonNull(parameters); + } + + public Map getParameters() { + return parameters; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (parameters.isEmpty() == false) { + builder.field(PARAMETERS, parameters); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericMap(parameters); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomTaskSettings that = (CustomTaskSettings) o; + return Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(parameters); + } + + @Override + public boolean isEmpty() { + return parameters.isEmpty(); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + CustomTaskSettings updatedSettings = CustomTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java new file mode 100644 index 0000000000000..2b5bc2fe964b3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples; + +public record QueryParameters(List parameters) implements ToXContentFragment, Writeable { + + public static final QueryParameters EMPTY = new QueryParameters(List.of()); + public static final String QUERY_PARAMETERS = "query_parameters"; + + public static QueryParameters fromMap(Map map, ValidationException validationException) { + List> queryParams = extractOptionalListOfStringTuples( + map, + QUERY_PARAMETERS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return QueryParameters.fromTuples(queryParams); + } + + private static QueryParameters fromTuples(List> queryParams) { + if (queryParams == null) { + return QueryParameters.EMPTY; + } + + return new QueryParameters(queryParams.stream().map((tuple) -> new Parameter(tuple.v1(), tuple.v2())).toList()); + } + + public record Parameter(String key, String value) implements ToXContentFragment, Writeable { + public Parameter { + Objects.requireNonNull(key); + Objects.requireNonNull(value); + } + + public Parameter(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(key); + out.writeString(value); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(); + builder.value(key); + builder.value(value); + builder.endArray(); + return builder; + } + } + + public QueryParameters { + Objects.requireNonNull(parameters); + } + + public QueryParameters(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(Parameter::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(parameters); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (parameters.isEmpty() == false) { + builder.startArray(QUERY_PARAMETERS); + for (var parameter : parameters) { + parameter.toXContent(builder, params); + } + builder.endArray(); + } + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java new file mode 100644 index 0000000000000..0a50b08163260 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.StringEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_CONTENT; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL; + +public class CustomRequest implements Request { + private static final String QUERY = "query"; + private static final String INPUT = "input"; + + private final URI uri; + private final ValidatingSubstitutor jsonPlaceholderReplacer; + private final ValidatingSubstitutor stringPlaceholderReplacer; + private final CustomModel model; + + public CustomRequest(String query, List input, CustomModel model) { + this.model = Objects.requireNonNull(model); + + var stringOnlyParams = new HashMap(); + addStringParams(stringOnlyParams, model.getSecretSettings().getSecretParameters()); + addStringParams(stringOnlyParams, model.getTaskSettings().getParameters()); + + var jsonParams = new HashMap(); + addJsonStringParams(jsonParams, model.getSecretSettings().getSecretParameters()); + addJsonStringParams(jsonParams, model.getTaskSettings().getParameters()); + + if (query != null) { + jsonParams.put(QUERY, toJson(query, QUERY)); + } + + addInputJsonParam(jsonParams, input, model.getTaskType()); + + jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}"); + stringPlaceholderReplacer = new ValidatingSubstitutor(stringOnlyParams, "${", "}"); + uri = buildUri(); + } + + private static void addStringParams(Map stringParams, Map paramsToAdd) { + for (var entry : paramsToAdd.entrySet()) { + if (entry.getValue() instanceof String str) { + stringParams.put(entry.getKey(), str); + } else if (entry.getValue() instanceof SerializableSecureString serializableSecureString) { + stringParams.put(entry.getKey(), serializableSecureString.getSecureString().toString()); + } else if (entry.getValue() instanceof SecureString secureString) { + stringParams.put(entry.getKey(), secureString.toString()); + } + } + } + + private static void addJsonStringParams(Map jsonStringParams, Map params) { + for (var entry : params.entrySet()) { + jsonStringParams.put(entry.getKey(), toJson(entry.getValue(), entry.getKey())); + } + } + + private static void addInputJsonParam(Map jsonParams, List input, TaskType taskType) { + if (taskType == TaskType.COMPLETION && input.isEmpty() == false) { + jsonParams.put(INPUT, toJson(input.get(0), INPUT)); + } else { + jsonParams.put(INPUT, toJson(input, INPUT)); + } + } + + private URI buildUri() { + var replacedUrl = stringPlaceholderReplacer.replace(model.getServiceSettings().getUrl(), URL); + + try { + var builder = new URIBuilder(replacedUrl); + for (var queryParam : model.getServiceSettings().getQueryParameters().parameters()) { + builder.addParameter( + queryParam.key(), + stringPlaceholderReplacer.replace(queryParam.value(), Strings.format("query parameters: [%s]", queryParam.key())) + ); + } + return builder.build(); + } catch (URISyntaxException e) { + throw new IllegalStateException(Strings.format("Failed to build URI, error: %s", e.getMessage()), e); + } + + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpRequest = new HttpPost(uri); + + setHeaders(httpRequest); + setRequestContent(httpRequest); + + return new HttpRequest(httpRequest, getInferenceEntityId()); + } + + private void setHeaders(HttpRequestBase httpRequest) { + // Header content_type's default value, if user defines the Content-Type, it will be replaced by user's value; + httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + for (var entry : model.getServiceSettings().getHeaders().entrySet()) { + String replacedHeadersValue = stringPlaceholderReplacer.replace(entry.getValue(), Strings.format("header.%s", entry.getKey())); + httpRequest.setHeader(entry.getKey(), replacedHeadersValue); + } + } + + private void setRequestContent(HttpPost httpRequest) { + String replacedRequestContentString = jsonPlaceholderReplacer.replace( + model.getServiceSettings().getRequestContentString(), + REQUEST_CONTENT + ); + StringEntity stringEntity = new StringEntity(replacedRequestContentString, StandardCharsets.UTF_8); + httpRequest.setEntity(stringEntity); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return uri; + } + + public CustomServiceSettings getServiceSettings() { + return model.getServiceSettings(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java index 762556fb381ed..ecd3125e228c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java @@ -30,8 +30,17 @@ public class CompletionResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException); + public static CompletionResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var path = extractRequiredString( + responseParserMap, + COMPLETION_PARSER_RESULT, + String.join(".", scope, JSON_PARSER), + validationException + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java new file mode 100644 index 0000000000000..b52670b99e6a9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntity.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; + +import java.io.IOException; + +public class CustomResponseEntity { + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + if (request instanceof CustomRequest customRequest) { + var responseJsonParser = customRequest.getServiceSettings().getResponseJsonParser(); + + return responseJsonParser.parse(response); + } else { + throw new IllegalArgumentException( + Strings.format( + "Original request is an invalid type [%s], expected [%s]", + request.getClass().getSimpleName(), + CustomRequest.class.getSimpleName() + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java index d05fa68595b3a..51fb8b1486a82 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference.services.custom.response; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -21,6 +24,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Objects; import java.util.function.Function; @@ -31,30 +35,40 @@ public class ErrorResponseParser implements ToXContentFragment, Function { + private static final Logger logger = LogManager.getLogger(ErrorResponseParser.class); public static final String MESSAGE_PATH = "path"; private final String messagePath; + private final String inferenceId; - public static ErrorResponseParser fromMap(Map responseParserMap, ValidationException validationException) { - var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException); + public static ErrorResponseParser fromMap( + Map responseParserMap, + String scope, + String inferenceId, + ValidationException validationException + ) { + var path = extractRequiredString(responseParserMap, MESSAGE_PATH, String.join(".", scope, ERROR_PARSER), validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ErrorResponseParser(path); + return new ErrorResponseParser(path, inferenceId); } - public ErrorResponseParser(String messagePath) { + public ErrorResponseParser(String messagePath, String inferenceId) { this.messagePath = Objects.requireNonNull(messagePath); + this.inferenceId = Objects.requireNonNull(inferenceId); } public ErrorResponseParser(StreamInput in) throws IOException { this.messagePath = in.readString(); + this.inferenceId = in.readString(); } public void writeTo(StreamOutput out) throws IOException { out.writeString(messagePath); + out.writeString(inferenceId); } public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -86,7 +100,6 @@ public ErrorResponse apply(HttpResult httpResult) { .createParser(XContentParserConfiguration.EMPTY, httpResult.body()) ) { var map = jsonParser.map(); - // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic // if we find the top level error field we'll return a response with an empty message but indicate // that we found the structure of the error object. Here if we're missing the final field we will return @@ -97,9 +110,19 @@ public ErrorResponse apply(HttpResult httpResult) { var errorText = toType(MapPathExtractor.extract(map, messagePath).extractedObject(), String.class, messagePath); return new ErrorResponse(errorText); } catch (Exception e) { - // swallow the error + var resultAsString = new String(httpResult.body(), StandardCharsets.UTF_8); + + logger.info( + Strings.format( + "Failed to parse error object for custom service inference id [%s], message path: [%s], result as string: [%s]", + inferenceId, + messagePath, + resultAsString + ), + e + ); + + return new ErrorResponse(Strings.format("Unable to parse the error, response body: [%s]", resultAsString)); } - - return ErrorResponse.UNDEFINED_ERROR; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java index 18d3cbbad051b..0a4c2c42b8c79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java @@ -37,11 +37,16 @@ public class RerankResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - - var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException); - var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException); - var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException); + public static RerankResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var fullScope = String.join(".", scope, JSON_PARSER); + + var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, fullScope, validationException); + var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, fullScope, validationException); + var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, fullScope, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java index b6c83fd7fbfc6..7d54e90865122 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java @@ -35,10 +35,15 @@ public class SparseEmbeddingResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, JSON_PARSER, validationException); + public static SparseEmbeddingResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var fullScope = String.join(".", scope, JSON_PARSER); + var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, fullScope, validationException); - var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, JSON_PARSER, validationException); + var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, fullScope, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -149,6 +154,7 @@ private static SparseEmbeddingResults.Embedding createEmbedding( // Alibaba can return a token id which is an integer and needs to be converted to a string var tokenIdAsString = token.toString(); + try { var weightAsFloat = toFloat(weight, weightFieldName); weightedTokens.add(new WeightedToken(tokenIdAsString, weightAsFloat)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java index fe5b4ec236282..b5b0a191f3c4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -30,8 +30,17 @@ public class TextEmbeddingResponseParser extends BaseCustomResponseParser responseParserMap, ValidationException validationException) { - var path = extractRequiredString(responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, JSON_PARSER, validationException); + public static TextEmbeddingResponseParser fromMap( + Map responseParserMap, + String scope, + ValidationException validationException + ) { + var path = extractRequiredString( + responseParserMap, + TEXT_EMBEDDING_PARSER_EMBEDDINGS, + String.join(".", scope, JSON_PARSER), + validationException + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..fe6ebb6cfb625 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -70,17 +70,6 @@ public class ElasticInferenceServiceSettings { Setting.Property.NodeScope ); - /** - * Total time to live (TTL) defines maximum life span of persistent connections regardless of their - * expiration setting. No persistent connection will be re-used past its TTL value. - * Using a TTL of -1 will disable the expiration of persistent connections (the idle connection evictor will still apply). - */ - public static final Setting CONNECTION_TTL_SETTING = Setting.timeSetting( - "xpack.inference.elastic.http.connection_ttl", - TimeValue.timeValueSeconds(60), - Setting.Property.NodeScope - ); - @Deprecated private final String eisGatewayUrl; @@ -88,7 +77,6 @@ public class ElasticInferenceServiceSettings { private final boolean periodicAuthorizationEnabled; private volatile TimeValue authRequestInterval; private volatile TimeValue maxAuthorizationRequestJitter; - private final TimeValue connectionTtl; public ElasticInferenceServiceSettings(Settings settings) { eisGatewayUrl = EIS_GATEWAY_URL.get(settings); @@ -96,7 +84,6 @@ public ElasticInferenceServiceSettings(Settings settings) { periodicAuthorizationEnabled = PERIODIC_AUTHORIZATION_ENABLED.get(settings); authRequestInterval = AUTHORIZATION_REQUEST_INTERVAL.get(settings); maxAuthorizationRequestJitter = MAX_AUTHORIZATION_REQUEST_JITTER.get(settings); - connectionTtl = CONNECTION_TTL_SETTING.get(settings); } /** @@ -128,10 +115,6 @@ public TimeValue getMaxAuthorizationRequestJitter() { return maxAuthorizationRequestJitter; } - public TimeValue getConnectionTtl() { - return connectionTtl; - } - public static List> getSettingsDefinitions() { ArrayList> settings = new ArrayList<>(); settings.add(EIS_GATEWAY_URL); @@ -141,7 +124,6 @@ public static List> getSettingsDefinitions() { settings.add(PERIODIC_AUTHORIZATION_ENABLED); settings.add(AUTHORIZATION_REQUEST_INTERVAL); settings.add(MAX_AUTHORIZATION_REQUEST_JITTER); - settings.add(CONNECTION_TTL_SETTING); return settings; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 5ab743c3d4cc0..10c8d8928ea65 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -26,6 +26,7 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; import java.util.function.BiFunction; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @@ -59,11 +60,21 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor< public static final String TOTAL_TOKENS_FIELD = "total_tokens"; private final BiFunction errorParser; + private final Deque buffer = new LinkedBlockingDeque<>(); public OpenAiUnifiedStreamingProcessor(BiFunction errorParser) { this.errorParser = errorParser; } + @Override + protected void upstreamRequest(long n) { + if (buffer.isEmpty()) { + super.upstreamRequest(n); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + @Override protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); @@ -85,8 +96,15 @@ protected void next(Deque item) throws Exception { if (results.isEmpty()) { upstream().request(1); - } else { + } else if (results.size() == 1) { downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // results > 1, but openai spec only wants 1 chunk per SSE event + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); } } @@ -279,4 +297,12 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa } } } + + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(1); + deque.offer(result); + return deque; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java new file mode 100644 index 0000000000000..0ebd4cc0cad81 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/SerializableSecureString.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.settings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SerializableSecureString implements ToXContentFragment, Writeable { + + private final SecureString secureString; + + public SerializableSecureString(StreamInput in) throws IOException { + secureString = in.readSecureString(); + } + + public SerializableSecureString(SecureString secureString) { + this.secureString = Objects.requireNonNull(secureString); + } + + public SerializableSecureString(String value) { + secureString = new SecureString(value.toCharArray()); + } + + public SecureString getSecureString() { + return secureString; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.value(secureString.toString()); + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeSecureString(secureString); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + SerializableSecureString that = (SerializableSecureString) o; + return Objects.equals(secureString, that.secureString); + } + + @Override + public int hashCode() { + return Objects.hashCode(secureString); + } +} diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e36b553d2def2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +grant { + // required by: com.google.api.client.json.JsonParser#parseValue + // also required by AWS SDK for client configuration + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + + // required by: com.google.api.client.json.GenericJson# + // also by AWS SDK for Jackson's ObjectMapper + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + + // required to add google certs to the gcs client trustore + permission java.lang.RuntimePermission "setFactory"; + + // gcs client opens socket connections for to access repository + // also, AWS Bedrock client opens socket connections and needs resolve for to access to resources + permission java.net.SocketPermission "*", "connect,resolve"; + + // AWS Clients always try to check the http.proxyHost system property + permission java.util.PropertyPermission "http.proxyHost", "read"; +}; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index bd18058277d9c..b8648936956ae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -235,4 +235,8 @@ public static void assertJsonEquals(String actual, String expected) throws IOExc assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); } } + + public static Map modifiableMap(Map aMap) { + return new HashMap<>(aMap); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index e514867780669..f61398fcacacf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -7,16 +7,13 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; -import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase { +public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { @Override protected Writeable.Reader instanceReader() { return PutInferenceModelAction.Request::new; @@ -28,29 +25,38 @@ protected PutInferenceModelAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), randomBytesReference(50), - randomFrom(XContentType.values()), - randomTimeValue() + randomFrom(XContentType.values()) ); } @Override protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) { - return randomValueOtherThan(instance, this::createTestInstance); - } - - @Override - protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) { - if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) - || version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { - return instance; - } else { - return new PutInferenceModelAction.Request( + return switch (randomIntBetween(0, 3)) { + case 0 -> new PutInferenceModelAction.Request( + TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length], + instance.getInferenceEntityId(), + instance.getContent(), + instance.getContentType() + ); + case 1 -> new PutInferenceModelAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId() + "foo", + instance.getContent(), + instance.getContentType() + ); + case 2 -> new PutInferenceModelAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId(), + randomBytesReference(instance.getContent().length() + 1), + instance.getContentType() + ); + case 3 -> new PutInferenceModelAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), instance.getContent(), - instance.getContentType(), - InferenceAction.Request.DEFAULT_TIMEOUT + XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length] ); - } + default -> throw new IllegalStateException(); + }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java index b49a819a3a698..7f1003c6723a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/JsonUtilsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; import java.io.IOException; import java.util.List; @@ -52,6 +53,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws assertThat(toJson(1.1f, "field"), is("1.1")); assertThat(toJson(true, "field"), is("true")); assertThat(toJson(false, "field"), is("false")); + assertThat(toJson(new SerializableSecureString("api_key"), "field"), is("\"api_key\"")); } public void testToJson_ThrowsException_WhenUnableToSerialize() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java index 047c0c8d647fb..a22bf12d29eb6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java @@ -86,6 +86,75 @@ public void testExtract_IteratesListOfMapsToListOfMapsOfStringToDoubles() { ); } + public void testExtract_IteratesSparseEmbeddingStyleMap_ExtractsMaps() { + Map input = Map.of( + "result", + Map.of( + "sparse_embeddings", + List.of( + Map.of( + "index", + 0, + "embedding", + List.of(Map.of("tokenId", 6, "weight", 0.123d), Map.of("tokenId", 100, "weight", -123d)) + ), + Map.of( + "index", + 1, + "embedding", + List.of(Map.of("tokenId", 7, "weight", 0.456d), Map.of("tokenId", 200, "weight", -456d)) + ) + ) + ) + ); + + assertThat( + MapPathExtractor.extract(input, "$.result.sparse_embeddings[*].embedding[*]"), + is( + new MapPathExtractor.Result( + List.of( + List.of(Map.of("tokenId", 6, "weight", 0.123d), Map.of("tokenId", 100, "weight", -123d)), + List.of(Map.of("tokenId", 7, "weight", 0.456d), Map.of("tokenId", 200, "weight", -456d)) + ), + List.of("result.sparse_embeddings", "result.sparse_embeddings.embedding") + ) + ) + ); + } + + public void testExtract_IteratesSparseEmbeddingStyleMap_ExtractsFieldFromMap() { + Map input = Map.of( + "result", + Map.of( + "sparse_embeddings", + List.of( + Map.of( + "index", + 0, + "embedding", + List.of(Map.of("tokenId", 6, "weight", 0.123d), Map.of("tokenId", 100, "weight", -123d)) + ), + Map.of( + "index", + 1, + "embedding", + List.of(Map.of("tokenId", 7, "weight", 0.456d), Map.of("tokenId", 200, "weight", -456d)) + ) + ) + ) + ); + + assertThat( + MapPathExtractor.extract(input, "$.result.sparse_embeddings[*].embedding[*].tokenId"), + is( + new MapPathExtractor.Result( + List.of(List.of(6, 100), List.of(7, 200)), + List.of("result.sparse_embeddings", "result.sparse_embeddings.embedding", "result.sparse_embeddings.embedding.tokenId") + ) + ) + ); + } + public void testExtract_ReturnsNullForEmptyList() { Map input = Map.of(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java index 9ad9b9f3ca0a5..5024cf53dffa9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java @@ -32,7 +32,7 @@ public void testErrorResponse_ExtractsError() { var error = ErrorMessageResponseEntity.fromResponse(result); assertNotNull(error); - assertThat(error.getErrorMessage(), is("test_error_message")); + assertThat(error, is(new ErrorMessageResponseEntity("test_error_message"))); } public void testFromResponse_WithOtherFieldsPresent() { @@ -50,7 +50,7 @@ public void testFromResponse_WithOtherFieldsPresent() { ErrorResponse errorMessage = ErrorMessageResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertEquals("You didn't provide an API key", errorMessage.getErrorMessage()); + assertThat(errorMessage, is(new ErrorMessageResponseEntity("You didn't provide an API key"))); } public void testFromResponse_noMessage() { @@ -65,7 +65,7 @@ public void testFromResponse_noMessage() { var errorMessage = ErrorMessageResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(errorMessage.getErrorMessage(), is("")); + assertThat(errorMessage, is(new ErrorMessageResponseEntity(""))); assertTrue(errorMessage.errorStructureFound()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index c3b50cdb4a670..a00f8e55a4e27 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -31,7 +31,6 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -47,14 +46,9 @@ public static TestModel createRandomInstance(TaskType taskType) { } public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities) { - // Use a max dimension count that has a reasonable probability of being compatible with BBQ - return createRandomInstance(taskType, excludedSimilarities, BBQ_MIN_DIMS * 2); - } - - public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; var dimensions = taskType == TaskType.TEXT_EMBEDDING - ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) + ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64) : null; SimilarityMeasure similarity = null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java new file mode 100644 index 0000000000000..071c4caa90a9f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractServiceTests.java @@ -0,0 +1,538 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +/** + * Base class for testing inference services. + *

+ * This class provides common unit tests for inference services, such as testing the model creation, and calling the infer method. + * + * To use this class, extend it and pass the constructor a configuration. + *

+ */ +public abstract class AbstractServiceTests extends ESTestCase { + + protected final MockWebServer webServer = new MockWebServer(); + protected ThreadPool threadPool; + protected HttpClientManager clientManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + private final TestConfiguration testConfiguration; + + public AbstractServiceTests(TestConfiguration testConfiguration) { + this.testConfiguration = Objects.requireNonNull(testConfiguration); + } + + /** + * Main configurations for the tests + */ + public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) { + public static class Builder { + private final CommonConfig commonConfig; + private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS; + + public Builder(CommonConfig commonConfig) { + this.commonConfig = commonConfig; + } + + public Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) { + this.updateModelConfiguration = updateModelConfiguration; + return this; + } + + public TestConfiguration build() { + return new TestConfiguration(commonConfig, updateModelConfiguration); + } + } + } + + /** + * Configurations that useful for most tests + */ + public abstract static class CommonConfig { + + private final TaskType taskType; + private final TaskType unsupportedTaskType; + + public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) { + this.taskType = Objects.requireNonNull(taskType); + this.unsupportedTaskType = unsupportedTaskType; + } + + protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); + + protected abstract Map createServiceSettingsMap(TaskType taskType); + + protected abstract Map createTaskSettingsMap(); + + protected abstract Map createSecretSettingsMap(); + + protected abstract void assertModel(Model model, TaskType taskType); + + protected abstract EnumSet supportedStreamingTasks(); + } + + /** + * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests + */ + public abstract static class UpdateModelConfiguration { + + public boolean isEnabled() { + return true; + } + + protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + } + + private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { + @Override + public boolean isEnabled() { + return false; + } + + @Override + protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + throw new UnsupportedOperationException("Update model tests are disabled"); + } + }; + + public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); + + parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.TEXT_EMBEDDING); + } + } + + public void testParseRequestConfig_CreatesACompletionModel() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.COMPLETION, config, listener); + + parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.COMPLETION); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + containsString(Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType)) + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + config.put("extra_key", "value"); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType); + serviceSettings.put("extra_key", "value"); + var config = getRequestConfigMap( + serviceSettings, + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap(); + taskSettings.put("extra_key", "value"); + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + taskSettings, + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createTaskSettingsMap(), + secretSettingsMap + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings [{extra_key=value}]")); + } + } + + // parsePersistedConfigWithSecrets + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + parseConfigTestConfig.assertModel(model, TaskType.COMPLETION); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.unsupportedTaskType, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ) + ); + + assertThat( + exception.getMessage(), + containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType)) + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + persistedConfigMap.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.taskType, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType); + serviceSettings.put("extra_key", "value"); + var persistedConfigMap = getPersistedConfigMap( + serviceSettings, + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.taskType, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); + taskSettings.put("extra_key", "value"); + var config = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + taskSettings, + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + var config = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createTaskSettingsMap(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var listener = new PlainActionFuture(); + + service.infer( + getInvalidModel("id", "service"), + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is("The internal model was invalid, please delete the service [service] with id [id] and add it again.") + ); + } + } + + public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var listener = new PlainActionFuture(); + + var exception = expectThrows( + ValidationException.class, + () -> service.infer( + getInvalidModel("id", "service"), + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;") + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) + ); + + assertThat(exception.getMessage(), containsString("Can't update embedding details for model of type:")); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var embeddingSize = randomNonNegativeInt(); + var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(null); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + var embeddingSize = randomNonNegativeInt(); + var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(SimilarityMeasure.COSINE); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.COSINE, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + // streaming tests + public void testSupportedStreamingTasks() throws Exception { + try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig.supportedStreamingTasks())); + assertFalse(service.canStream(TaskType.ANY)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 6b2731bb313b5..770f85c866ba7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -11,26 +11,36 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalTimeValue; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -920,7 +930,326 @@ public void testValidateInputType_ValidationErrorsWhenInputTypeIsSpecified() { assertThat(validationException.validationErrors().size(), is(4)); } - private static Map modifiableMap(Map aMap) { - return new HashMap<>(aMap); + public void testExtractRequiredMap() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertTrue(validation.validationErrors().isEmpty()); + assertThat(extractedMap, is(Map.of("key", "value"))); + } + + public void testExtractRequiredMap_ReturnsNull_WhenTypeIsInvalid() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("setting", 123)), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat( + validation.getMessage(), + is("Validation Failed: 1: field [setting] is not of the expected type. The value [123] cannot be converted to a [Map];") + ); + } + + public void testExtractRequiredMap_ReturnsNull_WhenMissingSetting() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("not_setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat(validation.getMessage(), is("Validation Failed: 1: [scope] does not contain the required setting [setting];")); + } + + public void testExtractRequiredMap_ReturnsNull_WhenMapIsEmpty() { + var validation = new ValidationException(); + var extractedMap = extractRequiredMap(modifiableMap(Map.of("setting", Map.of())), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat( + validation.getMessage(), + is("Validation Failed: 1: [scope] Invalid value empty map. [setting] must be a non-empty map;") + ); + } + + public void testExtractOptionalMap() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertTrue(validation.validationErrors().isEmpty()); + assertThat(extractedMap, is(Map.of("key", "value"))); + } + + public void testExtractOptionalMap_ReturnsNull_WhenTypeIsInvalid() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("setting", 123)), "setting", "scope", validation); + + assertNull(extractedMap); + assertThat( + validation.getMessage(), + is("Validation Failed: 1: field [setting] is not of the expected type. The value [123] cannot be converted to a [Map];") + ); + } + + public void testExtractOptionalMap_ReturnsNull_WhenMissingSetting() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("not_setting", Map.of("key", "value"))), "setting", "scope", validation); + + assertNull(extractedMap); + assertTrue(validation.validationErrors().isEmpty()); + } + + public void testExtractOptionalMap_ReturnsEmptyMap_WhenEmpty() { + var validation = new ValidationException(); + var extractedMap = extractOptionalMap(modifiableMap(Map.of("setting", Map.of())), "setting", "scope", validation); + + assertThat(extractedMap, is(Map.of())); + } + + public void testValidateMapValues() { + var validation = new ValidationException(); + validateMapValues( + Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), + List.of(String.class, Integer.class), + "setting", + validation, + false + ); + } + + public void testValidateMapValues_IgnoresNullMap() { + var validation = new ValidationException(); + validateMapValues(null, List.of(String.class, Integer.class), "setting", validation, false); + } + + public void testValidateMapValues_ThrowsException_WhenMapContainsInvalidTypes() { + // Includes the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapValues( + Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), + List.of(String.class), + "setting", + validation, + false + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [setting] has an entry that is not valid, " + + "[num_key => 1]. Value type of [1] is not one of [String].;" + ) + ); + } + + // Does not include the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapValues( + Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), + List.of(String.class, List.class), + "setting", + validation, + true + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [setting] has an entry that is not valid. " + + "Value type is not one of [List, String].;" + ) + ); + } + } + + public void testValidateMapStringValues() { + var validation = new ValidationException(); + assertThat( + validateMapStringValues(Map.of("string_key", "abc", "string_key2", new String("awesome")), "setting", validation, false), + is(Map.of("string_key", "abc", "string_key2", "awesome")) + ); + } + + public void testValidateMapStringValues_ReturnsEmptyMap_WhenMapIsNull() { + var validation = new ValidationException(); + assertThat(validateMapStringValues(null, "setting", validation, false), is(Map.of())); + } + + public void testValidateMapStringValues_ThrowsException_WhenMapContainsInvalidTypes() { + // Includes the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapStringValues(Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), "setting", validation, false) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [setting] has an entry that is not valid, " + + "[num_key => 1]. Value type of [1] is not one of [String].;" + ) + ); + } + + // Does not include the invalid key and value in the exception message + { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> validateMapStringValues(Map.of("string_key", "abc", "num_key", Integer.valueOf(1)), "setting", validation, true) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: Map field [setting] has an entry that is not valid. Value type is not one of [String].;") + ); + } + } + + public void testConvertMapStringsToSecureString() { + var validation = new ValidationException(); + assertThat( + convertMapStringsToSecureString(Map.of("key", "value", "key2", "abc"), "setting", validation), + is(Map.of("key", new SerializableSecureString("value"), "key2", new SerializableSecureString("abc"))) + ); + } + + public void testConvertMapStringsToSecureString_ReturnsAnEmptyMap_WhenMapIsNull() { + var validation = new ValidationException(); + assertThat(convertMapStringsToSecureString(null, "setting", validation), is(Map.of())); + } + + public void testConvertMapStringsToSecureString_ThrowsException_WhenMapContainsInvalidTypes() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> convertMapStringsToSecureString(Map.of("key", "value", "key2", 123), "setting", validation) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: Map field [setting] has an entry that is not valid. Value type is not one of [String].;") + ); + } + + public void testRemoveNullValues() { + var map = new HashMap(); + map.put("key1", null); + map.put("key2", "awesome"); + map.put("key3", null); + + assertThat(removeNullValues(map), is(Map.of("key2", "awesome"))); + } + + public void testRemoveNullValues_ReturnsNull_WhenMapIsNull() { + assertNull(removeNullValues(null)); + } + + public void testExtractOptionalListOfStringTuples() { + var validation = new ValidationException(); + assertThat( + extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of("key", "value"), List.of("key2", "value2")))), + "params", + "scope", + validation + ), + is(List.of(new Tuple<>("key", "value"), new Tuple<>("key2", "value2"))) + ); + } + + public void testExtractOptionalListOfStringTuples_ReturnsNull_WhenFieldIsNotAList() { + var validation = new ValidationException(); + assertNull(extractOptionalListOfStringTuples(modifiableMap(Map.of("params", Map.of())), "params", "scope", validation)); + + assertThat( + validation.getMessage(), + is("Validation Failed: 1: field [params] is not of the expected type. The value [{}] cannot be converted to a [List];") + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleIsNotAList() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples(modifiableMap(Map.of("params", List.of("string"))), "params", "scope", validation) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry [0] for setting " + + "[params], expected a list but the entry is [String];" + ) + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleIsListSize2() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of("string")))), + "params", + "scope", + validation + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry " + + "[0] for setting [params], the tuple list size must be two, but was [1];" + ) + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleFirstElement_IsNotAString() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of(1, "value")))), + "params", + "scope", + validation + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry [0] for setting [params], " + + "the first element must be a string but was [Integer];" + ) + ); + } + + public void testExtractOptionalListOfStringTuples_Exception_WhenTupleSecondElement_IsNotAString() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> extractOptionalListOfStringTuples( + modifiableMap(Map.of("params", List.of(List.of("key", 2)))), + "params", + "scope", + validation + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [scope] failed to parse tuple list entry [0] for setting [params], " + + "the second element must be a string but was [Integer];" + ) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java new file mode 100644 index 0000000000000..c3c4a44bcab07 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; +import org.hamcrest.MatcherAssert; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class CustomModelTests extends ESTestCase { + private static final String taskSettingsKey = "test_taskSettings_key"; + private static final String taskSettingsValue = "test_taskSettings_value"; + + private static final String secretSettingsKey = "test_secret_key"; + private static final SerializableSecureString secretSettingsValue = new SerializableSecureString("test_secret_value"); + private static final String url = "http://www.abc.com"; + + public void testOverride_DoesNotModifiedFields_TaskSettingsIsEmpty() { + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + CustomServiceSettingsTests.createRandom(), + CustomTaskSettingsTests.createRandom(), + CustomSecretSettingsTests.createRandom() + ); + + var overriddenModel = CustomModel.of(model, Map.of()); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverride() { + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + CustomServiceSettingsTests.createRandom(), + new CustomTaskSettings(Map.of("key", "value")), + CustomSecretSettingsTests.createRandom() + ); + + var overriddenModel = CustomModel.of( + model, + new HashMap<>(Map.of(CustomTaskSettings.PARAMETERS, new HashMap<>(Map.of("key", "different_value")))) + ); + MatcherAssert.assertThat( + overriddenModel, + is( + createModel( + "service", + TaskType.TEXT_EMBEDDING, + model.getServiceSettings(), + new CustomTaskSettings(Map.of("key", "different_value")), + model.getSecretSettings() + ) + ) + ); + } + + public static CustomModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + return new CustomModel(inferenceId, taskType, CustomService.NAME, serviceSettings, taskSettings, secrets, null); + } + + public static CustomModel createModel( + String inferenceId, + TaskType taskType, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings + ) { + return new CustomModel(inferenceId, taskType, CustomService.NAME, serviceSettings, taskSettings, secretSettings); + } + + public static CustomModel getTestModel() { + return getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")); + } + + public static CustomModel getTestModel(TaskType taskType, CustomResponseParser responseParser) { + return getTestModel(taskType, responseParser, url); + } + + public static CustomModel getTestModel(TaskType taskType, CustomResponseParser responseParser, String url) { + var inferenceId = "inference_id"; + Integer dims = 1536; + Integer maxInputTokens = 512; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, "${" + secretSettingsKey + "}"); + String requestContentString = "\"input\":\"${input}\""; + + CustomServiceSettings serviceSettings = new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + headers, + QueryParameters.EMPTY, + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue)); + CustomSecretSettings secretSettings = new CustomSecretSettings(Map.of(secretSettingsKey, secretSettingsValue)); + + return CustomModelTests.createModel(inferenceId, taskType, serviceSettings, taskSettings, secretSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java new file mode 100644 index 0000000000000..16c058b3e0115 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CustomRequestManagerTests extends ESTestCase { + + private ThreadPool threadPool; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + public void testCreateRequest_ThrowsException_ForInvalidUrl() { + var inferenceId = "inference_id"; + + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "${url}", + null, + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of("url", "^")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var listener = new PlainActionFuture(); + var manager = CustomRequestManager.of(model, threadPool); + manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(30))); + + assertThat(exception.getMessage(), is("Failed to construct the custom service request")); + assertThat(exception.getCause().getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java new file mode 100644 index 0000000000000..a29992cd7f9fd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.hamcrest.Matchers.is; + +public class CustomSecretSettingsTests extends AbstractBWCWireSerializationTestCase { + public static CustomSecretSettings createRandom() { + Map secretParameters = randomMap( + 0, + 5, + () -> tuple(randomAlphaOfLength(5), new SerializableSecureString(randomAlphaOfLength(5))) + ); + + return new CustomSecretSettings(secretParameters); + } + + public void testFromMap() { + Map secretParameters = new HashMap<>( + Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value"))) + ); + + assertThat( + CustomSecretSettings.fromMap(secretParameters), + is(new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value")))) + ); + } + + public void testFromMap_PassedNull_ReturnsNull() { + assertNull(CustomSecretSettings.fromMap(null)); + } + + public void testFromMap_RemovesNullValues() { + var mapWithNulls = new HashMap(); + mapWithNulls.put("value", "abc"); + mapWithNulls.put("null", null); + + assertThat( + CustomSecretSettings.fromMap(modifiableMap(Map.of(CustomSecretSettings.SECRET_PARAMETERS, mapWithNulls))), + is(new CustomSecretSettings(Map.of("value", new SerializableSecureString("abc")))) + ); + } + + public void testFromMap_Throws_IfValueIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> CustomSecretSettings.fromMap( + modifiableMap(Map.of(CustomSecretSettings.SECRET_PARAMETERS, modifiableMap(Map.of("key", Map.of("another_key", "value"))))) + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [secret_parameters] has an entry that is not valid. " + + "Value type is not one of [String].;" + ) + ); + } + + public void testFromMap_DefaultsToEmptyMap_WhenSecretParametersField_DoesNotExist() { + var map = new HashMap(Map.of("key", new HashMap<>(Map.of("test_key", "test_value")))); + + assertThat(CustomSecretSettings.fromMap(map), is(new CustomSecretSettings(Map.of()))); + } + + public void testXContent() throws IOException { + var entity = new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "secret_parameters": { + "test_key": "test_value" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyParameters() throws IOException { + var entity = new CustomSecretSettings(Map.of()); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomSecretSettings::new; + } + + @Override + protected CustomSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomSecretSettings mutateInstance(CustomSecretSettings instance) { + return randomValueOtherThan(instance, CustomSecretSettingsTests::createRandom); + } + + @Override + protected CustomSecretSettings mutateInstanceForVersion(CustomSecretSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java new file mode 100644 index 0000000000000..71eec73df5375 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -0,0 +1,734 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class CustomServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + public static CustomServiceSettings createRandom(String inputUrl) { + var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); + + SimilarityMeasure similarityMeasure = null; + Integer dims = null; + var isTextEmbeddingModel = taskType.equals(TaskType.TEXT_EMBEDDING); + if (isTextEmbeddingModel) { + similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + dims = 1536; + } + var maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + var url = inputUrl != null ? inputUrl : randomAlphaOfLength(15); + Map headers = randomBoolean() ? Map.of() : Map.of("key", "value"); + var queryParameters = randomBoolean() + ? QueryParameters.EMPTY + : new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))); + var requestContentString = randomAlphaOfLength(10); + + var responseJsonParser = switch (taskType) { + case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].token_id", + "$.result.sparse_embeddings[*].embedding[*].weights" + ); + case RERANK -> new RerankResponseParser( + "$.result.reranked_results[*].index", + "$.result.reranked_results[*].relevance_score", + "$.result.reranked_results[*].document_text" + ); + case COMPLETION -> new CompletionResponseParser("$.result.text"); + default -> new NoopResponseParser(); + }; + + var errorParser = new ErrorResponseParser("$.error.message", randomAlphaOfLength(5)); + + RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); + + return new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + similarityMeasure, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser + ); + } + + public static CustomServiceSettings createRandom() { + return createRandom(randomAlphaOfLength(5)); + } + + public void testFromMap() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + Map headers = Map.of("key", "value"); + var queryParameters = List.of(List.of("key", "value")); + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + headers, + QueryParameters.QUERY_PARAMETERS, + queryParameters, + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + headers, + new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", "inference_id") + ) + ) + ); + } + + public void testFromMap_WithOptionalsNotSpecified() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.DEFAULT_FLOAT, + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", "inference_id") + ) + ) + ); + } + + public void testFromMap_RemovesNullValues_FromMaps() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + + var headersWithNulls = new HashMap(); + headersWithNulls.put("value", "abc"); + headersWithNulls.put("null", null); + + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + headersWithNulls, + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + Map.of("value", "abc"), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", "inference_id") + ) + ) + ); + } + + public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", 1)), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [headers] has an entry that is not valid, [key => 1]. " + + "Value type of [1] is not one of [String].;" + ) + ); + } + + public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.URL, + url, + QueryParameters.QUERY_PARAMETERS, + List.of(List.of("key", 1)), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [service_settings] failed to parse tuple list entry [0] " + + "for setting [query_parameters], the second element must be a string but was [Integer];" + ) + ); + } + + public void testFromMap_ReturnsError_IfRequestMapIsMissing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + "invalid_request", + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [service_settings] does not contain the required setting [request];" + + "2: [service_settings] does not contain the required setting [content];" + ) + ); + } + + public void testFromMap_ReturnsError_IfResponseMapIsMissing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + "invalid_response", + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [service_settings] does not contain the required setting [response];" + + "2: [service_settings.response] does not contain the required setting [json_parser];" + + "3: [service_settings.response] does not contain the required setting [error_parser];" + + "4: Encountered a null input map while parsing field [path];" + ) + ); + } + + public void testFromMap_ReturnsError_IfRequestMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString, "key", "value")), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [request]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + "key", + "value" + ) + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [json_parser]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")), + "key", + "value" + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [response]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfErrorParserMapIsNotEmptyAfterParsing() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) + ) + ) + ) + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + ); + + assertThat( + exception.getMessage(), + is( + "Model configuration contains unknown settings [{key=value}] while parsing field [error_parser]" + + " for settings [custom_service_settings]" + ) + ); + } + + public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var mapSettings = new HashMap( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.HEADERS, + new HashMap<>(Map.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) + ) + ) + ) + ); + + var exception = expectThrows( + IllegalArgumentException.class, + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.CHAT_COMPLETION, "inference_id") + ); + + assertThat(exception.getMessage(), is("Invalid task type received [chat_completion] while constructing response parser")); + } + + public void testXContent() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + null, + new ErrorResponseParser("$.error.message", "inference_id") + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": { + "content": "string" + }, + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding" + }, + "error_parser": { + "path": "$.error.message" + } + }, + "rate_limit": { + "requests_per_minute": 10000 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomServiceSettings::new; + } + + @Override + protected CustomServiceSettings createTestInstance() { + return createRandom(randomAlphaOfLength(5)); + } + + @Override + protected CustomServiceSettings mutateInstance(CustomServiceSettings instance) { + return randomValueOtherThan(instance, CustomServiceSettingsTests::createRandom); + } + + @Override + protected CustomServiceSettings mutateInstanceForVersion(CustomServiceSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java new file mode 100644 index 0000000000000..6ce181b0487ad --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -0,0 +1,550 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.services.AbstractServiceTests; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_INDEX; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE; +import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH; +import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CustomServiceTests extends AbstractServiceTests { + + public CustomServiceTests() { + super(createTestConfiguration()); + } + + private static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return CustomServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return CustomServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return CustomServiceTests.createTaskSettingsMap(); + } + + @Override + protected Map createSecretSettingsMap() { + return CustomServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType) { + CustomServiceTests.assertModel(model, taskType); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.noneOf(TaskType.class); + } + }).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); + case COMPLETION -> assertCompletionModel(model); + default -> fail("unexpected task type [" + taskType + "]"); + } + } + + private static void assertTextEmbeddingModel(Model model) { + var customModel = assertCommonModelFields(model); + + assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); + assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class)); + } + + private static CustomModel assertCommonModelFields(Model model) { + assertThat(model, instanceOf(CustomModel.class)); + + var customModel = (CustomModel) model; + + assertThat(customModel.getServiceSettings().getUrl(), is("http://www.abc.com")); + assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value"))); + assertThat( + customModel.getSecretSettings().getSecretParameters(), + is(Map.of("test_key", new SerializableSecureString("test_value"))) + ); + + return customModel; + } + + private static void assertCompletionModel(Model model) { + var customModel = assertCommonModelFields(model); + assertThat(customModel.getTaskType(), is(TaskType.COMPLETION)); + assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class)); + } + + private static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + var settingsMap = new HashMap<>( + Map.of( + CustomServiceSettings.URL, + "http://www.abc.com", + CustomServiceSettings.HEADERS, + Map.of("key", "value"), + QueryParameters.QUERY_PARAMETERS, + List.of(List.of("key", "value")), + CustomServiceSettings.REQUEST, + new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, "request body")), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + createResponseParserMap(taskType), + CustomServiceSettings.ERROR_PARSER, + new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) + ) + ) + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512)); + } + + return settingsMap; + } + + private static Map createResponseParserMap(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> new HashMap<>( + Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ); + case COMPLETION -> new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text")); + case SPARSE_EMBEDDING -> new HashMap<>( + Map.of( + SPARSE_EMBEDDING_TOKEN_PATH, + "$.result[*].embeddings[*].token", + SPARSE_EMBEDDING_WEIGHT_PATH, + "$.result[*].embeddings[*].weight" + ) + ); + case RERANK -> new HashMap<>( + Map.of( + RERANK_PARSER_SCORE, + "$.result.scores[*].score", + RERANK_PARSER_INDEX, + "$.result.scores[*].index", + RERANK_PARSER_DOCUMENT_TEXT, + "$.result.scores[*].document_text" + ) + ); + default -> throw new IllegalArgumentException("unexpected task type [" + taskType + "]"); + }; + } + + private static Map createTaskSettingsMap() { + return new HashMap<>(Map.of(CustomTaskSettings.PARAMETERS, new HashMap<>(Map.of("test_key", "test_value")))); + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value")))); + } + + private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel( + similarityMeasure, + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + "http://www.abc.com" + ); + } + + private static CustomModel createInternalEmbeddingModel(TextEmbeddingResponseParser parser, String url) { + return createInternalEmbeddingModel(SimilarityMeasure.DOT_PRODUCT, parser, url); + } + + private static CustomModel createInternalEmbeddingModel( + @Nullable SimilarityMeasure similarityMeasure, + TextEmbeddingResponseParser parser, + String url + ) { + var inferenceId = "inference_id"; + + return new CustomModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + CustomService.NAME, + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + similarityMeasure, + 123, + 456, + DenseVectorFieldMapper.ElementType.FLOAT + ), + url, + Map.of("key", "value"), + QueryParameters.EMPTY, + "\"input\":\"${input}\"", + parser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ), + new CustomTaskSettings(Map.of("key", "test_value")), + new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))) + ); + } + + private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) { + var inferenceId = "inference_id"; + + return new CustomModel( + "model_id", + taskType, + CustomService.NAME, + new CustomServiceSettings( + getDefaultTextEmbeddingSettings(taskType), + url, + Map.of("key", "value"), + QueryParameters.EMPTY, + "\"input\":\"${input}\"", + customResponseParser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ), + new CustomTaskSettings(Map.of("key", "test_value")), + new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))) + ); + } + + private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddingSettings(TaskType taskType) { + return taskType == TaskType.TEXT_EMBEDDING + ? CustomServiceSettings.TextEmbeddingSettings.DEFAULT_FLOAT + : CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; + } + + public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + + var embeddingResults = (TextEmbeddingFloatResults) results; + assertThat( + embeddingResults.embeddings(), + is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) + ); + } + } + + public void testInfer_HandlesRerankRequest_Cohere_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "index": "44873262-1315-4c06-8433-fdc90c9790d0", + "results": [ + { + "document": { + "text": "Washington, D.C.." + }, + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": { + "text": "Capital punishment has existed in the United States since beforethe United States was a country." + }, + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": { + "text": "Carson City is the capital city of the American state of Nevada." + }, + "index": 0, + "relevance_score": 0.10194652 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCustomModel( + TaskType.RERANK, + new RerankResponseParser("$.results[*].relevance_score", "$.results[*].index", "$.results[*].document.text"), + getUrl(webServer) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(RankedDocsResults.class)); + + var rerankResults = (RankedDocsResults) results; + assertThat( + rerankResults.getRankedDocs(), + is( + List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307f, "Washington, D.C.."), + new RankedDocsResults.RankedDoc( + 3, + 0.27904198f, + "Capital punishment has existed in the United States since beforethe United States was a country." + ), + new RankedDocsResults.RankedDoc(0, 0.10194652f, "Carson City is the capital city of the American state of Nevada.") + ) + ) + ); + } + } + + public void testInfer_HandlesCompletionRequest_OpenAI_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCustomModel( + TaskType.COMPLETION, + new CompletionResponseParser("$.choices[*].message.content"), + getUrl(webServer) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(ChatCompletionResults.class)); + + var completionResults = (ChatCompletionResults) results; + assertThat( + completionResults.getResults(), + is(List.of(new ChatCompletionResults.Result("Hello there, how may I assist you today?"))) + ); + } + } + + public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "tokenId": 6, + "weight": 0.101 + }, + { + "tokenId": 163040, + "weight": 0.28417 + } + ] + } + ] + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createCustomModel( + TaskType.SPARSE_EMBEDDING, + new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" + ), + getUrl(webServer) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + InferenceServiceResults results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(SparseEmbeddingResults.class)); + + var sparseEmbeddingResults = (SparseEmbeddingResults) results; + assertThat( + sparseEmbeddingResults.embeddings(), + is( + List.of( + new SparseEmbeddingResults.Embedding( + List.of(new WeightedToken("6", 0.101f), new WeightedToken("163040", 0.28417f)), + false + ) + ) + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java new file mode 100644 index 0000000000000..01d09af0b7a27 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.hamcrest.Matchers.is; + +public class CustomTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + public static CustomTaskSettings createRandom() { + Map parameters = randomBoolean() + ? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5))) + : Map.of(); + return new CustomTaskSettings(parameters); + } + + public void testFromMap() { + var taskSettingsMap = new HashMap( + Map.of(CustomTaskSettings.PARAMETERS, new HashMap<>(Map.of("test_key", "test_value"))) + ); + + assertThat(CustomTaskSettings.fromMap(taskSettingsMap), is(new CustomTaskSettings(Map.of("test_key", "test_value")))); + } + + public void testFromMap_Null_EmptyMap_Returns_EmptySettings() { + assertThat(CustomTaskSettings.fromMap(Map.of()), is(CustomTaskSettings.EMPTY_SETTINGS)); + assertThat(CustomTaskSettings.fromMap(null), is(CustomTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_RemovesNullValues() { + var mapWithNulls = new HashMap(); + mapWithNulls.put("value", "abc"); + mapWithNulls.put("null", null); + + assertThat( + CustomTaskSettings.fromMap(modifiableMap(Map.of(CustomTaskSettings.PARAMETERS, mapWithNulls))), + is(new CustomTaskSettings(Map.of("value", "abc"))) + ); + } + + public void testFromMap_Throws_IfValueIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> CustomTaskSettings.fromMap( + modifiableMap(Map.of(CustomTaskSettings.PARAMETERS, modifiableMap(Map.of("key", Map.of("another_key", "value"))))) + ) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Map field [parameters] has an entry that is not valid, [key => {another_key=value}]. " + + "Value type of [{another_key=value}] is not one of [Boolean, Double, Float, Integer, String].;" + ) + ); + } + + public void testFromMap_DefaultsToEmptyMap_WhenParametersField_DoesNotExist() { + var taskSettingsMap = new HashMap(Map.of("key", new HashMap<>(Map.of("test_key", "test_value")))); + + assertThat(CustomTaskSettings.fromMap(taskSettingsMap), is(new CustomTaskSettings(Map.of()))); + } + + public void testOf_PrefersSettingsFromRequest() { + assertThat( + CustomTaskSettings.of( + new CustomTaskSettings(Map.of("a", "a_value", "b", "b_value")), + new CustomTaskSettings(Map.of("b", "b_value_overwritten")) + ), + is(new CustomTaskSettings(Map.of("a", "a_value", "b", "b_value_overwritten"))) + ); + } + + public void testXContent() throws IOException { + var entity = new CustomTaskSettings(Map.of("test_key", "test_value")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "parameters": { + "test_key": "test_value" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyParameters() throws IOException { + var entity = new CustomTaskSettings(Map.of()); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + } + """); + + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomTaskSettings::new; + } + + @Override + protected CustomTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomTaskSettings mutateInstance(CustomTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, CustomTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMap(@Nullable Map parameters) { + var map = new HashMap(); + if (parameters != null) { + map.put(CustomTaskSettings.PARAMETERS, parameters); + } + + return map; + } + + @Override + protected CustomTaskSettings mutateInstanceForVersion(CustomTaskSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java new file mode 100644 index 0000000000000..d6fac6709cc32 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/QueryParametersTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.modifiableMap; +import static org.hamcrest.Matchers.is; + +public class QueryParametersTests extends AbstractBWCWireSerializationTestCase { + public static QueryParameters createRandom() { + var parameters = randomList(5, () -> new QueryParameters.Parameter(randomAlphaOfLength(5), randomAlphaOfLength(5))); + return new QueryParameters(parameters); + } + + public void testFromMap() { + Map params = new HashMap<>(Map.of(QueryParameters.QUERY_PARAMETERS, List.of(List.of("test_key", "test_value")))); + + assertThat( + QueryParameters.fromMap(params, new ValidationException()), + is(new QueryParameters(List.of(new QueryParameters.Parameter("test_key", "test_value")))) + ); + } + + public void testFromMap_ReturnsEmpty_IfFieldDoesNotExist() { + assertThat(QueryParameters.fromMap(modifiableMap(Map.of()), new ValidationException()), is(QueryParameters.EMPTY)); + } + + public void testFromMap_Throws_IfFieldIsInvalid() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> QueryParameters.fromMap(modifiableMap(Map.of(QueryParameters.QUERY_PARAMETERS, "string")), validation) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: field [query_parameters] is not of the expected type. " + + "The value [string] cannot be converted to a [List];" + ) + ); + } + + public void testXContent() throws IOException { + var entity = new QueryParameters(List.of(new QueryParameters.Parameter("test_key", "test_value"))); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "query_parameters": [ + ["test_key", "test_value"] + ] + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_EmptyParameters() throws IOException { + var entity = QueryParameters.EMPTY; + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is("")); + } + + @Override + protected Writeable.Reader instanceReader() { + return QueryParameters::new; + } + + @Override + protected QueryParameters createTestInstance() { + return createRandom(); + } + + @Override + protected QueryParameters mutateInstance(QueryParameters instance) { + return randomValueOtherThan(instance, QueryParametersTests::createRandom); + } + + @Override + protected QueryParameters mutateInstanceForVersion(QueryParameters instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java new file mode 100644 index 0000000000000..06bfc0b1f6956 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -0,0 +1,310 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; +import org.elasticsearch.xpack.inference.services.custom.QueryParameters; +import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CustomRequestTests extends ESTestCase { + + public void testCreateRequest() throws IOException { + var inferenceId = "inference_id"; + var dims = 1536; + var maxInputTokens = 512; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}")); + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + "${url}", + headers, + new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), + requestContentString, + new TextEmbeddingResponseParser("$.result.embeddings"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + serviceSettings, + new CustomTaskSettings(Map.of("url", "https://www.elastic.com")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("https://www.elastic.com?key=value&key=value2")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("my-secret-key")); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"] + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + + public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { + var inferenceId = "inferenceId"; + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + null, + // escaped characters retrieved from here: https://docs.microfocus.com/OMi/10.62/Content/OMi/ExtGuide/ExtApps/URL_encoding.htm + new QueryParameters( + List.of( + new QueryParameters.Parameter("key", " <>#%+{}|\\^~[]`;/?:@=&$"), + // unicode is a 😀 + // Note: In the current version of the apache library (4.x) being used to do the encoding, spaces are converted to + + // There's a bug fix here explaining that: https://issues.apache.org/jira/browse/HTTPCORE-628 + new QueryParameters.Parameter("key", "Σ \uD83D\uDE00") + ) + ), + requestContentString, + new TextEmbeddingResponseParser("$.result.embeddings"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + serviceSettings, + new CustomTaskSettings(Map.of("url", "https://www.elastic.com")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat( + httpPost.getURI().toString(), + // To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/ + is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80") + ); + } + + public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException { + var inferenceId = "inference_id"; + var dims = 1536; + var maxInputTokens = 512; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}")); + var requestContentString = """ + { + "input": ${input}, + "secret": ${api_key} + } + """; + + var serviceSettings = new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + DenseVectorFieldMapper.ElementType.FLOAT + ), + "${url}", + headers, + new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), + requestContentString, + new TextEmbeddingResponseParser("$.result.embeddings"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + serviceSettings, + new CustomTaskSettings(Map.of("url", "https://www.elastic.com")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("https://www.elastic.com?key=value&key=value2")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("my-secret-key")); + + // secret is encoded in json format (with quotes) + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "secret": "my-secret-key" + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + + public void testCreateRequest_HandlesQuery() throws IOException { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input}, + "query": ${query} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + null, + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of()), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest("query string", List.of("abc", "123"), model); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedBody = XContentHelper.stripWhitespace(""" + { + "input": ["abc", "123"], + "query": "query string" + } + """); + + assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody)); + } + + public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.elastic.co", + Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")), + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of("task.key", 100)), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var request = new CustomRequest(null, List.of("abc", "123"), model); + var exception = expectThrows(IllegalStateException.class, request::createHttpRequest); + assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call")); + } + + public void testCreateRequest_ThrowsException_ForInvalidUrl() { + var inferenceId = "inference_id"; + var requestContentString = """ + { + "input": ${input} + } + """; + + var serviceSettings = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "${url}", + Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")), + null, + requestContentString, + new RerankResponseParser("$.result.score"), + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId) + ); + + var model = CustomModelTests.createModel( + inferenceId, + TaskType.RERANK, + serviceSettings, + new CustomTaskSettings(Map.of("url", "^")), + new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key"))) + ); + + var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model)); + assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^")); + } + + private static String convertToString(InputStream inputStream) throws IOException { + return XContentHelper.stripWhitespace(Streams.copyToString(new InputStreamReader(inputStream, StandardCharsets.UTF_8))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java index 46cb23a4ceaa5..1e8dbb41e4d9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java @@ -38,7 +38,12 @@ public static CompletionResponseParser createRandom() { public void testFromMap() { var validation = new ValidationException(); - var parser = CompletionResponseParser.fromMap(new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), validation); + + var parser = CompletionResponseParser.fromMap( + new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), + "scope", + validation + ); assertThat(parser, is(new CompletionResponseParser("$.result[*].text"))); } @@ -47,12 +52,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), validation) + () -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), "scope", validation) ); assertThat( exception.getMessage(), - is("Validation Failed: 1: [json_parser] does not contain the required setting [completion_result];") + is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [completion_result];") ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java new file mode 100644 index 0000000000000..e7f6a47e7c9c7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CustomResponseEntityTests extends ESTestCase { + + public void testFromTextEmbeddingResponse() throws IOException { + String responseJson = """ + { + "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4", + "latency": 38, + "usage": { + "token_count": 3072 + }, + "result": { + "embeddings": [ + { + "index": 0, + "embedding": [ + -0.02868066355586052, + 0.022033605724573135 + ] + } + ] + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")) + ); + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat( + ((TextEmbeddingFloatResults) results).embeddings(), + is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + ); + } + + public void testFromSparseEmbeddingResponse() throws IOException { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "tokenId": 6, + "weight": 0.10137939453125 + }, + { + "tokenId": 163040, + "weight": 0.2841796875 + } + ] + } + ] + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel( + TaskType.SPARSE_EMBEDDING, + new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" + ) + ) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(results, instanceOf(SparseEmbeddingResults.class)); + + SparseEmbeddingResults sparseEmbeddingResults = (SparseEmbeddingResults) results; + + List embeddingList = new ArrayList<>(); + List weightedTokens = new ArrayList<>(); + weightedTokens.add(new WeightedToken("6", 0.10137939453125f)); + weightedTokens.add(new WeightedToken("163040", 0.2841796875f)); + embeddingList.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + + for (int i = 0; i < embeddingList.size(); i++) { + assertThat(sparseEmbeddingResults.embeddings().get(i), is(embeddingList.get(i))); + } + } + + public void testFromRerankResponse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": 1.37 + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel( + TaskType.RERANK, + new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null) + ) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(RankedDocsResults.class)); + var expected = new ArrayList(); + expected.add(new RankedDocsResults.RankedDoc(1, 1.37F, null)); + expected.add(new RankedDocsResults.RankedDoc(0, -0.3F, null)); + + for (int i = 0; i < ((RankedDocsResults) results).getRankedDocs().size(); i++) { + assertThat(((RankedDocsResults) results).getRankedDocs().get(i).index(), is(expected.get(i).index())); + } + } + + public void testFromCompletionResponse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": { + "text":"completion results" + }, + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var request = new CustomRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel(TaskType.COMPLETION, new CompletionResponseParser("$.result.text")) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(ChatCompletionResults.class)); + ChatCompletionResults chatCompletionResults = (ChatCompletionResults) results; + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("completion results")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java index 56987407e02ac..e52d7d9d0ff69 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java @@ -24,34 +24,38 @@ import static org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser.MESSAGE_PATH; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; import static org.mockito.Mockito.mock; public class ErrorResponseParserTests extends ESTestCase { public static ErrorResponseParser createRandom() { - return new ErrorResponseParser("$." + randomAlphaOfLength(5)); + return new ErrorResponseParser("$." + randomAlphaOfLength(5), randomAlphaOfLength(5)); } public void testFromMap() { var validation = new ValidationException(); - var parser = ErrorResponseParser.fromMap(new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), validation); + var parser = ErrorResponseParser.fromMap( + new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), + "scope", + "inference_id", + validation + ); - assertThat(parser, is(new ErrorResponseParser("$.error.message"))); + assertThat(parser, is(new ErrorResponseParser("$.error.message", "inference_id"))); } public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), validation) + () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), "scope", "inference_id", validation) ); - assertThat(exception.getMessage(), is("Validation Failed: 1: [error_parser] does not contain the required setting [path];")); + assertThat(exception.getMessage(), is("Validation Failed: 1: [scope.error_parser] does not contain the required setting [path];")); } public void testToXContent() throws IOException { - var entity = new ErrorResponseParser("$.error.message"); + var entity = new ErrorResponseParser("$.error.message", "inference_id"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); { @@ -80,7 +84,7 @@ public void testErrorResponse_ExtractsError() throws IOException { } }"""); - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(result); assertThat(error, is(new ErrorResponse("test_error_message"))); } @@ -97,7 +101,7 @@ public void testFromResponse_WithOtherFieldsPresent() throws IOException { } """; - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(getMockResult(responseJson)); assertThat(error, is(new ErrorResponse("You didn't provide an API key"))); @@ -112,30 +116,29 @@ public void testFromResponse_noMessage() throws IOException { } """; - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(getMockResult(responseJson)); - assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); - assertThat(error.getErrorMessage(), is("")); - assertFalse(error.errorStructureFound()); + assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"error\":{\"type\":\"not_found_error\"}}]")); + assertTrue(error.errorStructureFound()); } public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOException { var mockResult = getMockResult(""" {"noerror":true}"""); - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(mockResult); - assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"noerror\":true}]")); } public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() { var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string")); - var parser = new ErrorResponseParser("$.error.message"); + var parser = new ErrorResponseParser("$.error.message", "inference_id"); var error = parser.apply(result); - assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [not a json string]")); } private static HttpResult getMockResult(String jsonString) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java index 523d15ec2a805..0c88d1f93bc73 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java @@ -53,6 +53,7 @@ public void testFromMap() { "$.result.scores[*].document_text" ) ), + "scope", validation ); @@ -64,7 +65,11 @@ public void testFromMap() { public void testFromMap_WithoutOptionalFields() { var validation = new ValidationException(); - var parser = RerankResponseParser.fromMap(new HashMap<>(Map.of(RERANK_PARSER_SCORE, "$.result.scores[*].score")), validation); + var parser = RerankResponseParser.fromMap( + new HashMap<>(Map.of(RERANK_PARSER_SCORE, "$.result.scores[*].score")), + "scope", + validation + ); assertThat(parser, is(new RerankResponseParser("$.result.scores[*].score", null, null))); } @@ -73,12 +78,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> RerankResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation) + () -> RerankResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), "scope", validation) ); assertThat( exception.getMessage(), - is("Validation Failed: 1: [json_parser] does not contain the required setting [relevance_score];") + is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [relevance_score];") ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java index c4b69ae8c8b19..7e54f95ef0fc1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java @@ -49,6 +49,7 @@ public void testFromMap() { "$.result[*].embeddings[*].weight" ) ), + "scope", validation ); @@ -59,14 +60,14 @@ public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> SparseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation) + () -> SparseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), "scope", validation) ); assertThat( exception.getMessage(), is( - "Validation Failed: 1: [json_parser] does not contain the required setting [token_path];" - + "2: [json_parser] does not contain the required setting [weight_path];" + "Validation Failed: 1: [scope.json_parser] does not contain the required setting [token_path];" + + "2: [scope.json_parser] does not contain the required setting [weight_path];" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java index b240e07a66336..82ddfa618d3b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java @@ -40,6 +40,7 @@ public void testFromMap() { var validation = new ValidationException(); var parser = TextEmbeddingResponseParser.fromMap( new HashMap<>(Map.of(TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result[*].embeddings")), + "scope", validation ); @@ -50,12 +51,12 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), validation) + () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation) ); assertThat( exception.getMessage(), - is("Validation Failed: 1: [json_parser] does not contain " + "the required setting [text_embeddings];") + is("Validation Failed: 1: [scope.json_parser] does not contain the required setting [text_embeddings];") ); } diff --git a/x-pack/plugin/logstash/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/logstash/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..16701ab74d8c9 --- /dev/null +++ b/x-pack/plugin/logstash/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/x-pack/plugin/migrate/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/migrate/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..db02e9267218a --- /dev/null +++ b/x-pack/plugin/migrate/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +grant { + // needed for Painless to generate runtime classes + permission java.lang.RuntimePermission "createClassLoader"; +}; diff --git a/x-pack/plugin/ml-package-loader/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/ml-package-loader/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..67fc731ea29de --- /dev/null +++ b/x-pack/plugin/ml-package-loader/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +grant { + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJob.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJob.java index 4582d1a49392e..2a76b925247ff 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJob.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJob.java @@ -377,7 +377,7 @@ private void run(long start, long end, FlushJobAction.Request flushRequest) { extractedData = result.data(); searchInterval = result.searchInterval(); } catch (Exception e) { - LOGGER.warn(() -> "[" + jobId + "] error while extracting data", e); + LOGGER.error(() -> "[" + jobId + "] error while extracting data", e); // When extraction problems are encountered, we do not want to advance time. // Instead, it is preferable to retry the given interval next time an extraction // is triggered. diff --git a/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy new file mode 100644 index 0000000000000..9b3e5e0c72209 --- /dev/null +++ b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy @@ -0,0 +1,5 @@ +// Needed for painless script to run +grant { + // needed to create the classloader which allows plugins to extend other plugins + permission java.lang.RuntimePermission "createClassLoader"; +}; diff --git a/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..1bf45f6d697a6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,7 @@ +grant { + // needed for Windows named pipes in machine learning + permission java.io.FilePermission "\\\\.\\pipe\\*", "read,write"; + + // needed for ojalgo linear programming solver + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/x-pack/plugin/monitoring/src/main/plugin-metadata/plugin-security.codebases b/x-pack/plugin/monitoring/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..6bb3f6a738ff2 --- /dev/null +++ b/x-pack/plugin/monitoring/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1 @@ +elasticsearch-rest-client: org.elasticsearch.client.RestClient diff --git a/x-pack/plugin/monitoring/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/monitoring/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..ef079a5c16e46 --- /dev/null +++ b/x-pack/plugin/monitoring/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,23 @@ +grant { + // needed because of problems in unbound LDAP library + permission java.util.PropertyPermission "*", "read,write"; + + // required to configure the custom mailcap for watcher + permission java.lang.RuntimePermission "setFactory"; + + // needed when sending emails for javax.activation + // otherwise a classnotfound exception is thrown due to trying + // to load the class with the application class loader + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.lang.RuntimePermission "getClassLoader"; + // TODO: remove use of this jar as soon as possible!!!! + permission java.lang.RuntimePermission "accessClassInPackage.com.sun.activation.registries"; + + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; + +grant codeBase "${codebase.elasticsearch-rest-client}" { + // rest client uses system properties which gets the default proxy + permission java.net.NetPermission "getProxySelector"; +}; diff --git a/x-pack/plugin/security/qa/audit/src/javaRestTest/java/org/elasticsearch/xpack/security/audit/AuditIT.java b/x-pack/plugin/security/qa/audit/src/javaRestTest/java/org/elasticsearch/xpack/security/audit/AuditIT.java index 9d6e49b63f395..2c329db5e3b50 100644 --- a/x-pack/plugin/security/qa/audit/src/javaRestTest/java/org/elasticsearch/xpack/security/audit/AuditIT.java +++ b/x-pack/plugin/security/qa/audit/src/javaRestTest/java/org/elasticsearch/xpack/security/audit/AuditIT.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.security.audit; -import org.apache.http.entity.ContentType; -import org.apache.http.entity.StringEntity; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.common.bytes.BytesReference; @@ -29,7 +27,6 @@ import org.junit.ClassRule; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.time.Instant; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; @@ -40,7 +37,6 @@ import java.util.concurrent.TimeUnit; import java.util.function.Predicate; -import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasKey; @@ -107,25 +103,6 @@ public void testFilteringOfRequestBodies() throws Exception { }); } - public void testAuditAuthenticationSuccessForStreamingRequest() throws Exception { - final Request request = new Request("POST", "/testindex/_bulk"); - request.setEntity(new StringEntity(""" - {"index":{}} - {} - """, ContentType.create("application/x-ndjson", StandardCharsets.UTF_8))); - executeAndVerifyAudit( - request, - AuditLevel.AUTHENTICATION_SUCCESS, - event -> assertThat( - event, - allOf( - hasEntry(LoggingAuditTrail.AUTHENTICATION_TYPE_FIELD_NAME, "REALM"), - hasEntry(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, "Request body had not been received at the time of the audit event") - ) - ) - ); - } - private void executeAndVerifyAudit(Request request, AuditLevel eventType, CheckedConsumer, Exception> assertions) throws Exception { Instant start = Instant.now(); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/security/qa/operator-privileges-tests/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..eb1558fb8e381 --- /dev/null +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/x-pack/plugin/security/src/main/java/module-info.java b/x-pack/plugin/security/src/main/java/module-info.java index 7af53479b0844..4820fd4ad6096 100644 --- a/x-pack/plugin/security/src/main/java/module-info.java +++ b/x-pack/plugin/security/src/main/java/module-info.java @@ -74,7 +74,6 @@ exports org.elasticsearch.xpack.security.rest.action.apikey to org.elasticsearch.internal.security; exports org.elasticsearch.xpack.security.support to org.elasticsearch.internal.security; exports org.elasticsearch.xpack.security.authz.store to org.elasticsearch.internal.security; - exports org.elasticsearch.xpack.security.authc.service; provides org.elasticsearch.index.SlowLogFieldProvider with org.elasticsearch.xpack.security.slowlog.SecuritySlowLogFieldProvider; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index 60b034294bcfc..004ea9a23ee70 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -208,8 +208,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.Subject; -import org.elasticsearch.xpack.core.security.authc.service.NodeLocalServiceAccountTokenStore; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine; @@ -312,7 +310,6 @@ import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm; import org.elasticsearch.xpack.security.authc.jwt.JwtRealm; import org.elasticsearch.xpack.security.authc.service.CachingServiceAccountTokenStore; -import org.elasticsearch.xpack.security.authc.service.CompositeServiceAccountTokenStore; import org.elasticsearch.xpack.security.authc.service.FileServiceAccountTokenStore; import org.elasticsearch.xpack.security.authc.service.IndexServiceAccountTokenStore; import org.elasticsearch.xpack.security.authc.service.ServiceAccountService; @@ -918,34 +915,12 @@ Collection createComponents( this.realms.set(realms); systemIndices.getMainIndexManager().addStateListener(nativeRoleMappingStore::onSecurityIndexStateChange); + final CacheInvalidatorRegistry cacheInvalidatorRegistry = new CacheInvalidatorRegistry(); + cacheInvalidatorRegistry.registerAlias("service", Set.of("file_service_account_token", "index_service_account_token")); components.add(cacheInvalidatorRegistry); - - ServiceAccountService serviceAccountService = createServiceAccountService( - components, - cacheInvalidatorRegistry, - extensionComponents, - () -> new IndexServiceAccountTokenStore( - settings, - threadPool, - getClock(), - client, - systemIndices.getMainIndexManager(), - clusterService, - cacheInvalidatorRegistry - ), - () -> new FileServiceAccountTokenStore( - environment, - resourceWatcherService, - threadPool, - clusterService, - cacheInvalidatorRegistry - ) - ); - - components.add(serviceAccountService); - systemIndices.getMainIndexManager().addStateListener(cacheInvalidatorRegistry::onSecurityIndexStateChange); + final NativePrivilegeStore privilegeStore = new NativePrivilegeStore( settings, client, @@ -1029,6 +1004,33 @@ Collection createComponents( ); components.add(apiKeyService); + final IndexServiceAccountTokenStore indexServiceAccountTokenStore = new IndexServiceAccountTokenStore( + settings, + threadPool, + getClock(), + client, + systemIndices.getMainIndexManager(), + clusterService, + cacheInvalidatorRegistry + ); + components.add(indexServiceAccountTokenStore); + + final FileServiceAccountTokenStore fileServiceAccountTokenStore = new FileServiceAccountTokenStore( + environment, + resourceWatcherService, + threadPool, + clusterService, + cacheInvalidatorRegistry + ); + components.add(fileServiceAccountTokenStore); + + final ServiceAccountService serviceAccountService = new ServiceAccountService( + client, + fileServiceAccountTokenStore, + indexServiceAccountTokenStore + ); + components.add(serviceAccountService); + final RoleProviders roleProviders = new RoleProviders( reservedRolesStore, fileRolesStore.get(), @@ -1248,74 +1250,6 @@ Collection createComponents( return components; } - private ServiceAccountService createServiceAccountService( - List components, - CacheInvalidatorRegistry cacheInvalidatorRegistry, - SecurityExtension.SecurityComponents extensionComponents, - Supplier indexServiceAccountTokenStoreSupplier, - Supplier fileServiceAccountTokenStoreSupplier - ) { - Map accountTokenStoreByExtension = new HashMap<>(); - - for (var extension : securityExtensions) { - var serviceAccountTokenStore = extension.getServiceAccountTokenStore(extensionComponents); - if (serviceAccountTokenStore != null) { - if (isInternalExtension(extension) == false) { - throw new IllegalStateException( - "The [" - + extension.getClass().getName() - + "] extension tried to install a custom ServiceAccountTokenStore. This functionality is not available to " - + "external extensions." - ); - } - accountTokenStoreByExtension.put(extension.extensionName(), serviceAccountTokenStore); - } - } - - if (accountTokenStoreByExtension.size() > 1) { - throw new IllegalStateException( - "More than one extension provided a ServiceAccountTokenStore override: " + accountTokenStoreByExtension.keySet() - ); - } - - if (accountTokenStoreByExtension.isEmpty()) { - var fileServiceAccountTokenStore = fileServiceAccountTokenStoreSupplier.get(); - var indexServiceAccountTokenStore = indexServiceAccountTokenStoreSupplier.get(); - - components.add(new PluginComponentBinding<>(NodeLocalServiceAccountTokenStore.class, fileServiceAccountTokenStore)); - components.add(fileServiceAccountTokenStore); - components.add(indexServiceAccountTokenStore); - cacheInvalidatorRegistry.registerAlias("service", Set.of("file_service_account_token", "index_service_account_token")); - - return new ServiceAccountService( - client.get(), - new CompositeServiceAccountTokenStore( - List.of(fileServiceAccountTokenStore, indexServiceAccountTokenStore), - client.get().threadPool().getThreadContext() - ), - indexServiceAccountTokenStore - ); - } - // Completely handover service account token management to the extension if provided, - // this will disable the index managed - // service account tokens managed through the service account token API - var extensionStore = accountTokenStoreByExtension.values().stream().findFirst(); - components.add(new PluginComponentBinding<>(NodeLocalServiceAccountTokenStore.class, (token, listener) -> { - throw new IllegalStateException("Node local config not supported by [" + extensionStore.get().getClass() + "]"); - })); - components.add(extensionStore); - logger.debug("Service account authentication handled by extension, disabling file and index token stores"); - return new ServiceAccountService(client.get(), extensionStore.get()); - } - - private static boolean isInternalExtension(SecurityExtension extension) { - final String canonicalName = extension.getClass().getCanonicalName(); - if (canonicalName == null) { - return false; - } - return canonicalName.startsWith("org.elasticsearch.xpack.") || canonicalName.startsWith("co.elastic.elasticsearch."); - } - @FixForMultiProject // TODO : The migration task needs to be project aware private void applyPendingSecurityMigrations(ProjectId projectId, SecurityIndexManager.IndexState newState) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountAction.java index be0a93bbda207..2e882546cb5a6 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountAction.java @@ -19,7 +19,7 @@ import org.elasticsearch.xpack.core.security.action.service.GetServiceAccountRequest; import org.elasticsearch.xpack.core.security.action.service.GetServiceAccountResponse; import org.elasticsearch.xpack.core.security.action.service.ServiceAccountInfo; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount; import org.elasticsearch.xpack.security.authc.service.ServiceAccountService; import java.util.function.Predicate; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java index 55352086a58a6..82ec407189595 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java @@ -21,8 +21,8 @@ import org.elasticsearch.xpack.core.security.action.service.GetServiceAccountCredentialsNodesResponse; import org.elasticsearch.xpack.core.security.action.service.GetServiceAccountNodesCredentialsAction; import org.elasticsearch.xpack.core.security.action.service.TokenInfo; -import org.elasticsearch.xpack.core.security.authc.service.NodeLocalServiceAccountTokenStore; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; +import org.elasticsearch.xpack.security.authc.service.FileServiceAccountTokenStore; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import java.io.IOException; import java.util.List; @@ -38,7 +38,7 @@ public class TransportGetServiceAccountNodesCredentialsAction extends TransportN GetServiceAccountCredentialsNodesResponse.Node, Void> { - private final NodeLocalServiceAccountTokenStore readOnlyServiceAccountTokenStore; + private final FileServiceAccountTokenStore fileServiceAccountTokenStore; @Inject public TransportGetServiceAccountNodesCredentialsAction( @@ -46,7 +46,7 @@ public TransportGetServiceAccountNodesCredentialsAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - NodeLocalServiceAccountTokenStore readOnlyServiceAccountTokenStore + FileServiceAccountTokenStore fileServiceAccountTokenStore ) { super( GetServiceAccountNodesCredentialsAction.NAME, @@ -56,7 +56,7 @@ public TransportGetServiceAccountNodesCredentialsAction( GetServiceAccountCredentialsNodesRequest.Node::new, threadPool.executor(ThreadPool.Names.GENERIC) ); - this.readOnlyServiceAccountTokenStore = readOnlyServiceAccountTokenStore; + this.fileServiceAccountTokenStore = fileServiceAccountTokenStore; } @Override @@ -84,7 +84,7 @@ protected GetServiceAccountCredentialsNodesResponse.Node nodeOperation( Task task ) { final ServiceAccountId accountId = new ServiceAccountId(request.getNamespace(), request.getServiceName()); - final List tokenInfos = readOnlyServiceAccountTokenStore.findNodeLocalTokensFor(accountId); + final List tokenInfos = fileServiceAccountTokenStore.findTokensFor(accountId); return new GetServiceAccountCredentialsNodesResponse.Node( clusterService.localNode(), tokenInfos.stream().map(TokenInfo::getName).toArray(String[]::new) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditUtil.java index fde2c2457d952..c584945bc3bd2 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditUtil.java @@ -27,9 +27,6 @@ public class AuditUtil { public static String restRequestContent(RestRequest request) { if (request.hasContent()) { - if (request.isStreamedContent()) { - return "Request body had not been received at the time of the audit event"; - } var content = request.content(); try { return XContentHelper.convertToJson(content, false, false, request.getXContentType()); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java index 4799caa360a0e..8f9d754b0c0a5 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java @@ -97,7 +97,6 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationField; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountSettings; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.AuthorizationInfo; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivileges; @@ -111,6 +110,7 @@ import org.elasticsearch.xpack.security.audit.AuditTrail; import org.elasticsearch.xpack.security.audit.AuditUtil; import org.elasticsearch.xpack.security.authc.ApiKeyService; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.rest.RemoteHostHeader; import org.elasticsearch.xpack.security.transport.filter.IPFilter; import org.elasticsearch.xpack.security.transport.filter.SecurityIpFilterRule; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticator.java index 2fbc792d1c62d..eb907b3bb52cd 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticator.java @@ -15,8 +15,8 @@ import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.authc.service.ServiceAccountService; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.metric.InstrumentedSecurityActionListener; import org.elasticsearch.xpack.security.metric.SecurityMetricType; import org.elasticsearch.xpack.security.metric.SecurityMetrics; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStore.java index 23ffe98fcdf95..fff47c2c22e2e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStore.java @@ -19,8 +19,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; @@ -99,10 +97,10 @@ private void authenticateWithCache(ServiceAccountToken token, ActionListener { if (result.success) { - l.onResponse(StoreAuthenticationResult.fromBooleanResult(getTokenSource(), result.verify(token))); + l.onResponse(new StoreAuthenticationResult(result.verify(token), getTokenSource())); } else if (result.verify(token)) { // same wrong token - l.onResponse(StoreAuthenticationResult.failed(getTokenSource())); + l.onResponse(new StoreAuthenticationResult(false, getTokenSource())); } else { cache.invalidate(token.getQualifiedName(), listenableCacheEntry); authenticateWithCache(token, l); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStore.java index 48a8a89cda300..e5227e3b9f593 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStore.java @@ -12,8 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.xpack.core.common.IteratingActionListener; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; import java.util.List; import java.util.function.Function; @@ -40,7 +38,7 @@ public void authenticate(ServiceAccountToken token, ActionListener storeAuthenticationResult.isSuccess() == false + storeAuthenticationResult -> false == storeAuthenticationResult.isSuccess() ); try { authenticatingListener.run(); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java index 7caab9d5a6e87..dd671ebef824e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccounts.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.authc.service; import org.elasticsearch.common.Strings; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; import org.elasticsearch.xpack.core.security.authz.store.ReservedRolesStore; import org.elasticsearch.xpack.core.security.user.User; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStore.java index f7921d1ae3526..dff79c56b32dc 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStore.java @@ -22,12 +22,10 @@ import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.security.action.service.TokenInfo; import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; -import org.elasticsearch.xpack.core.security.authc.service.NodeLocalServiceAccountTokenStore; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.support.NoOpLogger; import org.elasticsearch.xpack.security.PrivilegedFileWatcher; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; import org.elasticsearch.xpack.security.support.FileLineParser; import org.elasticsearch.xpack.security.support.FileReloadListener; @@ -43,7 +41,7 @@ import java.util.Optional; import java.util.concurrent.CopyOnWriteArrayList; -public class FileServiceAccountTokenStore extends CachingServiceAccountTokenStore implements NodeLocalServiceAccountTokenStore { +public class FileServiceAccountTokenStore extends CachingServiceAccountTokenStore { private static final Logger logger = LogManager.getLogger(FileServiceAccountTokenStore.class); @@ -52,7 +50,6 @@ public class FileServiceAccountTokenStore extends CachingServiceAccountTokenStor private final CopyOnWriteArrayList refreshListeners; private volatile Map tokenHashes; - @SuppressWarnings("this-escape") public FileServiceAccountTokenStore( Environment env, ResourceWatcherService resourceWatcherService, @@ -85,8 +82,8 @@ public void doAuthenticate(ServiceAccountToken token, ActionListener StoreAuthenticationResult.fromBooleanResult(getTokenSource(), Hasher.verifyHash(token.getSecret(), hash))) - .orElse(StoreAuthenticationResult.failed(getTokenSource())) + .map(hash -> new StoreAuthenticationResult(Hasher.verifyHash(token.getSecret(), hash), getTokenSource())) + .orElse(new StoreAuthenticationResult(false, getTokenSource())) ); } @@ -95,8 +92,7 @@ public TokenSource getTokenSource() { return TokenSource.FILE; } - @Override - public List findNodeLocalTokensFor(ServiceAccountId accountId) { + public List findTokensFor(ServiceAccountId accountId) { final String principal = accountId.asPrincipal(); return tokenHashes.keySet() .stream() diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileTokensTool.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileTokensTool.java index 9549e91d8a49d..14ca1663e16a5 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileTokensTool.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/FileTokensTool.java @@ -21,11 +21,10 @@ import org.elasticsearch.core.Predicates; import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.core.XPackSettings; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken.ServiceAccountTokenId; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.support.Validation; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken.ServiceAccountTokenId; import org.elasticsearch.xpack.security.support.FileAttributesChecker; import java.nio.file.Path; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStore.java index 85c9b555b5352..91e51340e7d8f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStore.java @@ -47,10 +47,9 @@ import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Subject; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken.ServiceAccountTokenId; import org.elasticsearch.xpack.core.security.authc.support.Hasher; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken.ServiceAccountTokenId; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; import org.elasticsearch.xpack.security.support.SecurityIndexManager; import org.elasticsearch.xpack.security.support.SecurityIndexManager.IndexState; @@ -81,7 +80,6 @@ public class IndexServiceAccountTokenStore extends CachingServiceAccountTokenSto private final ClusterService clusterService; private final Hasher hasher; - @SuppressWarnings("this-escape") public IndexServiceAccountTokenStore( Settings settings, ThreadPool threadPool, @@ -118,14 +116,14 @@ void doAuthenticate(ServiceAccountToken token, ActionListener { + compositeServiceAccountTokenStore.authenticate(serviceAccountToken, ActionListener.wrap(storeAuthenticationResult -> { if (storeAuthenticationResult.isSuccess()) { listener.onResponse( createAuthentication(account, serviceAccountToken, storeAuthenticationResult.getTokenSource(), nodeName) @@ -153,23 +149,14 @@ public void createIndexToken( CreateServiceAccountTokenRequest request, ActionListener listener ) { - if (indexServiceAccountTokenStore == null) { - throw new IllegalStateException("Can't create token because index service account token store not configured"); - } indexServiceAccountTokenStore.createToken(authentication, request, listener); } public void deleteIndexToken(DeleteServiceAccountTokenRequest request, ActionListener listener) { - if (indexServiceAccountTokenStore == null) { - throw new IllegalStateException("Can't delete token because index service account token store not configured"); - } indexServiceAccountTokenStore.deleteToken(request, listener); } public void findTokensFor(GetServiceAccountCredentialsRequest request, ActionListener listener) { - if (indexServiceAccountTokenStore == null) { - throw new IllegalStateException("Can't find tokens because index service account token store not configured"); - } final ServiceAccountId accountId = new ServiceAccountId(request.getNamespace(), request.getServiceName()); findIndexTokens(accountId, listener); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountToken.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountToken.java similarity index 97% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountToken.java rename to x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountToken.java index 03f8fe3554383..ce509a0122419 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountToken.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountToken.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.service; +package org.elasticsearch.xpack.security.authc.service; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -14,9 +14,9 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.CharArrays; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.core.security.support.Validation; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -51,6 +51,7 @@ public class ServiceAccountToken implements AuthenticationToken, Closeable { private final ServiceAccountTokenId tokenId; private final SecureString secret; + // pkg private for testing ServiceAccountToken(ServiceAccountId accountId, String tokenName, SecureString secret) { tokenId = new ServiceAccountTokenId(accountId, tokenName); this.secret = Objects.requireNonNull(secret, "service account token secret cannot be null"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountTokenStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountTokenStore.java similarity index 62% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountTokenStore.java rename to x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountTokenStore.java index 3226ee1381d8c..37eec6a092e3f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountTokenStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountTokenStore.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.service; +package org.elasticsearch.xpack.security.authc.service; import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; @@ -24,23 +24,11 @@ class StoreAuthenticationResult { private final boolean success; private final TokenSource tokenSource; - private StoreAuthenticationResult(TokenSource tokenSource, boolean success) { + public StoreAuthenticationResult(boolean success, TokenSource tokenSource) { this.success = success; this.tokenSource = tokenSource; } - public static StoreAuthenticationResult successful(TokenSource tokenSource) { - return new StoreAuthenticationResult(tokenSource, true); - } - - public static StoreAuthenticationResult failed(TokenSource tokenSource) { - return new StoreAuthenticationResult(tokenSource, false); - } - - public static StoreAuthenticationResult fromBooleanResult(TokenSource tokenSource, boolean result) { - return result ? successful(tokenSource) : failed(tokenSource); - } - public boolean isSuccess() { return success; } diff --git a/x-pack/plugin/security/src/main/plugin-metadata/plugin-security.codebases b/x-pack/plugin/security/src/main/plugin-metadata/plugin-security.codebases new file mode 100644 index 0000000000000..94cfaec2d519c --- /dev/null +++ b/x-pack/plugin/security/src/main/plugin-metadata/plugin-security.codebases @@ -0,0 +1,2 @@ +netty-common: io.netty.util.NettyRuntime +netty-transport: io.netty.channel.Channel diff --git a/x-pack/plugin/security/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/security/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..b4791207a15bf --- /dev/null +++ b/x-pack/plugin/security/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,64 @@ +grant { + permission java.lang.RuntimePermission "setFactory"; + + // secure the users file from other things (current and legacy locations) + permission org.elasticsearch.SecuredConfigFileAccessPermission "users"; + permission org.elasticsearch.SecuredConfigFileAccessPermission "x-pack/users"; + // other security files specified by settings + permission org.elasticsearch.SecuredConfigFileSettingAccessPermission "xpack.security.authc.realms.ldap.*.files.role_mapping"; + permission org.elasticsearch.SecuredConfigFileSettingAccessPermission "xpack.security.authc.realms.pki.*.files.role_mapping"; + permission org.elasticsearch.SecuredConfigFileSettingAccessPermission "xpack.security.authc.realms.jwt.*.pkc_jwkset_path"; + permission org.elasticsearch.SecuredConfigFileSettingAccessPermission "xpack.security.authc.realms.saml.*.idp.metadata.path"; + permission org.elasticsearch.SecuredConfigFileSettingAccessPermission "xpack.security.authc.realms.kerberos.*.keytab.path"; + + // needed for SAML + permission java.util.PropertyPermission "org.apache.xml.security.ignoreLineBreaks", "read,write"; + + // needed during initialization of OpenSAML library where xml security algorithms are registered + // see https://github.com/apache/santuario-java/blob/e79f1fe4192de73a975bc7246aee58ed0703343d/src/main/java/org/apache/xml/security/utils/JavaUtils.java#L205-L220 + // and https://git.shibboleth.net/view/?p=java-opensaml.git;a=blob;f=opensaml-xmlsec-impl/src/main/java/org/opensaml/xmlsec/signature/impl/SignatureMarshaller.java;hb=db0eaa64210f0e32d359cd6c57bedd57902bf811#l52 + // which uses it in the opensaml-xmlsec-impl + permission java.security.SecurityPermission "org.apache.xml.security.register"; + + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; + + // needed for Kerberos login + permission javax.security.auth.AuthPermission "modifyPrincipals"; + permission javax.security.auth.AuthPermission "modifyPrivateCredentials"; + permission javax.security.auth.PrivateCredentialPermission "javax.security.auth.kerberos.KerberosKey * \"*\"", "read"; + permission javax.security.auth.PrivateCredentialPermission "javax.security.auth.kerberos.KeyTab * \"*\"", "read"; + permission javax.security.auth.PrivateCredentialPermission "javax.security.auth.kerberos.KerberosTicket * \"*\"", "read"; + permission javax.security.auth.AuthPermission "doAs"; + permission javax.security.auth.kerberos.ServicePermission "*","initiate,accept"; + + permission java.util.PropertyPermission "javax.security.auth.useSubjectCredsOnly","write"; + permission java.util.PropertyPermission "java.security.krb5.conf","write"; + permission java.util.PropertyPermission "sun.security.krb5.debug","write"; + permission java.util.PropertyPermission "java.security.debug","write"; + permission java.util.PropertyPermission "sun.security.spnego.debug","write"; + + // needed for kerberos file permission tests to access user information + permission java.lang.RuntimePermission "accessUserInformation"; + permission java.lang.RuntimePermission "getFileStoreAttributes"; +}; + +grant codeBase "${codebase.netty-common}" { + // for reading the system-wide configuration for the backlog of established sockets + permission java.io.FilePermission "/proc/sys/net/core/somaxconn", "read"; + // Netty gets and sets classloaders for some of its internal threads + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.lang.RuntimePermission "getClassLoader"; +}; + +grant codeBase "${codebase.netty-transport}" { + // Netty NioEventLoop wants to change this, because of https://bugs.openjdk.java.net/browse/JDK-6427854 + // the bug says it only happened rarely, and that its fixed, but apparently it still happens rarely! + permission java.util.PropertyPermission "sun.nio.ch.bugLevel", "write"; +}; + +grant codeBase "${codebase.nimbus-jose-jwt-modified}" { + // for JSON serialization based on a shaded GSON dependency + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; +}; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java index 4eac3ddf85f1b..40acb9a32f1bc 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.ActionModule; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.bulk.IncrementalBulkService; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; @@ -33,7 +32,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsModule; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; @@ -79,15 +77,12 @@ import org.elasticsearch.xpack.core.security.SecurityExtension; import org.elasticsearch.xpack.core.security.SecurityField; import org.elasticsearch.xpack.core.security.action.ActionTypes; -import org.elasticsearch.xpack.core.security.action.service.TokenInfo; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.file.FileRealmSettings; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; import org.elasticsearch.xpack.core.security.authc.support.CachingUsernamePasswordRealmSettings; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.authz.accesscontrol.IndicesAccessControl; @@ -104,9 +99,6 @@ import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm; import org.elasticsearch.xpack.security.authc.jwt.JwtRealm; import org.elasticsearch.xpack.security.authc.service.CachingServiceAccountTokenStore; -import org.elasticsearch.xpack.security.authc.service.FileServiceAccountTokenStore; -import org.elasticsearch.xpack.security.authc.service.IndexServiceAccountTokenStore; -import org.elasticsearch.xpack.security.authc.service.ServiceAccountService; import org.elasticsearch.xpack.security.operator.DefaultOperatorOnlyRegistry; import org.elasticsearch.xpack.security.operator.OperatorOnlyRegistry; import org.elasticsearch.xpack.security.operator.OperatorPrivileges; @@ -165,34 +157,16 @@ public class SecurityTests extends ESTestCase { private TestUtils.UpdatableLicenseState licenseState; public static class DummyExtension implements SecurityExtension { - private final String realmType; - private final ServiceAccountTokenStore serviceAccountTokenStore; - private final String extensionName; + private String realmType; DummyExtension(String realmType) { - this(realmType, "DummyExtension", null); - } - - DummyExtension(String realmType, String extensionName, @Nullable ServiceAccountTokenStore serviceAccountTokenStore) { this.realmType = realmType; - this.extensionName = extensionName; - this.serviceAccountTokenStore = serviceAccountTokenStore; - } - - @Override - public String extensionName() { - return extensionName; } @Override public Map getRealms(SecurityComponents components) { return Collections.singletonMap(realmType, config -> null); } - - @Override - public ServiceAccountTokenStore getServiceAccountTokenStore(SecurityComponents components) { - return serviceAccountTokenStore; - } } public static class DummyOperatorOnlyRegistry implements OperatorOnlyRegistry { @@ -292,7 +266,7 @@ public void testCustomRealmExtension() throws Exception { assertNotNull(realms.realmFactory("myrealm")); } - public void testCustomRealmExtensionConflict() { + public void testCustomRealmExtensionConflict() throws Exception { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> createComponents(Settings.EMPTY, new DummyExtension(FileRealmSettings.TYPE)) @@ -300,64 +274,6 @@ public void testCustomRealmExtensionConflict() { assertEquals("Realm type [" + FileRealmSettings.TYPE + "] is already registered", e.getMessage()); } - public void testServiceAccountTokenStoreExtensionSuccess() throws Exception { - Collection components = createComponents( - Settings.EMPTY, - new DummyExtension( - "test_realm", - "DummyExtension", - (token, listener) -> listener.onResponse( - ServiceAccountTokenStore.StoreAuthenticationResult.successful(TokenInfo.TokenSource.FILE) - ) - ) - ); - ServiceAccountService serviceAccountService = findComponent(ServiceAccountService.class, components); - assertNotNull(serviceAccountService); - FileServiceAccountTokenStore fileServiceAccountTokenStore = findComponent(FileServiceAccountTokenStore.class, components); - assertNull(fileServiceAccountTokenStore); - IndexServiceAccountTokenStore indexServiceAccountTokenStore = findComponent(IndexServiceAccountTokenStore.class, components); - assertNull(indexServiceAccountTokenStore); - var account = randomFrom(ServiceAccountService.getServiceAccounts().values()); - assertThrows(IllegalStateException.class, () -> serviceAccountService.createIndexToken(null, null, null)); - var future = new PlainActionFuture(); - serviceAccountService.authenticateToken(ServiceAccountToken.newToken(account.id(), "test"), "test", future); - assertTrue(future.get().isServiceAccount()); - } - - public void testSeveralServiceAccountTokenStoreExtensionFail() { - IllegalStateException exception = assertThrows( - IllegalStateException.class, - () -> createComponents( - Settings.EMPTY, - new DummyExtension( - "test_realm_1", - "DummyExtension1", - (token, listener) -> listener.onResponse( - ServiceAccountTokenStore.StoreAuthenticationResult.successful(TokenInfo.TokenSource.FILE) - ) - ), - new DummyExtension( - "test_realm_2", - "DummyExtension2", - (token, listener) -> listener.onResponse( - ServiceAccountTokenStore.StoreAuthenticationResult.successful(TokenInfo.TokenSource.FILE) - ) - ) - ) - ); - assertThat(exception.getMessage(), containsString("More than one extension provided a ServiceAccountTokenStore override: ")); - } - - public void testNoServiceAccountTokenStoreExtension() throws Exception { - Collection components = createComponents(Settings.EMPTY); - ServiceAccountService serviceAccountService = findComponent(ServiceAccountService.class, components); - assertNotNull(serviceAccountService); - FileServiceAccountTokenStore fileServiceAccountTokenStore = findComponent(FileServiceAccountTokenStore.class, components); - assertNotNull(fileServiceAccountTokenStore); - IndexServiceAccountTokenStore indexServiceAccountTokenStore = findComponent(IndexServiceAccountTokenStore.class, components); - assertNotNull(indexServiceAccountTokenStore); - } - public void testAuditEnabled() throws Exception { Settings settings = Settings.builder().put(XPackSettings.AUDIT_ENABLED.getKey(), true).build(); Collection components = createComponents(settings); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java index 191b4840d733a..638b1be718040 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java @@ -102,9 +102,7 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountSettings; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.core.security.authc.support.mapper.TemplateRoleName; import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.ExpressionModel; import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.RoleMapperExpression; @@ -126,6 +124,8 @@ import org.elasticsearch.xpack.security.audit.AuditTrail; import org.elasticsearch.xpack.security.audit.AuditUtil; import org.elasticsearch.xpack.security.authc.ApiKeyService; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.rest.RemoteHostHeader; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; import org.elasticsearch.xpack.security.support.SecurityIndexManager; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index 1b5bbd4de9a44..4bf840d281a8c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -81,7 +81,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmDomain; import org.elasticsearch.xpack.core.security.authc.esnative.NativeRealmSettings; import org.elasticsearch.xpack.core.security.authc.file.FileRealmSettings; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; @@ -99,6 +98,7 @@ import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm; import org.elasticsearch.xpack.security.authc.file.FileRealm; import org.elasticsearch.xpack.security.authc.service.ServiceAccountService; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.operator.OperatorPrivileges; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; import org.elasticsearch.xpack.security.support.SecurityIndexManager; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticatorChainTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticatorChainTests.java index 363593b83c9c4..4517b639b7604 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticatorChainTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticatorChainTests.java @@ -28,12 +28,12 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.Realm; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer; import org.elasticsearch.xpack.core.security.authc.support.BearerToken; import org.elasticsearch.xpack.core.security.user.AnonymousUser; import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.security.authc.ApiKeyService.ApiKeyCredentials; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.operator.OperatorPrivileges.OperatorPrivilegesService; import org.junit.Before; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticatorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticatorTests.java index 5015c17380ff4..a16e8aecc5c67 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticatorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ServiceAccountAuthenticatorTests.java @@ -16,10 +16,10 @@ import org.elasticsearch.telemetry.TestTelemetryPlugin; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.core.security.user.User; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.security.authc.service.ServiceAccountService; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken; import org.elasticsearch.xpack.security.metric.SecurityMetricType; import java.util.Map; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStoreTests.java index c6f9a66d876db..c09c7a1479131 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CachingServiceAccountTokenStoreTests.java @@ -17,10 +17,9 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore.StoreAuthenticationResult; import org.elasticsearch.xpack.core.security.support.ValidationTests; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountTokenStore.StoreAuthenticationResult; import org.junit.After; import org.junit.Before; @@ -35,7 +34,6 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class CachingServiceAccountTokenStoreTests extends ESTestCase { @@ -55,22 +53,14 @@ public void stop() { } } - private ServiceAccountToken newMockServiceAccountToken(ServiceAccountId accountId, String tokenName, SecureString secret) { - ServiceAccountToken serviceAccountToken = mock(ServiceAccountToken.class); - var serviceAccountTokenId = new ServiceAccountToken.ServiceAccountTokenId(accountId, tokenName); - when(serviceAccountToken.getQualifiedName()).thenReturn(serviceAccountTokenId.getQualifiedName()); - when(serviceAccountToken.getSecret()).thenReturn(secret); - return serviceAccountToken; - } - public void testCache() throws ExecutionException, InterruptedException { final ServiceAccountId accountId = new ServiceAccountId(randomAlphaOfLengthBetween(3, 8), randomAlphaOfLengthBetween(3, 8)); final SecureString validSecret = new SecureString("super-secret-value".toCharArray()); final SecureString invalidSecret = new SecureString("some-fishy-value".toCharArray()); - final ServiceAccountToken token1Valid = newMockServiceAccountToken(accountId, "token1", validSecret); - final ServiceAccountToken token1Invalid = newMockServiceAccountToken(accountId, "token1", invalidSecret); - final ServiceAccountToken token2Valid = newMockServiceAccountToken(accountId, "token2", validSecret); - final ServiceAccountToken token2Invalid = newMockServiceAccountToken(accountId, "token2", invalidSecret); + final ServiceAccountToken token1Valid = new ServiceAccountToken(accountId, "token1", validSecret); + final ServiceAccountToken token1Invalid = new ServiceAccountToken(accountId, "token1", invalidSecret); + final ServiceAccountToken token2Valid = new ServiceAccountToken(accountId, "token2", validSecret); + final ServiceAccountToken token2Invalid = new ServiceAccountToken(accountId, "token2", invalidSecret); final AtomicBoolean doAuthenticateInvoked = new AtomicBoolean(false); final TokenSource tokenSource = randomFrom(TokenSource.values()); @@ -78,7 +68,7 @@ public void testCache() throws ExecutionException, InterruptedException { @Override void doAuthenticate(ServiceAccountToken token, ActionListener listener) { doAuthenticateInvoked.set(true); - listener.onResponse(StoreAuthenticationResult.fromBooleanResult(getTokenSource(), validSecret.equals(token.getSecret()))); + listener.onResponse(new StoreAuthenticationResult(validSecret.equals(token.getSecret()), getTokenSource())); } @Override @@ -170,7 +160,7 @@ public void testCacheCanBeDisabled() throws ExecutionException, InterruptedExcep final CachingServiceAccountTokenStore store = new CachingServiceAccountTokenStore(settings, threadPool) { @Override void doAuthenticate(ServiceAccountToken token, ActionListener listener) { - listener.onResponse(StoreAuthenticationResult.fromBooleanResult(getTokenSource(), success)); + listener.onResponse(new StoreAuthenticationResult(success, getTokenSource())); } @Override @@ -191,7 +181,7 @@ public void testCacheInvalidateByKeys() { final CachingServiceAccountTokenStore store = new CachingServiceAccountTokenStore(globalSettings, threadPool) { @Override void doAuthenticate(ServiceAccountToken token, ActionListener listener) { - listener.onResponse(StoreAuthenticationResult.successful(getTokenSource())); + listener.onResponse(new StoreAuthenticationResult(true, getTokenSource())); } @Override diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStoreTests.java index 0b2d75bb32243..09cba2fdfaf6a 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/CompositeServiceAccountTokenStoreTests.java @@ -13,9 +13,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore.StoreAuthenticationResult; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountTokenStore.StoreAuthenticationResult; import org.junit.Before; import org.mockito.Mockito; @@ -60,7 +58,7 @@ public void testAuthenticate() throws ExecutionException, InterruptedException { @SuppressWarnings("unchecked") final ActionListener listener = (ActionListener) invocationOnMock .getArguments()[1]; - listener.onResponse(StoreAuthenticationResult.fromBooleanResult(tokenSource, store1Success)); + listener.onResponse(new StoreAuthenticationResult(store1Success, tokenSource)); return null; }).when(store1).authenticate(eq(token), any()); @@ -68,7 +66,7 @@ public void testAuthenticate() throws ExecutionException, InterruptedException { @SuppressWarnings("unchecked") final ActionListener listener = (ActionListener) invocationOnMock .getArguments()[1]; - listener.onResponse(StoreAuthenticationResult.fromBooleanResult(tokenSource, store2Success)); + listener.onResponse(new StoreAuthenticationResult(store2Success, tokenSource)); return null; }).when(store2).authenticate(eq(token), any()); @@ -76,7 +74,7 @@ public void testAuthenticate() throws ExecutionException, InterruptedException { @SuppressWarnings("unchecked") final ActionListener listener = (ActionListener) invocationOnMock .getArguments()[1]; - listener.onResponse(StoreAuthenticationResult.fromBooleanResult(tokenSource, store3Success)); + listener.onResponse(new StoreAuthenticationResult(store3Success, tokenSource)); return null; }).when(store3).authenticate(eq(token), any()); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccountsTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccountsTests.java index 5ccefb8fbe134..21e29469bb02b 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccountsTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ElasticServiceAccountsTests.java @@ -56,7 +56,6 @@ import org.elasticsearch.xpack.core.security.action.user.PutUserAction; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; import org.elasticsearch.xpack.core.security.authz.permission.FieldPermissionsCache; import org.elasticsearch.xpack.core.security.authz.permission.Role; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStoreTests.java index e95a3e66ce04f..0f2a720660afd 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/FileServiceAccountTokenStoreTests.java @@ -21,8 +21,8 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xpack.core.security.action.service.TokenInfo; import org.elasticsearch.xpack.core.security.audit.logfile.CapturingLogger; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.core.security.authc.support.Hasher; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; import org.junit.After; import org.junit.Before; @@ -238,7 +238,7 @@ public void testFindTokensFor() throws IOException { ); final ServiceAccountId accountId = new ServiceAccountId("elastic", "fleet-server"); - final List tokenInfos = store.findNodeLocalTokensFor(accountId); + final List tokenInfos = store.findTokensFor(accountId); assertThat(tokenInfos, hasSize(5)); assertThat( tokenInfos.stream().map(TokenInfo::getName).collect(Collectors.toUnmodifiableSet()), diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStoreTests.java index 17249c9900ac9..ec4b0059ed8a1 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/IndexServiceAccountTokenStoreTests.java @@ -55,11 +55,10 @@ import org.elasticsearch.xpack.core.security.action.service.TokenInfo.TokenSource; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore.StoreAuthenticationResult; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.support.ValidationTests; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountTokenStore.StoreAuthenticationResult; import org.elasticsearch.xpack.security.support.CacheInvalidatorRegistry; import org.elasticsearch.xpack.security.support.SecurityIndexManager; import org.junit.Before; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIdTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIdTests.java index aa5baefa894a3..ed6ac0f6de435 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIdTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountIdTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount; import java.io.IOException; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountServiceTests.java index 44eb42d08ef59..43fe57dd8b313 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountServiceTests.java @@ -33,12 +33,10 @@ import org.elasticsearch.xpack.core.security.action.service.TokenInfo; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountTokenStore; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; import org.elasticsearch.xpack.core.security.support.ValidationTests; import org.elasticsearch.xpack.core.security.user.User; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import org.junit.After; import org.junit.Before; @@ -85,14 +83,7 @@ public void init() throws UnknownHostException { when(indexServiceAccountTokenStore.getTokenSource()).thenReturn(TokenInfo.TokenSource.INDEX); client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); - serviceAccountService = new ServiceAccountService( - client, - new CompositeServiceAccountTokenStore( - List.of(fileServiceAccountTokenStore, indexServiceAccountTokenStore), - threadPool.getThreadContext() - ), - indexServiceAccountTokenStore - ); + serviceAccountService = new ServiceAccountService(client, fileServiceAccountTokenStore, indexServiceAccountTokenStore); } @After @@ -237,15 +228,16 @@ public void testTryParseToken() throws IOException { List.of(magicBytes, (namespace + "/" + serviceName + "/" + tokenName + ":" + secret).getBytes(StandardCharsets.UTF_8)) ); final ServiceAccountToken serviceAccountToken1 = ServiceAccountService.tryParseToken(bearerString5); - - assertNotNull(serviceAccountToken1); - assertThat(serviceAccountToken1.getAccountId(), equalTo(accountId)); - assertThat(serviceAccountToken1.getTokenName(), equalTo(tokenName)); - assertThat(serviceAccountToken1.getSecret(), equalTo(new SecureString(secret.toCharArray()))); + final ServiceAccountToken serviceAccountToken2 = new ServiceAccountToken( + accountId, + tokenName, + new SecureString(secret.toCharArray()) + ); + assertThat(serviceAccountToken1, equalTo(serviceAccountToken2)); // Serialise and de-serialise service account token - final ServiceAccountToken parsedToken = ServiceAccountService.tryParseToken(serviceAccountToken1.asBearerString()); - assertThat(parsedToken, equalTo(serviceAccountToken1)); + final ServiceAccountToken parsedToken = ServiceAccountService.tryParseToken(serviceAccountToken2.asBearerString()); + assertThat(parsedToken, equalTo(serviceAccountToken2)); // Invalid magic byte satMockLog.addExpectation( @@ -303,31 +295,25 @@ public void testTryParseToken() throws IOException { ); sasMockLog.assertAllExpectationsMatched(); - ServiceAccountToken parsedServiceAccountToken = ServiceAccountService.tryParseToken( - new SecureString("AAEAAWVsYXN0aWMvZmxlZXQtc2VydmVyL3Rva2VuMTpzdXBlcnNlY3JldA".toCharArray()) - ); - // everything is fine - assertNotNull(parsedServiceAccountToken); - assertThat(parsedServiceAccountToken.getAccountId(), equalTo(new ServiceAccountId("elastic", "fleet-server"))); - assertThat(parsedServiceAccountToken.getTokenName(), equalTo("token1")); - assertThat(parsedServiceAccountToken.getSecret(), equalTo(new SecureString("supersecret".toCharArray()))); + assertThat( + ServiceAccountService.tryParseToken( + new SecureString("AAEAAWVsYXN0aWMvZmxlZXQtc2VydmVyL3Rva2VuMTpzdXBlcnNlY3JldA".toCharArray()) + ), + equalTo( + new ServiceAccountToken( + new ServiceAccountId("elastic", "fleet-server"), + "token1", + new SecureString("supersecret".toCharArray()) + ) + ) + ); } finally { Loggers.setLevel(satLogger, Level.INFO); Loggers.setLevel(sasLogger, Level.INFO); } } - private ServiceAccountToken newMockServiceAccountToken(ServiceAccountId accountId, String tokenName, SecureString secret) { - ServiceAccountToken serviceAccountToken = mock(ServiceAccountToken.class); - var serviceAccountTokenId = new ServiceAccountToken.ServiceAccountTokenId(accountId, tokenName); - when(serviceAccountToken.getQualifiedName()).thenReturn(serviceAccountTokenId.getQualifiedName()); - when(serviceAccountToken.getSecret()).thenReturn(secret); - when(serviceAccountToken.getAccountId()).thenReturn(accountId); - when(serviceAccountToken.getTokenName()).thenReturn(tokenName); - return serviceAccountToken; - } - public void testTryAuthenticateBearerToken() throws ExecutionException, InterruptedException { // Valid token final PlainActionFuture future5 = new PlainActionFuture<>(); @@ -339,10 +325,7 @@ public void testTryAuthenticateBearerToken() throws ExecutionException, Interrup final ActionListener listener = (ActionListener< ServiceAccountTokenStore.StoreAuthenticationResult>) invocationOnMock.getArguments()[1]; listener.onResponse( - ServiceAccountTokenStore.StoreAuthenticationResult.fromBooleanResult( - store.getTokenSource(), - store == authenticatingStore - ) + new ServiceAccountTokenStore.StoreAuthenticationResult(store == authenticatingStore, store.getTokenSource()) ); return null; }).when(store).authenticate(any(), any()); @@ -350,7 +333,7 @@ public void testTryAuthenticateBearerToken() throws ExecutionException, Interrup final String nodeName = randomAlphaOfLengthBetween(3, 8); serviceAccountService.authenticateToken( - newMockServiceAccountToken( + new ServiceAccountToken( new ServiceAccountId("elastic", "fleet-server"), "token1", new SecureString("super-secret-value".toCharArray()) @@ -396,7 +379,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx ) ); final SecureString secret = new SecureString(randomAlphaOfLength(20).toCharArray()); - final ServiceAccountToken token1 = newMockServiceAccountToken(accountId1, randomAlphaOfLengthBetween(3, 8), secret); + final ServiceAccountToken token1 = new ServiceAccountToken(accountId1, randomAlphaOfLengthBetween(3, 8), secret); final PlainActionFuture future1 = new PlainActionFuture<>(); serviceAccountService.authenticateToken(token1, randomAlphaOfLengthBetween(3, 8), future1); final ExecutionException e1 = expectThrows(ExecutionException.class, future1::get); @@ -426,7 +409,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx "the [" + accountId2.asPrincipal() + "] service account does not exist" ) ); - final ServiceAccountToken token2 = newMockServiceAccountToken(accountId2, randomAlphaOfLengthBetween(3, 8), secret); + final ServiceAccountToken token2 = new ServiceAccountToken(accountId2, randomAlphaOfLengthBetween(3, 8), secret); final PlainActionFuture future2 = new PlainActionFuture<>(); serviceAccountService.authenticateToken(token2, randomAlphaOfLengthBetween(3, 8), future2); final ExecutionException e2 = expectThrows(ExecutionException.class, future2::get); @@ -446,7 +429,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx // Length of secret value is too short final ServiceAccountId accountId3 = new ServiceAccountId(ElasticServiceAccounts.NAMESPACE, "fleet-server"); final SecureString secret3 = new SecureString(randomAlphaOfLengthBetween(1, 9).toCharArray()); - final ServiceAccountToken token3 = newMockServiceAccountToken(accountId3, randomAlphaOfLengthBetween(3, 8), secret3); + final ServiceAccountToken token3 = new ServiceAccountToken(accountId3, randomAlphaOfLengthBetween(3, 8), secret3); mockLog.addExpectation( new MockLog.SeenEventExpectation( "secret value too short", @@ -473,7 +456,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx ); mockLog.assertAllExpectationsMatched(); - final TokenInfo.TokenSource tokenSource = randomFrom(TokenInfo.TokenSource.FILE, TokenInfo.TokenSource.INDEX); + final TokenInfo.TokenSource tokenSource = randomFrom(TokenInfo.TokenSource.values()); final CachingServiceAccountTokenStore store; final CachingServiceAccountTokenStore otherStore; if (tokenSource == TokenInfo.TokenSource.FILE) { @@ -486,8 +469,8 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx // Success based on credential store final ServiceAccountId accountId4 = new ServiceAccountId(ElasticServiceAccounts.NAMESPACE, "fleet-server"); - final ServiceAccountToken token4 = newMockServiceAccountToken(accountId4, randomAlphaOfLengthBetween(3, 8), secret); - final ServiceAccountToken token5 = newMockServiceAccountToken( + final ServiceAccountToken token4 = new ServiceAccountToken(accountId4, randomAlphaOfLengthBetween(3, 8), secret); + final ServiceAccountToken token5 = new ServiceAccountToken( accountId4, randomAlphaOfLengthBetween(3, 8), new SecureString(randomAlphaOfLength(20).toCharArray()) @@ -497,7 +480,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx @SuppressWarnings("unchecked") final ActionListener listener = (ActionListener< ServiceAccountTokenStore.StoreAuthenticationResult>) invocationOnMock.getArguments()[1]; - listener.onResponse(ServiceAccountTokenStore.StoreAuthenticationResult.successful(store.getTokenSource())); + listener.onResponse(new ServiceAccountTokenStore.StoreAuthenticationResult(true, store.getTokenSource())); return null; }).when(store).authenticate(eq(token4), any()); @@ -505,7 +488,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx @SuppressWarnings("unchecked") final ActionListener listener = (ActionListener< ServiceAccountTokenStore.StoreAuthenticationResult>) invocationOnMock.getArguments()[1]; - listener.onResponse(ServiceAccountTokenStore.StoreAuthenticationResult.failed(store.getTokenSource())); + listener.onResponse(new ServiceAccountTokenStore.StoreAuthenticationResult(false, store.getTokenSource())); return null; }).when(store).authenticate(eq(token5), any()); @@ -513,7 +496,7 @@ public void testAuthenticateWithToken() throws ExecutionException, InterruptedEx @SuppressWarnings("unchecked") final ActionListener listener = (ActionListener< ServiceAccountTokenStore.StoreAuthenticationResult>) invocationOnMock.getArguments()[1]; - listener.onResponse(ServiceAccountTokenStore.StoreAuthenticationResult.failed(otherStore.getTokenSource())); + listener.onResponse(new ServiceAccountTokenStore.StoreAuthenticationResult(false, otherStore.getTokenSource())); return null; }).when(otherStore).authenticate(any(), any()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountTokenTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountTokenTests.java similarity index 96% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountTokenTests.java rename to x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountTokenTests.java index dbece9ca5b4a4..6c8c625c0ceea 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/service/ServiceAccountTokenTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/service/ServiceAccountTokenTests.java @@ -5,13 +5,13 @@ * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.service; +package org.elasticsearch.xpack.security.authc.service; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccount.ServiceAccountId; import org.elasticsearch.xpack.core.security.support.Validation; import org.elasticsearch.xpack.core.security.support.ValidationTests; +import org.elasticsearch.xpack.security.authc.service.ServiceAccount.ServiceAccountId; import java.io.IOException; diff --git a/x-pack/plugin/sql/jdbc/src/test/resources/plugin-security.policy b/x-pack/plugin/sql/jdbc/src/test/resources/plugin-security.policy new file mode 100644 index 0000000000000..577795ffb7842 --- /dev/null +++ b/x-pack/plugin/sql/jdbc/src/test/resources/plugin-security.policy @@ -0,0 +1,6 @@ +grant { + // Required for testing the Driver registration + permission java.sql.SQLPermission "deregisterDriver"; + // Required for debug logging purposes + permission java.sql.SQLPermission "setLog"; +}; diff --git a/x-pack/plugin/sql/qa/jdbc/security/src/test/resources/plugin-security.policy b/x-pack/plugin/sql/qa/jdbc/security/src/test/resources/plugin-security.policy new file mode 100644 index 0000000000000..434fdee0a8d20 --- /dev/null +++ b/x-pack/plugin/sql/qa/jdbc/security/src/test/resources/plugin-security.policy @@ -0,0 +1,9 @@ +grant { + // Needed to read the audit log file + permission java.io.FilePermission "${tests.audit.logfile}", "read"; + permission java.io.FilePermission "${tests.audit.yesterday.logfile}", "read"; + + //// Required by ssl subproject: + // Required for the net client to setup ssl rather than use global ssl. + permission java.lang.RuntimePermission "setFactory"; +}; diff --git a/x-pack/plugin/sql/qa/jdbc/src/main/resources/plugin-security.policy b/x-pack/plugin/sql/qa/jdbc/src/main/resources/plugin-security.policy new file mode 100644 index 0000000000000..bb58eb4270ddf --- /dev/null +++ b/x-pack/plugin/sql/qa/jdbc/src/main/resources/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // Policy is required for tests to connect to testing Elasticsearch instances. + permission java.net.SocketPermission "*", "connect,resolve"; +}; diff --git a/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/SqlSecurityTestCase.java b/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/SqlSecurityTestCase.java index 75cfe5cca64bd..e30934050bfb9 100644 --- a/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/SqlSecurityTestCase.java +++ b/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/SqlSecurityTestCase.java @@ -98,7 +98,10 @@ protected interface Actions { protected static final String SQL_ACTION_NAME = "indices:data/read/sql"; /** - * Location of the audit log file. + * Location of the audit log file. We could technically figure this out by reading the admin + * APIs but it isn't worth doing because we also have to give ourselves permission to read + * the file and that must be done by setting a system property and reading it in + * {@code plugin-security.policy}. So we may as well have gradle set the property. */ private static final Path AUDIT_LOG_FILE = lookupAuditLog(); private static final Path ROLLED_OVER_AUDIT_LOG_FILE = lookupRolledOverAuditLog(); diff --git a/x-pack/plugin/sql/qa/server/security/src/test/resources/plugin-security.policy b/x-pack/plugin/sql/qa/server/security/src/test/resources/plugin-security.policy new file mode 100644 index 0000000000000..434fdee0a8d20 --- /dev/null +++ b/x-pack/plugin/sql/qa/server/security/src/test/resources/plugin-security.policy @@ -0,0 +1,9 @@ +grant { + // Needed to read the audit log file + permission java.io.FilePermission "${tests.audit.logfile}", "read"; + permission java.io.FilePermission "${tests.audit.yesterday.logfile}", "read"; + + //// Required by ssl subproject: + // Required for the net client to setup ssl rather than use global ssl. + permission java.lang.RuntimePermission "setFactory"; +}; diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/plugin-security.policy b/x-pack/plugin/sql/qa/server/src/main/resources/plugin-security.policy new file mode 100644 index 0000000000000..bb58eb4270ddf --- /dev/null +++ b/x-pack/plugin/sql/qa/server/src/main/resources/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // Policy is required for tests to connect to testing Elasticsearch instances. + permission java.net.SocketPermission "*", "connect,resolve"; +}; diff --git a/x-pack/plugin/sql/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/sql/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/watcher/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/watcher/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..d27ded771b86f --- /dev/null +++ b/x-pack/plugin/watcher/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,15 @@ +grant { + // required to configure the custom mailcap for watcher + permission java.lang.RuntimePermission "setFactory"; + + // needed when sending emails for javax.activation + // otherwise a classnotfound exception is thrown due to trying + // to load the class with the application class loader + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.lang.RuntimePermission "getClassLoader"; + // TODO: remove use of this jar as soon as possible!!!! + permission java.lang.RuntimePermission "accessClassInPackage.com.sun.activation.registries"; + + // needed for multiple server implementations used in tests + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/x-pack/qa/kerberos-tests/src/javaRestTest/resources/plugin-security.policy b/x-pack/qa/kerberos-tests/src/javaRestTest/resources/plugin-security.policy new file mode 100644 index 0000000000000..84219494bf2ce --- /dev/null +++ b/x-pack/qa/kerberos-tests/src/javaRestTest/resources/plugin-security.policy @@ -0,0 +1,7 @@ +grant { + permission javax.security.auth.AuthPermission "doAsPrivileged"; + permission javax.security.auth.kerberos.DelegationPermission "\"HTTP/localhost@BUILD.ELASTIC.CO\" \"krbtgt/BUILD.ELASTIC.CO@BUILD.ELASTIC.CO\""; + permission javax.security.auth.kerberos.DelegationPermission "\"HTTP/localhost.localdomain@BUILD.ELASTIC.CO\" \"krbtgt/BUILD.ELASTIC.CO@BUILD.ELASTIC.CO\""; + permission javax.security.auth.kerberos.DelegationPermission "\"HTTP/localhost4@BUILD.ELASTIC.CO\" \"krbtgt/BUILD.ELASTIC.CO@BUILD.ELASTIC.CO\""; + permission javax.security.auth.kerberos.DelegationPermission "\"HTTP/localhost4.localdomain4@BUILD.ELASTIC.CO\" \"krbtgt/BUILD.ELASTIC.CO@BUILD.ELASTIC.CO\""; +}; \ No newline at end of file diff --git a/x-pack/qa/multi-project/xpack-rest-tests-with-multiple-projects/src/yamlRestTest/java/org/elasticsearch/multiproject/test/XpackWithMultipleProjectsClientYamlTestSuiteIT.java b/x-pack/qa/multi-project/xpack-rest-tests-with-multiple-projects/src/yamlRestTest/java/org/elasticsearch/multiproject/test/XpackWithMultipleProjectsClientYamlTestSuiteIT.java index 0977eaddd7e06..860e2ffc0690f 100644 --- a/x-pack/qa/multi-project/xpack-rest-tests-with-multiple-projects/src/yamlRestTest/java/org/elasticsearch/multiproject/test/XpackWithMultipleProjectsClientYamlTestSuiteIT.java +++ b/x-pack/qa/multi-project/xpack-rest-tests-with-multiple-projects/src/yamlRestTest/java/org/elasticsearch/multiproject/test/XpackWithMultipleProjectsClientYamlTestSuiteIT.java @@ -20,7 +20,7 @@ import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.junit.ClassRule; -@TimeoutSuite(millis = 60 * TimeUnits.MINUTE) +@TimeoutSuite(millis = 30 * TimeUnits.MINUTE) public class XpackWithMultipleProjectsClientYamlTestSuiteIT extends MultipleProjectsClientYamlSuiteTestCase { @ClassRule public static ElasticsearchCluster cluster = ElasticsearchCluster.local() diff --git a/x-pack/qa/rolling-upgrade/build.gradle b/x-pack/qa/rolling-upgrade/build.gradle index d9ab1723469f2..c184558d2353f 100644 --- a/x-pack/qa/rolling-upgrade/build.gradle +++ b/x-pack/qa/rolling-upgrade/build.gradle @@ -14,11 +14,9 @@ apply plugin: 'elasticsearch.bwc-test' apply plugin: 'elasticsearch.rest-resources' dependencies { - testImplementation testArtifact(project(':server')) testImplementation testArtifact(project(xpackModule('core'))) testImplementation project(':x-pack:qa') testImplementation project(':modules:reindex') - testImplementation testArtifact(project(xpackModule('inference'))) } restResources { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java deleted file mode 100644 index 81eaa9507b843..0000000000000 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.upgrades; - -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - -import org.apache.lucene.search.join.ScoreMode; -import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; -import org.elasticsearch.client.Request; -import org.elasticsearch.client.RequestOptions; -import org.elasticsearch.client.Response; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; -import org.elasticsearch.index.query.NestedQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; -import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.test.rest.ObjectPath; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; -import org.elasticsearch.xpack.inference.mapper.SemanticTextField; -import org.elasticsearch.xpack.inference.model.TestModel; -import org.junit.BeforeClass; - -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.notNullValue; - -public class SemanticTextUpgradeIT extends AbstractUpgradeTestCase { - private static final String INDEX_BASE_NAME = "semantic_text_test_index"; - private static final String SPARSE_FIELD = "sparse_field"; - private static final String DENSE_FIELD = "dense_field"; - - private static final String DOC_1_ID = "doc_1"; - private static final String DOC_2_ID = "doc_2"; - private static final Map> DOC_VALUES = Map.of( - DOC_1_ID, - List.of("a test value", "with multiple test values"), - DOC_2_ID, - List.of("another test value") - ); - - private static Model SPARSE_MODEL; - private static Model DENSE_MODEL; - - private final boolean useLegacyFormat; - - @BeforeClass - public static void beforeClass() { - SPARSE_MODEL = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - // Exclude dot product because we are not producing unit length vectors - DENSE_MODEL = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT)); - } - - public SemanticTextUpgradeIT(boolean useLegacyFormat) { - this.useLegacyFormat = useLegacyFormat; - } - - @ParametersFactory - public static Iterable parameters() { - return List.of(new Object[] { true }, new Object[] { false }); - } - - public void testSemanticTextOperations() throws Exception { - switch (CLUSTER_TYPE) { - case OLD -> createAndPopulateIndex(); - case MIXED, UPGRADED -> performIndexQueryHighlightOps(); - default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]"); - } - } - - private void createAndPopulateIndex() throws IOException { - final String indexName = getIndexName(); - final String mapping = Strings.format(""" - { - "properties": { - "%s": { - "type": "semantic_text", - "inference_id": "%s" - }, - "%s": { - "type": "semantic_text", - "inference_id": "%s" - } - } - } - """, SPARSE_FIELD, SPARSE_MODEL.getInferenceEntityId(), DENSE_FIELD, DENSE_MODEL.getInferenceEntityId()); - - CreateIndexResponse response = createIndex( - indexName, - Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat).build(), - mapping - ); - assertThat(response.isAcknowledged(), equalTo(true)); - - indexDoc(DOC_1_ID, DOC_VALUES.get(DOC_1_ID)); - } - - private void performIndexQueryHighlightOps() throws IOException { - indexDoc(DOC_2_ID, DOC_VALUES.get(DOC_2_ID)); - - ObjectPath sparseQueryObjectPath = semanticQuery(SPARSE_FIELD, SPARSE_MODEL, "test value", 3); - assertQueryResponseWithHighlights(sparseQueryObjectPath, SPARSE_FIELD); - - ObjectPath denseQueryObjectPath = semanticQuery(DENSE_FIELD, DENSE_MODEL, "test value", 3); - assertQueryResponseWithHighlights(denseQueryObjectPath, DENSE_FIELD); - } - - private String getIndexName() { - return INDEX_BASE_NAME + (useLegacyFormat ? "_legacy" : "_new"); - } - - private void indexDoc(String id, List semanticTextFieldValue) throws IOException { - final String indexName = getIndexName(); - final SemanticTextField sparseFieldValue = randomSemanticText( - useLegacyFormat, - SPARSE_FIELD, - SPARSE_MODEL, - null, - semanticTextFieldValue, - XContentType.JSON - ); - final SemanticTextField denseFieldValue = randomSemanticText( - useLegacyFormat, - DENSE_FIELD, - DENSE_MODEL, - null, - semanticTextFieldValue, - XContentType.JSON - ); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - if (useLegacyFormat == false) { - builder.field(sparseFieldValue.fieldName(), semanticTextFieldValue); - builder.field(denseFieldValue.fieldName(), semanticTextFieldValue); - } - addSemanticTextInferenceResults(useLegacyFormat, builder, List.of(sparseFieldValue, denseFieldValue)); - builder.endObject(); - - RequestOptions requestOptions = RequestOptions.DEFAULT.toBuilder().addParameter("refresh", "true").build(); - Request request = new Request("POST", indexName + "/_doc/" + id); - request.setJsonEntity(Strings.toString(builder)); - request.setOptions(requestOptions); - - Response response = client().performRequest(request); - assertOK(response); - } - - private ObjectPath semanticQuery(String field, Model fieldModel, String query, Integer numOfHighlightFragments) throws IOException { - // We can't perform a real semantic query because that requires performing inference, so instead we perform an equivalent nested - // query - final String embeddingsFieldName = SemanticTextField.getEmbeddingsFieldName(field); - final QueryBuilder innerQueryBuilder = switch (fieldModel.getTaskType()) { - case SPARSE_EMBEDDING -> { - List weightedTokens = Arrays.stream(query.split("\\s")).map(t -> new WeightedToken(t, 1.0f)).toList(); - yield new SparseVectorQueryBuilder(embeddingsFieldName, weightedTokens, null, null, null, null); - } - case TEXT_EMBEDDING -> { - DenseVectorFieldMapper.ElementType elementType = fieldModel.getServiceSettings().elementType(); - int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength( - elementType, - fieldModel.getServiceSettings().dimensions() - ); - - // Create a query vector with a value of 1 for each dimension, which will effectively act as a pass-through for the document - // vector - float[] queryVector = new float[embeddingLength]; - if (elementType == DenseVectorFieldMapper.ElementType.BIT) { - Arrays.fill(queryVector, -128.0f); - } else { - Arrays.fill(queryVector, 1.0f); - } - - yield new KnnVectorQueryBuilder(embeddingsFieldName, queryVector, DOC_VALUES.size(), null, null, null); - } - default -> throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]"); - }; - - NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( - SemanticTextField.getChunksFieldName(field), - innerQueryBuilder, - ScoreMode.Max - ); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.field("query", nestedQueryBuilder); - if (numOfHighlightFragments != null) { - HighlightBuilder.Field highlightField = new HighlightBuilder.Field(field); - highlightField.numOfFragments(numOfHighlightFragments); - - HighlightBuilder highlightBuilder = new HighlightBuilder(); - highlightBuilder.field(highlightField); - - builder.field("highlight", highlightBuilder); - } - builder.endObject(); - - Request request = new Request("GET", getIndexName() + "/_search"); - request.setJsonEntity(Strings.toString(builder)); - - Response response = client().performRequest(request); - return assertOKAndCreateObjectPath(response); - } - - private static void assertQueryResponseWithHighlights(ObjectPath queryObjectPath, String field) throws IOException { - assertThat(queryObjectPath.evaluate("hits.total.value"), equalTo(2)); - assertThat(queryObjectPath.evaluateArraySize("hits.hits"), equalTo(2)); - - Set docIds = new HashSet<>(); - List> hits = queryObjectPath.evaluate("hits.hits"); - for (Map hit : hits) { - String id = ObjectPath.evaluate(hit, "_id"); - assertThat(id, notNullValue()); - docIds.add(id); - - List expectedHighlight = DOC_VALUES.get(id); - assertThat(expectedHighlight, notNullValue()); - assertThat(ObjectPath.evaluate(hit, "highlight." + field), equalTo(expectedHighlight)); - } - - assertThat(docIds, equalTo(Set.of(DOC_1_ID, DOC_2_ID))); - } -} diff --git a/x-pack/qa/security-example-spi-extension/src/main/plugin-metadata/plugin-security.policy b/x-pack/qa/security-example-spi-extension/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..4b663a0bf8d8c --- /dev/null +++ b/x-pack/qa/security-example-spi-extension/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,4 @@ +grant { + // example security manager permission + permission java.util.PropertyPermission "foobar", "read"; +}; diff --git a/x-pack/qa/security-tools-tests/src/test/java/org/elasticsearch/xpack/security/authc/service/FileTokensToolTests.java b/x-pack/qa/security-tools-tests/src/test/java/org/elasticsearch/xpack/security/authc/service/FileTokensToolTests.java index 4084026b4cdbc..24373e5061206 100644 --- a/x-pack/qa/security-tools-tests/src/test/java/org/elasticsearch/xpack/security/authc/service/FileTokensToolTests.java +++ b/x-pack/qa/security-tools-tests/src/test/java/org/elasticsearch/xpack/security/authc/service/FileTokensToolTests.java @@ -22,10 +22,10 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.PathUtilsForTesting; import org.elasticsearch.env.Environment; -import org.elasticsearch.xpack.core.security.authc.service.ServiceAccountToken.ServiceAccountTokenId; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.support.Validation; import org.elasticsearch.xpack.core.security.support.ValidationTests; +import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken.ServiceAccountTokenId; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass;