diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 6b5708b3e6528..238540bf2c799 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -21,10 +21,13 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.OrdinalBytesRefVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.DriverContext; @@ -282,11 +285,18 @@ private static Block dataBlock(int groups, String dataType) { int blockLength = blockLength(groups); return switch (dataType) { case BYTES_REF -> { - try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) { + try ( + BytesRefVector.Builder dict = blockFactory.newBytesRefVectorBuilder(blockLength); + IntVector.Builder ords = blockFactory.newIntVectorBuilder(blockLength) + ) { + final int dictLength = Math.min(blockLength, KEYWORDS.length); + for (int i = 0; i < dictLength; i++) { + dict.appendBytesRef(KEYWORDS[i]); + } for (int i = 0; i < blockLength; i++) { - builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]); + ords.appendInt(i % dictLength); } - yield builder.build(); + yield new OrdinalBytesRefVector(ords.build(), dict.build()).asBlock(); } } case INT -> { diff --git a/docs/changelog/127849.yaml b/docs/changelog/127849.yaml new file mode 100644 index 0000000000000..4d5b747b35011 --- /dev/null +++ b/docs/changelog/127849.yaml @@ -0,0 +1,5 @@ +pr: 127849 +summary: Optimize ordinal inputs in Values aggregation +area: "ES|QL" +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index b02abd9b1fa58..1704f4cbeb1fe 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -35,6 +35,7 @@ import static java.util.stream.Collectors.joining; import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize; +import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod; import static org.elasticsearch.compute.gen.Methods.requireAnyArgs; import static org.elasticsearch.compute.gen.Methods.requireAnyType; import static org.elasticsearch.compute.gen.Methods.requireArgs; @@ -336,10 +337,32 @@ private MethodSpec prepareProcessPage() { builder.beginControlFlow("if (valuesBlock.mayHaveNulls())"); builder.addStatement("state.enableGroupIdTracking(seenGroupIds)"); builder.endControlFlow(); - builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra))); + if (shouldWrapAddInput(blockType(aggParam.type()))) { + builder.addStatement( + "var addInput = $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)) + ); + builder.addStatement("return $T.wrapAddInput(addInput, state, valuesBlock)", declarationType); + } else { + builder.addStatement( + "return $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)) + ); + } } builder.endControlFlow(); - builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra))); + if (shouldWrapAddInput(vectorType(aggParam.type()))) { + builder.addStatement( + "var addInput = $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)) + ); + builder.addStatement("return $T.wrapAddInput(addInput, state, valuesVector)", declarationType); + } else { + builder.addStatement( + "return $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)) + ); + } return builder.build(); } @@ -526,6 +549,15 @@ private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVar warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable)); } + private boolean shouldWrapAddInput(TypeName valuesType) { + return optionalStaticMethod( + declarationType, + requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT), + requireName("wrapAddInput"), + requireArgs(requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT), requireType(aggState.declaredType()), requireType(valuesType)) + ) != null; + } + private void warningsBlock(MethodSpec.Builder builder, Runnable block) { if (warnExceptions.isEmpty() == false) { builder.beginControlFlow("try"); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java index f2fa7b8084448..b94eb13433a15 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java @@ -59,6 +59,23 @@ static ExecutableElement requireStaticMethod( TypeMatcher returnTypeMatcher, NameMatcher nameMatcher, ArgumentMatcher argumentMatcher + ) { + ExecutableElement method = optionalStaticMethod(declarationType, returnTypeMatcher, nameMatcher, argumentMatcher); + if (method == null) { + var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: "; + var signatures = nameMatcher.names.stream() + .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")") + .collect(joining(" or ")); + throw new IllegalArgumentException(message + signatures); + } + return method; + } + + static ExecutableElement optionalStaticMethod( + TypeElement declarationType, + TypeMatcher returnTypeMatcher, + NameMatcher nameMatcher, + ArgumentMatcher argumentMatcher ) { return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream()) .filter(method -> method.getModifiers().contains(Modifier.STATIC)) @@ -66,13 +83,7 @@ static ExecutableElement requireStaticMethod( .filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType()))) .filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList())) .findFirst() - .orElseThrow(() -> { - var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: "; - var signatures = nameMatcher.names.stream() - .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")") - .collect(joining(" or ")); - return new IllegalArgumentException(message + signatures); - }); + .orElse(null); } static NameMatcher requireName(String... names) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 9018a1b7b73fb..51195578ac363 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -59,6 +59,22 @@ public static GroupingState initGrouping(BigArrays bigArrays) { return new GroupingState(bigArrays); } + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefBlock values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } + + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefVector values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } + public static void combine(GroupingState state, int groupId, BytesRef v) { state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); } @@ -130,8 +146,8 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { - private final LongLongHash values; - private BytesRefHash bytes; + final LongLongHash values; + BytesRefHash bytes; private GroupingState(BigArrays bigArrays) { LongLongHash _values = null; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index 662c4b445496e..28843942b73cb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -64,7 +64,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG if (valuesBlock.mayHaveNulls()) { state.enableGroupIdTracking(seenGroupIds); } - return new GroupingAggregatorFunction.AddInput() { + var addInput = new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); @@ -84,8 +84,9 @@ public void add(int positionOffset, IntVector groupIds) { public void close() { } }; + return ValuesBytesRefAggregator.wrapAddInput(addInput, state, valuesBlock); } - return new GroupingAggregatorFunction.AddInput() { + var addInput = new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); @@ -105,6 +106,7 @@ public void add(int positionOffset, IntVector groupIds) { public void close() { } }; + return ValuesBytesRefAggregator.wrapAddInput(addInput, state, valuesVector); } private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java new file mode 100644 index 0000000000000..78a083b8daac7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; +import org.elasticsearch.core.Releasables; + +final class ValuesBytesRefAggregators { + static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + ValuesBytesRefAggregator.GroupingState state, + BytesRefBlock values + ) { + OrdinalBytesRefBlock valuesOrdinal = values.asOrdinals(); + if (valuesOrdinal == null) { + return delegate; + } + BytesRefVector dict = valuesOrdinal.getDictionaryVector(); + final IntVector hashIds; + BytesRef spare = new BytesRef(); + try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { + for (int p = 0; p < dict.getPositionCount(); p++) { + hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); + } + hashIds = hashIdsBuilder.build(); + } + IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock(); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + if (ordinalIds.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + } + } + } + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + if (ordinalIds.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + } + } + } + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int groupId = groupIds.getInt(groupPosition); + if (ordinalIds.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + } + } + } + + @Override + public void close() { + Releasables.close(hashIds, delegate); + } + }; + } + + static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + ValuesBytesRefAggregator.GroupingState state, + BytesRefVector values + ) { + var valuesOrdinal = values.asOrdinals(); + if (valuesOrdinal == null) { + return delegate; + } + BytesRefVector dict = valuesOrdinal.getDictionaryVector(); + final IntVector hashIds; + BytesRef spare = new BytesRef(); + try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { + for (int p = 0; p < dict.getPositionCount(); p++) { + hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); + } + hashIds = hashIdsBuilder.build(); + } + var ordinalIds = valuesOrdinal.getOrdinalsVector(); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + } + } + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + } + } + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int groupId = groupIds.getInt(groupPosition); + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + } + } + + @Override + public void close() { + Releasables.close(hashIds, delegate); + } + }; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index f0397de497426..67f32fc4a4d4e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -94,6 +94,24 @@ $endif$ return new GroupingState(bigArrays); } +$if(BytesRef)$ + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefBlock values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } + + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefVector values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } +$endif$ + public static void combine(GroupingState state, int groupId, $type$ v) { $if(long)$ state.values.add(groupId, v); @@ -241,8 +259,8 @@ $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ - private final LongLongHash values; - private BytesRefHash bytes; + final LongLongHash values; + BytesRefHash bytes; $elseif(int||float)$ private final LongHash values;