Skip to content

Commit 3aeb7d5

Browse files
authored
Replace aggstate in aggs whose internal state is one of the primitives (ESQL-1375)
This commit moves aggs whose internal state is one of the primitives, over to the new intermediate agg state mechanism. We can then remove quite a bit of the internal serialization logic.
1 parent cbd4992 commit 3aeb7d5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+481
-802
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import static java.util.stream.Collectors.joining;
3333
import static org.elasticsearch.compute.gen.Methods.findMethod;
3434
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
35+
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
3536
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
3637
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR;
3738
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR_BUILDER;
@@ -398,7 +399,7 @@ private void combineRawInputForBytesRef(MethodSpec.Builder builder, String block
398399
private MethodSpec addIntermediateInput() {
399400
MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateInput");
400401
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).addParameter(PAGE, "page");
401-
if (combineIntermediate != null) {
402+
if (isAggState() == false) {
402403
builder.addStatement("assert channels.size() == intermediateBlockCount()");
403404
builder.addStatement("assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size()");
404405
int count = 0;
@@ -420,10 +421,25 @@ private MethodSpec addIntermediateInput() {
420421
.map(s -> first + ".getPositionCount() == " + s + ".getPositionCount()")
421422
.collect(joining(" && "))
422423
);
423-
builder.addStatement(
424-
"$T.combineIntermediate(state, " + intermediateState.stream().map(IntermediateStateDesc::name).collect(joining(", ")) + ")",
425-
declarationType
426-
);
424+
if (hasPrimitiveState()) {
425+
assert intermediateState.size() == 2;
426+
assert intermediateState.get(1).name().equals("seen");
427+
builder.beginControlFlow("if (seen.getBoolean(0))");
428+
{
429+
var state = intermediateState.get(0);
430+
var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))";
431+
builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod());
432+
builder.addStatement("state.seen(true)");
433+
builder.endControlFlow();
434+
}
435+
} else {
436+
builder.addStatement(
437+
"$T.combineIntermediate(state, "
438+
+ intermediateState.stream().map(IntermediateStateDesc::name).collect(joining(", "))
439+
+ ")",
440+
declarationType
441+
);
442+
}
427443
} else {
428444
builder.addStatement("Block block = page.getBlock(channels.get(0))");
429445
builder.addStatement("$T vector = block.asVector()", VECTOR);
@@ -480,8 +496,9 @@ private MethodSpec evaluateIntermediate() {
480496
.addModifiers(Modifier.PUBLIC)
481497
.addParameter(BLOCK_ARRAY, "blocks")
482498
.addParameter(TypeName.INT, "offset");
483-
if (combineIntermediate != null) {
484-
builder.addStatement("$T.evaluateIntermediate(state, blocks, offset)", declarationType);
499+
if (isAggState() == false) {
500+
assert hasPrimitiveState();
501+
builder.addStatement("state.toIntermediate(blocks, offset)");
485502
} else {
486503
ParameterizedTypeName stateBlockBuilderType = ParameterizedTypeName.get(
487504
AGGREGATOR_STATE_VECTOR_BUILDER,
@@ -557,4 +574,16 @@ private MethodSpec close() {
557574
private ParameterizedTypeName stateBlockType() {
558575
return ParameterizedTypeName.get(AGGREGATOR_STATE_VECTOR, stateType);
559576
}
577+
578+
private boolean isAggState() {
579+
return intermediateState.get(0).name().equals("aggstate");
580+
}
581+
582+
private boolean hasPrimitiveState() {
583+
return switch (stateType.toString()) {
584+
case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.LongState",
585+
"org.elasticsearch.compute.aggregation.DoubleState" -> true;
586+
default -> false;
587+
};
588+
}
560589
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import static org.elasticsearch.compute.gen.AggregatorImplementer.valueVectorType;
3737
import static org.elasticsearch.compute.gen.Methods.findMethod;
3838
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
39+
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
3940
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR;
4041
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_STATE_VECTOR_BUILDER;
4142
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
@@ -394,7 +395,7 @@ private MethodSpec addIntermediateInput() {
394395
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
395396
builder.addParameter(LONG_VECTOR, "groupIdVector").addParameter(PAGE, "page");
396397

397-
if (combineIntermediate != null) {
398+
if (isAggState() == false) {
398399
builder.addStatement("assert channels.size() == intermediateBlockCount()");
399400
builder.addStatement("assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size()");
400401
int count = 0;
@@ -415,12 +416,34 @@ private MethodSpec addIntermediateInput() {
415416
.map(s -> first + ".getPositionCount() == " + s + ".getPositionCount()")
416417
.collect(joining(" && "))
417418
);
418-
builder.addStatement(
419-
"$T.combineIntermediate(groupIdVector, state, "
420-
+ intermediateState.stream().map(IntermediateStateDesc::name).collect(joining(", "))
421-
+ ")",
422-
declarationType
423-
);
419+
if (hasPrimitiveState()) {
420+
assert intermediateState.size() == 2;
421+
assert intermediateState.get(1).name().equals("seen");
422+
builder.beginControlFlow("for (int position = 0; position < groupIdVector.getPositionCount(); position++)");
423+
{
424+
builder.addStatement("int groupId = Math.toIntExact(groupIdVector.getLong(position))");
425+
builder.beginControlFlow("if (seen.getBoolean(position))");
426+
{
427+
var name = intermediateState.get(0).name();
428+
var m = vectorAccessorName(intermediateState.get(0).elementType());
429+
builder.addStatement(
430+
"state.set($T.combine(state.getOrDefault(groupId), " + name + "." + m + "(position)), groupId)",
431+
declarationType
432+
);
433+
builder.nextControlFlow("else");
434+
builder.addStatement("state.putNull(groupId)");
435+
builder.endControlFlow();
436+
}
437+
builder.endControlFlow();
438+
}
439+
} else {
440+
builder.addStatement(
441+
"$T.combineIntermediate(groupIdVector, state, "
442+
+ intermediateState.stream().map(IntermediateStateDesc::name).collect(joining(", "))
443+
+ ")",
444+
declarationType
445+
);
446+
}
424447
} else {
425448
builder.addStatement("Block block = page.getBlock(channels.get(0))");
426449
builder.addStatement("$T vector = block.asVector()", VECTOR);
@@ -478,8 +501,9 @@ private MethodSpec evaluateIntermediate() {
478501
.addParameter(BLOCK_ARRAY, "blocks")
479502
.addParameter(TypeName.INT, "offset")
480503
.addParameter(INT_VECTOR, "selected");
481-
if (combineIntermediate != null) {
482-
builder.addStatement("$T.evaluateIntermediate(state, blocks, offset, selected)", declarationType);
504+
if (isAggState() == false) {
505+
assert hasPrimitiveState();
506+
builder.addStatement("state.toIntermediate(blocks, offset, selected)");
483507
} else {
484508
ParameterizedTypeName stateBlockBuilderType = ParameterizedTypeName.get(
485509
AGGREGATOR_STATE_VECTOR_BUILDER,
@@ -534,4 +558,16 @@ private MethodSpec close() {
534558
private ParameterizedTypeName stateBlockType() {
535559
return ParameterizedTypeName.get(AGGREGATOR_STATE_VECTOR, stateType);
536560
}
561+
562+
private boolean isAggState() {
563+
return intermediateState.get(0).name().equals("aggstate");
564+
}
565+
566+
private boolean hasPrimitiveState() {
567+
return switch (stateType.toString()) {
568+
case "org.elasticsearch.compute.aggregation.IntArrayState", "org.elasticsearch.compute.aggregation.LongArrayState",
569+
"org.elasticsearch.compute.aggregation.DoubleArrayState" -> true;
570+
default -> false;
571+
};
572+
}
537573
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,19 @@ static String getMethod(TypeName elementType) {
117117
}
118118
throw new IllegalArgumentException("unknown get method for [" + elementType + "]");
119119
}
120+
121+
/**
122+
* Returns the name of the method used to get {@code valueType} instances
123+
* from vectors or blocks.
124+
*/
125+
static String vectorAccessorName(String elementTypeName) {
126+
return switch (elementTypeName) {
127+
case "INT" -> "getInt";
128+
case "LONG" -> "getLong";
129+
case "DOUBLE" -> "getDouble";
130+
default -> throw new IllegalArgumentException(
131+
"don't know how to fetch primitive values from " + elementTypeName + ". define combineStates."
132+
);
133+
};
134+
}
120135
}

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@
1212
import org.elasticsearch.common.util.DoubleArray;
1313
import org.elasticsearch.compute.ann.Experimental;
1414
import org.elasticsearch.compute.data.Block;
15+
import org.elasticsearch.compute.data.BooleanBlock;
1516
import org.elasticsearch.compute.data.DoubleBlock;
1617
import org.elasticsearch.compute.data.DoubleVector;
18+
import org.elasticsearch.compute.data.IntVector;
1719
import org.elasticsearch.core.Releasables;
1820

19-
import java.lang.invoke.MethodHandles;
20-
import java.lang.invoke.VarHandle;
21-
import java.nio.ByteOrder;
22-
import java.util.Objects;
23-
2421
/**
2522
* Aggregator state for an array of doubles.
2623
* This class is generated. Do not edit it.
@@ -110,9 +107,23 @@ private void ensureCapacity(int position) {
110107
}
111108
}
112109

110+
/** Extracts an intermediate view of the contents of this state. */
111+
void toIntermediate(Block[] blocks, int offset, IntVector selected) {
112+
assert blocks.length >= offset + 2;
113+
var valuesBuilder = DoubleBlock.newBlockBuilder(selected.getPositionCount());
114+
var nullsBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
115+
for (int i = 0; i < selected.getPositionCount(); i++) {
116+
int group = selected.getInt(i);
117+
valuesBuilder.appendDouble(values.get(group));
118+
nullsBuilder.appendBoolean(hasValue(group));
119+
}
120+
blocks[offset + 0] = valuesBuilder.build();
121+
blocks[offset + 1] = nullsBuilder.build();
122+
}
123+
113124
@Override
114125
public long getEstimatedSize() {
115-
return Long.BYTES + (largestIndex + 1L) * Double.BYTES + LongArrayState.estimateSerializeSize(nonNulls);
126+
throw new UnsupportedOperationException();
116127
}
117128

118129
@Override
@@ -122,41 +133,6 @@ public void close() {
122133

123134
@Override
124135
public AggregatorStateSerializer<DoubleArrayState> serializer() {
125-
return new DoubleArrayStateSerializer();
126-
}
127-
128-
private static class DoubleArrayStateSerializer implements AggregatorStateSerializer<DoubleArrayState> {
129-
private static final VarHandle lengthHandle = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN);
130-
private static final VarHandle valueHandle = MethodHandles.byteArrayViewVarHandle(double[].class, ByteOrder.BIG_ENDIAN);
131-
132-
@Override
133-
public int size() {
134-
return Double.BYTES;
135-
}
136-
137-
@Override
138-
public int serialize(DoubleArrayState state, byte[] ba, int offset, org.elasticsearch.compute.data.IntVector selected) {
139-
lengthHandle.set(ba, offset, selected.getPositionCount());
140-
offset += Long.BYTES;
141-
for (int i = 0; i < selected.getPositionCount(); i++) {
142-
valueHandle.set(ba, offset, state.values.get(selected.getInt(i)));
143-
offset += Double.BYTES;
144-
}
145-
final int valuesBytes = Long.BYTES + (Double.BYTES * selected.getPositionCount());
146-
return valuesBytes + LongArrayState.serializeBitArray(state.nonNulls, ba, offset);
147-
}
148-
149-
@Override
150-
public void deserialize(DoubleArrayState state, byte[] ba, int offset) {
151-
Objects.requireNonNull(state);
152-
int positions = (int) (long) lengthHandle.get(ba, offset);
153-
offset += Long.BYTES;
154-
for (int i = 0; i < positions; i++) {
155-
state.set((double) valueHandle.get(ba, offset), i);
156-
offset += Double.BYTES;
157-
}
158-
state.largestIndex = positions - 1;
159-
state.nonNulls = LongArrayState.deseralizeBitArray(state.bigArrays, ba, offset);
160-
}
136+
throw new UnsupportedOperationException();
161137
}
162138
}

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleState.java

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@
88
package org.elasticsearch.compute.aggregation;
99

1010
import org.elasticsearch.compute.ann.Experimental;
11-
import org.elasticsearch.compute.data.IntVector;
12-
13-
import java.lang.invoke.MethodHandles;
14-
import java.lang.invoke.VarHandle;
15-
import java.nio.ByteOrder;
16-
import java.util.Objects;
11+
import org.elasticsearch.compute.data.Block;
12+
import org.elasticsearch.compute.data.ConstantBooleanVector;
13+
import org.elasticsearch.compute.data.ConstantDoubleVector;
1714

1815
/**
1916
* Aggregator state for a single double.
@@ -48,41 +45,23 @@ void seen(boolean seen) {
4845
this.seen = seen;
4946
}
5047

48+
/** Extracts an intermediate view of the contents of this state. */
49+
void toIntermediate(Block[] blocks, int offset) {
50+
assert blocks.length >= offset + 2;
51+
blocks[offset + 0] = new ConstantDoubleVector(value, 1).asBlock();
52+
blocks[offset + 1] = new ConstantBooleanVector(seen, 1).asBlock();
53+
}
54+
5155
@Override
5256
public long getEstimatedSize() {
53-
return Double.BYTES + 1;
57+
throw new UnsupportedOperationException();
5458
}
5559

5660
@Override
5761
public void close() {}
5862

5963
@Override
6064
public AggregatorStateSerializer<DoubleState> serializer() {
61-
return new DoubleStateSerializer();
62-
}
63-
64-
private static class DoubleStateSerializer implements AggregatorStateSerializer<DoubleState> {
65-
private static final VarHandle handle = MethodHandles.byteArrayViewVarHandle(double[].class, ByteOrder.BIG_ENDIAN);
66-
67-
@Override
68-
public int size() {
69-
return Double.BYTES + 1;
70-
}
71-
72-
@Override
73-
public int serialize(DoubleState state, byte[] ba, int offset, IntVector selected) {
74-
assert selected.getPositionCount() == 1;
75-
assert selected.getInt(0) == 0;
76-
handle.set(ba, offset, state.value);
77-
ba[offset + Double.BYTES] = (byte) (state.seen ? 1 : 0);
78-
return size(); // number of bytes written
79-
}
80-
81-
@Override
82-
public void deserialize(DoubleState state, byte[] ba, int offset) {
83-
Objects.requireNonNull(state);
84-
state.value = (double) handle.get(ba, offset);
85-
state.seen = ba[offset + Double.BYTES] == (byte) 1;
86-
}
65+
throw new UnsupportedOperationException();
8766
}
8867
}

0 commit comments

Comments
 (0)