Skip to content

Commit 77bd6f6

Browse files
committed
Supply timestamp to aggregate functions (elastic#122174)
(cherry picked from commit a36b327)
1 parent 631725c commit 77bd6f6

File tree

4 files changed

+169
-72
lines changed

4 files changed

+169
-72
lines changed

x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,8 @@
5858
*/
5959
Class<? extends Exception>[] warnExceptions() default {};
6060

61+
/**
62+
* If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function.
63+
*/
64+
boolean includeTimestamps() default false;
6165
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
5555
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
5656
import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
57+
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
58+
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
5759
import static org.elasticsearch.compute.gen.Types.PAGE;
5860
import static org.elasticsearch.compute.gen.Types.WARNINGS;
5961
import static org.elasticsearch.compute.gen.Types.blockType;
@@ -73,9 +75,10 @@ public class AggregatorImplementer {
7375
private final List<TypeMirror> warnExceptions;
7476
private final ExecutableElement init;
7577
private final ExecutableElement combine;
78+
private final List<Parameter> createParameters;
7679
private final ClassName implementation;
7780
private final List<IntermediateStateDesc> intermediateState;
78-
private final List<Parameter> createParameters;
81+
private final boolean includeTimestampVector;
7982

8083
private final AggregationState aggState;
8184
private final AggregationParameter aggParam;
@@ -84,7 +87,8 @@ public AggregatorImplementer(
8487
Elements elements,
8588
TypeElement declarationType,
8689
IntermediateState[] interStateAnno,
87-
List<TypeMirror> warnExceptions
90+
List<TypeMirror> warnExceptions,
91+
boolean includeTimestampVector
8892
) {
8993
this.declarationType = declarationType;
9094
this.warnExceptions = warnExceptions;
@@ -102,10 +106,10 @@ public AggregatorImplementer(
102106
declarationType,
103107
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
104108
requireName("combine"),
105-
requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
109+
combineArgs(aggState, includeTimestampVector)
106110
);
107111
// TODO support multiple parameters
108-
this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType());
112+
this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType());
109113

110114
this.createParameters = init.getParameters()
111115
.stream()
@@ -117,7 +121,20 @@ public AggregatorImplementer(
117121
elements.getPackageOf(declarationType).toString(),
118122
(declarationType.getSimpleName() + "AggregatorFunction").replace("AggregatorAggregator", "Aggregator")
119123
);
120-
intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList();
124+
this.intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList();
125+
this.includeTimestampVector = includeTimestampVector;
126+
}
127+
128+
private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
129+
if (includeTimestampVector) {
130+
return requireArgs(
131+
requireType(aggState.declaredType()),
132+
requireType(TypeName.LONG), // @timestamp
133+
requireAnyType("<aggregation input column type>")
134+
);
135+
} else {
136+
return requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"));
137+
}
121138
}
122139

123140
ClassName implementation() {
@@ -295,10 +312,18 @@ private MethodSpec addRawInput() {
295312
builder.addComment("No masking");
296313
builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type()));
297314
builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type()));
315+
if (includeTimestampVector) {
316+
builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
317+
builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);
318+
319+
builder.beginControlFlow("if (timestampsVector == null) ");
320+
builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
321+
builder.endControlFlow();
322+
}
298323
builder.beginControlFlow("if (vector != null)");
299-
builder.addStatement("addRawVector(vector)");
324+
builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector)" : "addRawVector(vector)");
300325
builder.nextControlFlow("else");
301-
builder.addStatement("addRawBlock(block)");
326+
builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector)" : "addRawBlock(block)");
302327
builder.endControlFlow();
303328
builder.addStatement("return");
304329
}
@@ -307,17 +332,28 @@ private MethodSpec addRawInput() {
307332
builder.addComment("Some positions masked away, others kept");
308333
builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type()));
309334
builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type()));
335+
if (includeTimestampVector) {
336+
builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
337+
builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);
338+
339+
builder.beginControlFlow("if (timestampsVector == null) ");
340+
builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
341+
builder.endControlFlow();
342+
}
310343
builder.beginControlFlow("if (vector != null)");
311-
builder.addStatement("addRawVector(vector, mask)");
344+
builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector, mask)" : "addRawVector(vector, mask)");
312345
builder.nextControlFlow("else");
313-
builder.addStatement("addRawBlock(block, mask)");
346+
builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector, mask)" : "addRawBlock(block, mask)");
314347
builder.endControlFlow();
315348
return builder.build();
316349
}
317350

