Skip to content

Commit 0cc1249

Browse files
committed
using std_dev as base function internally - non-generated files
1 parent 9e652a0 commit 0cc1249

File tree

9 files changed

+226
-89
lines changed

9 files changed

+226
-89
lines changed

x-pack/plugin/esql/build.gradle

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,22 +475,22 @@ tasks.named('stringTemplates').configure {
475475

476476
File stdDev = file("src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/X-StdDev.java.st")
477477
template {
478-
it.properties = ["Variation": "StdDev", "StdDev": "true", "lower_case": "population standard deviation", "CAPS": "STD_DEV"]
478+
it.properties = ["Variation": "StdDevPopulation", "StdDev": "true", "lower_case": "population standard deviation", "CAPS": "STD_DEV", "enum": "POPULATION"]
479479
it.inputFile = stdDev
480-
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java"
480+
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevPopulation.java"
481481
}
482482
template {
483-
it.properties = ["Variation": "StdDevSample", "StdDevSample": "true", "lower_case": "sample standard deviation", "CAPS": "STD_DEV_SAMPLE"]
483+
it.properties = ["Variation": "StdDevSample", "StdDevSample": "true", "lower_case": "sample standard deviation", "CAPS": "STD_DEV_SAMPLE", "enum": "SAMPLE"]
484484
it.inputFile = stdDev
485485
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevSample.java"
486486
}
487487
template {
488-
it.properties = ["Variation": "VariancePopulation", "VariancePopulation": "true", "lower_case": "population variance", "CAPS": "VARIANCE_POPULATION"]
488+
it.properties = ["Variation": "VariancePopulation", "VariancePopulation": "true", "lower_case": "population variance", "CAPS": "VARIANCE_POPULATION", "enum": "POPULATION_VARIANCE"]
489489
it.inputFile = stdDev
490490
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/VariancePopulation.java"
491491
}
492492
template {
493-
it.properties = ["Variation": "VarianceSample", "VarianceSample": "true", "lower_case": "sample variance", "CAPS": "VARIANCE_SAMPLE"]
493+
it.properties = ["Variation": "VarianceSample", "VarianceSample": "true", "lower_case": "sample variance", "CAPS": "VARIANCE_SAMPLE", "enum": "SAMPLE_VARIANCE"]
494494
it.inputFile = stdDev
495495
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/VarianceSample.java"
496496
}

x-pack/plugin/esql/compute/build.gradle

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,6 @@ def addOccurrence(props, Occurrence) {
8888
return newProps
8989
}
9090

91-
def addStdDevType(props, Variation) {
92-
def newProps = props.collectEntries { [(it.key): it.value] }
93-
newProps["Variation"] = Variation
94-
def enumName = Variation == "StdDev" ? "POPULATION" :
95-
Variation == "StdDevSample" ? "SAMPLE" :
96-
Variation == "VariancePopulation" ? "POPULATION_VARIANCE" :
97-
"SAMPLE_VARIANCE"
98-
newProps["EnumName"] = enumName
99-
return newProps
100-
}
101-
10291
tasks.named('stringTemplates').configure {
10392
var intProperties = prop("Int", "Int", "int", "Integer", "INT", "Integer.BYTES", "IntArray", "LongHash")
10493
var floatProperties = prop("Float", "Float", "float", "Float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
@@ -639,29 +628,25 @@ tasks.named('stringTemplates').configure {
639628
}
640629

641630
File stdDevAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st")
642-
["StdDev", "StdDevSample", "VariancePopulation", "VarianceSample"].forEach { Variation ->
643-
{
644-
template {
645-
it.properties = addStdDevType(intProperties, Variation)
646-
it.inputFile = stdDevAggregatorInputFile
647-
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}IntAggregator.java"
648-
}
649-
template {
650-
it.properties = addStdDevType(longProperties, Variation)
651-
it.inputFile = stdDevAggregatorInputFile
652-
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}LongAggregator.java"
653-
}
654-
template {
655-
it.properties = addStdDevType(floatProperties, Variation)
656-
it.inputFile = stdDevAggregatorInputFile
657-
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}FloatAggregator.java"
658-
}
659-
template {
660-
it.properties = addStdDevType(doubleProperties, Variation)
661-
it.inputFile = stdDevAggregatorInputFile
662-
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}DoubleAggregator.java"
663-
}
664-
}
631+
template {
632+
it.properties = intProperties
633+
it.inputFile = stdDevAggregatorInputFile
634+
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevIntAggregator.java"
635+
}
636+
template {
637+
it.properties = longProperties
638+
it.inputFile = stdDevAggregatorInputFile
639+
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevLongAggregator.java"
640+
}
641+
template {
642+
it.properties = floatProperties
643+
it.inputFile = stdDevAggregatorInputFile
644+
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java"
645+
}
646+
template {
647+
it.properties = doubleProperties
648+
it.inputFile = stdDevAggregatorInputFile
649+
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java"
665650
}
666651