318351
private MethodSpec addRawVector(boolean masked) {
319352
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector");
320353
builder.addModifiers(Modifier.PRIVATE).addParameter(vectorType(aggParam.type()), "vector");
354+
if (includeTimestampVector) {
355+
builder.addParameter(LONG_VECTOR, "timestamps");
356+
}
321357
if (masked) {
322358
builder.addParameter(BOOLEAN_VECTOR, "mask");
323359
}
@@ -348,6 +384,9 @@ private MethodSpec addRawVector(boolean masked) {
348384
private MethodSpec addRawBlock(boolean masked) {
349385
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock");
350386
builder.addModifiers(Modifier.PRIVATE).addParameter(blockType(aggParam.type()), "block");
387+
if (includeTimestampVector) {
388+
builder.addParameter(LONG_VECTOR, "timestamps");
389+
}
351390
if (masked) {
352391
builder.addParameter(BOOLEAN_VECTOR, "mask");
353392
}
@@ -401,33 +440,57 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) {
401440
});
402441
}
403442

404-
private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) {
405-
builder.addStatement(
406-
"state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))",
407-
returnType,
408-
declarationType,
409-
returnType,
410-
blockVariable,
411-
capitalize(combine.getParameters().get(1).asType().toString())
412-
);
443+
private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) {
444+
// scratch is a BytesRef var that must have been defined before the iteration starts
445+
if (includeTimestampVector) {
446+
builder.addStatement("$T.combine(state, timestamps.getLong(i), $L.getBytesRef(i, scratch))", declarationType, blockVariable);
447+
} else {
448+
builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable);
449+
}
413450
}
414451

415-
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
416-
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
452+
private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) {
453+
if (includeTimestampVector) {
454+
builder.addStatement(
455+
"state.$TValue($T.combine(state.$TValue(), timestamps.getLong(i), $L.get$L(i)))",
456+
returnType,
457+
declarationType,
458+
returnType,
459+
blockVariable,
460+
capitalize(combine.getParameters().get(1).asType().toString())
461+
);
462+
} else {
463+
builder.addStatement(
464+
"state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))",
465+
returnType,
466+
declarationType,
467+
returnType,
468+
blockVariable,
469+
capitalize(combine.getParameters().get(1).asType().toString())
470+
);
471+
}
417472
}
418473

419474
private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable) {
420-
builder.addStatement(
421-
"$T.combine(state, $L.get$L(i))",
422-
declarationType,
423-
blockVariable,
424-
capitalize(combine.getParameters().get(1).asType().toString())
425-
);
475+
if (includeTimestampVector) {
476+
builder.addStatement(
477+
"$T.combine(state, timestamps.getLong(i), $L.get$L(i))",
478+
declarationType,
479+
blockVariable,
480+
capitalize(combine.getParameters().get(1).asType().toString())
481+
);
482+
} else {
483+
builder.addStatement(
484+
"$T.combine(state, $L.get$L(i))",
485+
declarationType,
486+
blockVariable,
487+
capitalize(combine.getParameters().get(1).asType().toString())
488+
);
489+
}
426490
}
427491

428-
private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) {
429-
// scratch is a BytesRef var that must have been defined before the iteration starts
430-
builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable);
492+
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
493+
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
431494
}
432495

433496
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,13 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
8787
);
8888
if (aggClass.getAnnotation(Aggregator.class) != null) {
8989
IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value();
90-
implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes);
90+
implementer = new AggregatorImplementer(
91+
env.getElementUtils(),
92+
aggClass,
93+
intermediateState,
94+
warnExceptionsTypes,
95+
aggClass.getAnnotation(Aggregator.class).includeTimestamps()
96+
);
9197
write(aggClass, "aggregator", implementer.sourceFile(), env);
9298
}
9399
GroupingAggregatorImplementer groupingAggregatorImplementer = null;
@@ -96,13 +102,12 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
96102
if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) {
97103
intermediateState = aggClass.getAnnotation(Aggregator.class).value();
98104
}
99-
boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps();
100105
groupingAggregatorImplementer = new GroupingAggregatorImplementer(
101106
env.getElementUtils(),
102107
aggClass,
103108
intermediateState,
104109
warnExceptionsTypes,
105-
includeTimestamps
110+
aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps()
106111
);
107112
write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env);
108113
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ public GroupingAggregatorImplementer(
112112
requireName("combine"),
113113
combineArgs(aggState, includeTimestampVector)
114114
);
115-
this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType());
115+
// TODO support multiple parameters
116+
this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType());
116117

117118
this.createParameters = init.getParameters()
118119
.stream()
@@ -125,7 +126,7 @@ public GroupingAggregatorImplementer(
125126
(declarationType.getSimpleName() + "GroupingAggregatorFunction").replace("AggregatorGroupingAggregator", "GroupingAggregator")
126127
);
127128