667652
File sampleAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st")

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,45 @@ public final class StdDevStates {
2020

2121
private StdDevStates() {}
2222

23-
enum Variation {
24-
SAMPLE,
25-
POPULATION,
26-
SAMPLE_VARIANCE,
27-
POPULATION_VARIANCE
23+
public enum Variation {
24+
SAMPLE(0),
25+
POPULATION(1),
26+
SAMPLE_VARIANCE(2),
27+
POPULATION_VARIANCE(3);
28+
29+
private final int index;
30+
31+
Variation(int index) {
32+
this.index = index;
33+
}
34+
35+
public int getIndex() {
36+
return index;
37+
}
38+
39+
private static Variation getVariation(int index) {
40+
return switch (index) {
41+
case 0 -> SAMPLE;
42+
case 1 -> POPULATION;
43+
case 2 -> SAMPLE_VARIANCE;
44+
case 3 -> POPULATION_VARIANCE;
45+
default -> POPULATION;
46+
};
47+
}
2848
}
2949

3050
static final class SingleState implements AggregatorState {
3151

3252
private final WelfordAlgorithm welfordAlgorithm;
53+
private final Variation variation;
3354

34-
SingleState() {
35-
this(0, 0, 0);
55+
SingleState(int variation) {
56+
this(0, 0, 0, variation);
3657
}
3758

38-
SingleState(double mean, double m2, long count) {
59+
SingleState(double mean, double m2, long count, int variation) {
3960
this.welfordAlgorithm = new WelfordAlgorithm(mean, m2, count);
61+
this.variation = Variation.getVariation(variation);
4062
}
4163

4264
public void add(long value) {
@@ -79,7 +101,7 @@ public long count() {
79101
return welfordAlgorithm.count();
80102
}
81103

82-
public double evaluateFinal(Variation variation) {
104+
public double evaluateFinal() {
83105
return switch (variation) {
84106
case SAMPLE -> welfordAlgorithm.evaluateSample();
85107
case POPULATION -> welfordAlgorithm.evaluatePopulation();
@@ -88,24 +110,26 @@ public double evaluateFinal(Variation variation) {
88110
};
89111
}
90112

91-
public Block evaluateFinal(DriverContext driverContext, Variation variation) {
113+
public Block evaluateFinal(DriverContext driverContext) {
92114
final long count = count();
93115
final double m2 = m2();
94116
if (count == 0 || Double.isFinite(m2) == false) {
95117
return driverContext.blockFactory().newConstantNullBlock(1);
96118
}
97-
return driverContext.blockFactory().newConstantDoubleBlockWith(evaluateFinal(variation), 1);
119+
return driverContext.blockFactory().newConstantDoubleBlockWith(evaluateFinal(), 1);
98120
}
99121
}
100122

101123
static final class GroupingState implements GroupingAggregatorState {
102124

103125
private ObjectArray<WelfordAlgorithm> states;
104126
private final BigArrays bigArrays;
127+
private final Variation variation;
105128

106-
GroupingState(BigArrays bigArrays) {
129+
GroupingState(BigArrays bigArrays, int variation) {
107130
this.states = bigArrays.newObjectArray(1);
108131
this.bigArrays = bigArrays;
132+
this.variation = Variation.getVariation(variation);
109133
}
110134

111135
WelfordAlgorithm getOrNull(int position) {
@@ -190,7 +214,7 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive
190214
}
191215
}
192216

193-
public Block evaluateFinal(IntVector selected, DriverContext driverContext, Variation variation) {
217+
public Block evaluateFinal(IntVector selected, DriverContext driverContext) {
194218
try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) {
195219
for (int i = 0; i < selected.getPositionCount(); i++) {
196220
final var groupId = selected.getInt(i);

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import org.elasticsearch.compute.operator.DriverContext;
2626
@IntermediateState(name = "count", type = "LONG") }
2727
)
2828
@GroupingAggregator
29-
public class $Variation$$Type$Aggregator {
29+
public class StdDev$Type$Aggregator {
3030

31-
public static StdDevStates.SingleState initSingle() {
32-
return new StdDevStates.SingleState();
31+
public static StdDevStates.SingleState initSingle(int variation) {
32+
return new StdDevStates.SingleState(variation);
3333
}
3434

3535
public static void combine(StdDevStates.SingleState state, $type$ value) {
@@ -41,11 +41,11 @@ public class $Variation$$Type$Aggregator {
4141
}
4242

4343
public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) {
44-
return state.evaluateFinal(driverContext, StdDevStates.Variation.$EnumName$);
44+
return state.evaluateFinal(driverContext);
4545
}
4646

47-
public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) {
48-
return new StdDevStates.GroupingState(bigArrays);
47+
public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays, int variation) {
48+
return new StdDevStates.GroupingState(bigArrays, variation);
4949
}
5050

5151
public static void combine(StdDevStates.GroupingState current, int groupId, $type$ value) {
@@ -61,6 +61,6 @@ public class $Variation$$Type$Aggregator {
6161
}
6262

6363
public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) {
64-
return state.evaluateFinal(selected, driverContext, StdDevStates.Variation.$EnumName$);
64+
return state.evaluateFinal(selected, driverContext);
6565
}
6666
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sample;
3939
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
4040
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent;
41-
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
41+
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDevPopulation;
4242
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDevSample;
4343
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
4444
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
@@ -319,7 +319,7 @@ private static FunctionDefinition[][] functions() {
319319
def(Min.class, uni(Min::new), "min"),
320320
def(Percentile.class, bi(Percentile::new), "percentile"),
321321
def(Sample.class, bi(Sample::new), "sample"),
322-
def(StdDev.class, uni(StdDev::new), "std_dev"),
322+
def(StdDevPopulation.class, uni(StdDevPopulation::new), "std_dev_population", "std_dev"),
323323
def(StdDevSample.class, uni(StdDevSample::new), "std_dev_sample"),
324324
def(Sum.class, uni(Sum::new), "sum"),
325325
def(Top.class, tri(Top::new), "top"),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
2828
SpatialCentroid.ENTRY,
2929
SpatialExtent.ENTRY,
3030
StdDev.ENTRY,
31+
StdDevPopulation.ENTRY,
3132
StdDevSample.ENTRY,
3233
Sum.ENTRY,
3334
Top.ENTRY,
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.aggregate;
9+
10+
import org.elasticsearch.TransportVersions;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
14+
import org.elasticsearch.compute.aggregation.StdDevDoubleAggregatorFunctionSupplier;
15+
import org.elasticsearch.compute.aggregation.StdDevIntAggregatorFunctionSupplier;
16+
import org.elasticsearch.compute.aggregation.StdDevLongAggregatorFunctionSupplier;
17+
import org.elasticsearch.compute.aggregation.StdDevStates;
18+
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
19+
import org.elasticsearch.xpack.esql.core.expression.Expression;
20+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
21+
import org.elasticsearch.xpack.esql.core.expression.Literal;
22+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
23+
import org.elasticsearch.xpack.esql.core.tree.Source;
24+
import org.elasticsearch.xpack.esql.core.type.DataType;
25+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
26+
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
27+
import org.elasticsearch.xpack.esql.expression.function.Param;
28+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
29+
import org.elasticsearch.xpack.esql.planner.ToAggregator;
30+
31+
import java.io.IOException;
32+
import java.util.List;
33+
34+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
35+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
36+
import static org.elasticsearch.xpack.esql.core.util.CollectionUtils.nullSafeList;
37+
38+
public class StdDev extends AggregateFunction implements ToAggregator {
39+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "StdDev", StdDev::new);
40+
41+
private final Expression variation;
42+
43+
@FunctionInfo(
44+
returnType = "double",
45+
description = "TODO: description.", // TODO
46+
type = FunctionType.AGGREGATE
47+
)
48+
public StdDev(
49+
Source source,
50+
@Param(name = "number", type = { "double", "integer", "long" }) Expression field,
51+
@Param(name = "variation", type = "int", description = "index of stddev variation") Expression variation
52+
) {
53+
this(source, field, Literal.TRUE, variation);
54+
}
55+
56+
public StdDev(Source source, Expression field, Expression filter, Expression variation) {
57+
this(
58+
source,
59+
field,
60+
filter,
61+
variation != null
62+
? List.of(variation)
63+
: List.of()
64+
);
65+
}
66+
67+
private StdDev(Source source, Expression field, Expression filter, List<Expression> params) {
68+
super(source, field, filter, params);
69+
this.variation = params.size() > 0 ? params.get(0) : null;
70+
}
71+
72+
private StdDev(StreamInput in) throws IOException {
73+
this(
74+
Source.readFrom((PlanStreamInput) in),
75+
in.readNamedWriteable(Expression.class),
76+
in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE,
77+
in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)
78+
? in.readNamedWriteableCollectionAsList(Expression.class)
79+
: nullSafeList(in.readOptionalNamedWriteable(Expression.class))
80+
);
81+
}
82+
83+
@Override
84+
public String getWriteableName() {
85+
return ENTRY.name;
86+
}
87+
88+
@Override
89+
public DataType dataType() {
90+
return DataType.DOUBLE;
91+
}
92+
93+
@Override
94+
protected Expression.TypeResolution resolveType() {
95+
return isType(
96+
field(),
97+
dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
98+
sourceText(),
99+
DEFAULT,
100+
"numeric except unsigned_long or counter types"
101+
);
102+
}
103+
104+
@Override
105+
protected NodeInfo<StdDev> info() {
106+
return NodeInfo.create(this, StdDev::new, field(), filter(), variation);
107+
}
108+
109+
@Override
110+
public StdDev replaceChildren(List<Expression> newChildren) {
111+
return new StdDev(source(), newChildren.get(0), newChildren.get(1), newChildren.size() > 2 ? newChildren.get(2) : null);
112+
}
113+
114+
public StdDev withFilter(Expression filter) {
115+
return new StdDev(source(), field(), filter, variation);
116+
}
117+
118+
@Override
119+
public final AggregatorFunctionSupplier supplier() {
120+
DataType type = field().dataType();
121+
int variation = this.variation == null
122+
? StdDevStates.Variation.POPULATION.getIndex()
123+
: ((Number) (this.variation.fold(FoldContext.small() /* TODO remove me */))).intValue();
124+
if (type == DataType.LONG) {
125+
return new StdDevLongAggregatorFunctionSupplier(variation);
126+
}
127+
if (type == DataType.INTEGER) {
128+
return new StdDevIntAggregatorFunctionSupplier(variation);
129+
}
130+
if (type == DataType.DOUBLE) {
131+
return new StdDevDoubleAggregatorFunctionSupplier(variation);
132+
}
133+
throw EsqlIllegalArgumentException.illegalDataType(type);
134+
}
135+
}

0 commit comments

Comments
 (0)