128-
intermediateState = Arrays.stream(interStateAnno)
129+
this.intermediateState = Arrays.stream(interStateAnno)
129130
.map(AggregatorImplementer.IntermediateStateDesc::newIntermediateStateDesc)
130131
.toList();
131132
this.includeTimestampVector = includeTimestampVector;
@@ -370,8 +371,7 @@ private TypeSpec addInput(Consumer<MethodSpec.Builder> addBlock) {
370371
private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
371372
boolean groupsIsBlock = groupsType.toString().endsWith("Block");
372373
boolean valuesIsBlock = valuesType.toString().endsWith("Block");
373-
String methodName = "addRawInput";
374-
MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName);
374+
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput");
375375
builder.addModifiers(Modifier.PRIVATE);
376376
builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values");
377377
if (includeTimestampVector) {
@@ -443,8 +443,6 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
443443
warningsBlock(builder, () -> {
444444
if (aggParam.isBytesRef()) {
445445
combineRawInputForBytesRef(builder, blockVariable, offsetVariable);
446-
} else if (includeTimestampVector) {
447-
combineRawInputWithTimestamp(builder, offsetVariable);
448446
} else if (valueType.isPrimitive() == false) {
449447
throw new IllegalArgumentException("second parameter to combine must be a primitive, array or BytesRef: " + valueType);
450448
} else if (returnType.isPrimitive()) {
@@ -457,48 +455,75 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
457455
});
458456
}
459457

460-
private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
461-
builder.addStatement(
462-
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
463-
declarationType,
464-
blockVariable,
465-
capitalize(aggParam.type().toString()),
466-
offsetVariable
467-
);
458+
private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
459+
// scratch is a BytesRef var that must have been defined before the iteration starts
460+
if (includeTimestampVector) {
461+
if (offsetVariable.contains(" + ")) {
462+
builder.addStatement("var valuePosition = $L", offsetVariable);
463+
offsetVariable = "valuePosition";
464+
}
465+
builder.addStatement(
466+
"$T.combine(state, groupId, timestamps.getLong($L), $L.getBytesRef($L, scratch))",
467+
declarationType,
468+
offsetVariable,
469+
blockVariable,
470+
offsetVariable
471+
);
472+
} else {
473+
builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
474+
}
468475
}
469476

470-
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
471-
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
477+
private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
478+
if (includeTimestampVector) {
479+
if (offsetVariable.contains(" + ")) {
480+
builder.addStatement("var valuePosition = $L", offsetVariable);
481+
offsetVariable = "valuePosition";
482+
}
483+
builder.addStatement(
484+
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
485+
declarationType,
486+
offsetVariable,
487+
capitalize(aggParam.type().toString()),
488+
offsetVariable
489+
);
490+
} else {
491+
builder.addStatement(
492+
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
493+
declarationType,
494+
blockVariable,
495+
capitalize(aggParam.type().toString()),
496+
offsetVariable
497+
);
498+
}
472499
}
473500

474501
private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
475-
builder.addStatement(
476-
"$T.combine(state, groupId, $L.get$L($L))",
477-
declarationType,
478-
blockVariable,
479-
capitalize(aggParam.type().toString()),
480-
offsetVariable
481-
);
482-
}
483-
484-
private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) {
485-
String blockType = capitalize(aggParam.type().toString());
486-
if (offsetVariable.contains(" + ")) {
487-
builder.addStatement("var valuePosition = $L", offsetVariable);
488-
offsetVariable = "valuePosition";
502+
if (includeTimestampVector) {
503+
if (offsetVariable.contains(" + ")) {
504+
builder.addStatement("var valuePosition = $L", offsetVariable);
505+
offsetVariable = "valuePosition";
506+
}
507+
builder.addStatement(
508+
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
509+
declarationType,
510+
offsetVariable,
511+
capitalize(aggParam.type().toString()),
512+
offsetVariable
513+
);
514+
} else {
515+
builder.addStatement(
516+
"$T.combine(state, groupId, $L.get$L($L))",
517+
declarationType,
518+
blockVariable,
519+
capitalize(aggParam.type().toString()),
520+
offsetVariable
521+
);
489522
}
490-
builder.addStatement(
491-
"$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
492-
declarationType,
493-
offsetVariable,
494-
blockType,
495-
offsetVariable
496-
);
497523
}
498524

499-
private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
500-
// scratch is a BytesRef var that must have been defined before the iteration starts
501-
builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
525+
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
526+
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
502527
}
503528

504529
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {

0 commit comments

Comments
 (0)