Skip to content

Commit acfafbd

Browse files
committed
manual changes
1 parent ada0adb commit acfafbd

File tree

9 files changed

+234
-31
lines changed

9 files changed

+234
-31
lines changed

x-pack/plugin/esql/build.gradle

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,4 +472,26 @@ tasks.named('stringTemplates').configure {
472472
it.inputFile = roundToInput
473473
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/scalar/math/RoundToDouble.java"
474474
}
475+
476+
File stdDev = file("src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/X-StdDev.java.st")
477+
template {
478+
it.properties = ["Variation": "StdDev", "StdDev": "true", "lower_case": "population standard deviation", "CAPS": "STD_DEV"]
479+
it.inputFile = stdDev
480+
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/StdDev.java"
481+
}
482+
template {
483+
it.properties = ["Variation": "StdDevSample", "StdDevSample": "true", "lower_case": "sample standard deviation", "CAPS": "STD_DEV_SAMPLE"]
484+
it.inputFile = stdDev
485+
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/StdDevSample.java"
486+
}
487+
template {
488+
it.properties = ["Variation": "VariancePopulation", "VariancePopulation": "true", "lower_case": "population variance", "CAPS": "VARIANCE_POPULATION"]
489+
it.inputFile = stdDev
490+
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/VariancePopulation.java"
491+
}
492+
template {
493+
it.properties = ["Variation": "VarianceSample", "VarianceSample": "true", "lower_case": "sample variance", "CAPS": "VARIANCE_SAMPLE"]
494+
it.inputFile = stdDev
495+
it.outputFile = "org/elasticsearch/xpack/esql/expression/function/aggregate/VarianceSample.java"
496+
}
475497
}

x-pack/plugin/esql/compute/build.gradle

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ def addOccurrence(props, Occurrence) {
8888
return newProps
8989
}
9090

91+
def addStdDevType(props, Variation) {
92+
def newProps = props.collectEntries { [(it.key): it.value] }
93+
newProps["Variation"] = Variation
94+
def enumName = Variation == "StdDev" ? "POPULATION" :
95+
Variation == "StdDevSample" ? "SAMPLE" :
96+
Variation == "VariancePopulation" ? "POPULATION_VARIANCE" :
97+
"SAMPLE_VARIANCE"
98+
newProps["EnumName"] = enumName
99+
return newProps
100+
}
101+
91102
tasks.named('stringTemplates').configure {
92103
var intProperties = prop("Int", "Int", "int", "Integer", "INT", "Integer.BYTES", "IntArray", "LongHash")
93104
var floatProperties = prop("Float", "Float", "float", "Float", "FLOAT", "Float.BYTES", "FloatArray", "LongHash")
@@ -628,25 +639,29 @@ tasks.named('stringTemplates').configure {
628639
}
629640

630641
File stdDevAggregatorInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st")
631-
template {
632-
it.properties = intProperties
633-
it.inputFile = stdDevAggregatorInputFile
634-
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevIntAggregator.java"
635-
}
636-
template {
637-
it.properties = longProperties
638-
it.inputFile = stdDevAggregatorInputFile
639-
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevLongAggregator.java"
640-
}
641-
template {
642-
it.properties = floatProperties
643-
it.inputFile = stdDevAggregatorInputFile
644-
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevFloatAggregator.java"
645-
}
646-
template {
647-
it.properties = doubleProperties
648-
it.inputFile = stdDevAggregatorInputFile
649-
it.outputFile = "org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java"
642+
["StdDev", "StdDevSample", "VariancePopulation", "VarianceSample"].forEach { Variation ->
643+
{
644+
template {
645+
it.properties = addStdDevType(intProperties, Variation)
646+
it.inputFile = stdDevAggregatorInputFile
647+
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}IntAggregator.java"
648+
}
649+
template {
650+
it.properties = addStdDevType(longProperties, Variation)
651+
it.inputFile = stdDevAggregatorInputFile
652+
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}LongAggregator.java"
653+
}
654+
template {
655+
it.properties = addStdDevType(floatProperties, Variation)
656+
it.inputFile = stdDevAggregatorInputFile
657+
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}FloatAggregator.java"
658+
}
659+
template {
660+
it.properties = addStdDevType(doubleProperties, Variation)
661+
it.inputFile = stdDevAggregatorInputFile
662+
it.outputFile = "org/elasticsearch/compute/aggregation/${Variation}DoubleAggregator.java"
663+
}
664+
}
650665
}
651666

652667
File sampleAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st")

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ public final class StdDevStates {
2020

2121
private StdDevStates() {}
2222

23+
enum Variation {
24+
SAMPLE,
25+
POPULATION,
26+
SAMPLE_VARIANCE,
27+
POPULATION_VARIANCE
28+
}
29+
2330
static final class SingleState implements AggregatorState {
2431

2532
private final WelfordAlgorithm welfordAlgorithm;
@@ -72,17 +79,22 @@ public long count() {
7279
return welfordAlgorithm.count();
7380
}
7481

75-
public double evaluateFinal() {
76-
return welfordAlgorithm.evaluate();
82+
public double evaluateFinal(Variation variation) {
83+
return switch (variation) {
84+
case SAMPLE -> welfordAlgorithm.evaluateSample();
85+
case POPULATION -> welfordAlgorithm.evaluatePopulation();
86+
case SAMPLE_VARIANCE -> welfordAlgorithm.evaluateSampleVariance();
87+
case POPULATION_VARIANCE -> welfordAlgorithm.evaluatePopulationVariance();
88+
};
7789
}
7890

79-
public Block evaluateFinal(DriverContext driverContext) {
91+
public Block evaluateFinal(DriverContext driverContext, Variation variation) {
8092
final long count = count();
8193
final double m2 = m2();
8294
if (count == 0 || Double.isFinite(m2) == false) {
8395
return driverContext.blockFactory().newConstantNullBlock(1);
8496
}
85-
return driverContext.blockFactory().newConstantDoubleBlockWith(evaluateFinal(), 1);
97+
return driverContext.blockFactory().newConstantDoubleBlockWith(evaluateFinal(variation), 1);
8698
}
8799
}
88100

@@ -178,7 +190,7 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive
178190
}
179191
}
180192

181-
public Block evaluateFinal(IntVector selected, DriverContext driverContext) {
193+
public Block evaluateFinal(IntVector selected, DriverContext driverContext, Variation variation) {
182194
try (DoubleBlock.Builder builder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) {
183195
for (int i = 0; i < selected.getPositionCount(); i++) {
184196
final var groupId = selected.getInt(i);
@@ -189,7 +201,13 @@ public Block evaluateFinal(IntVector selected, DriverContext driverContext) {
189201
if (count == 0 || Double.isFinite(m2) == false) {
190202
builder.appendNull();
191203
} else {
192-
builder.appendDouble(st.evaluate());
204+
double result = switch (variation) {
205+
case SAMPLE -> st.evaluateSample();
206+
case POPULATION -> st.evaluatePopulation();
207+
case SAMPLE_VARIANCE -> st.evaluateSampleVariance();
208+
case POPULATION_VARIANCE -> st.evaluatePopulationVariance();
209+
};
210+
builder.appendDouble(result);
193211
}
194212
} else {
195213
builder.appendNull();

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/WelfordAlgorithm.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,19 @@ public void add(double meanValue, double m2Value, long countValue) {
7373
count += countValue;
7474
}
7575

76-
public double evaluate() {
77-
return count < 2 ? 0 : Math.sqrt(m2 / count);
76+
public double evaluatePopulation() {
77+
return Math.sqrt(evaluatePopulationVariance());
78+
}
79+
80+
public double evaluateSample() {
81+
return Math.sqrt(evaluateSampleVariance());
82+
}
83+
84+
public double evaluatePopulationVariance() {
85+
return count < 2 ? 0 : m2 / count;
86+
}
87+
88+
public double evaluateSampleVariance() {
89+
return count < 2 ? 0 : m2 / (count - 1);
7890
}
7991
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-StdDevAggregator.java.st

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.elasticsearch.compute.operator.DriverContext;
2626
@IntermediateState(name = "count", type = "LONG") }
2727
)
2828
@GroupingAggregator
29-
public class StdDev$Type$Aggregator {
29+
public class $Variation$$Type$Aggregator {
3030

3131
public static StdDevStates.SingleState initSingle() {
3232
return new StdDevStates.SingleState();
@@ -41,7 +41,7 @@ public class StdDev$Type$Aggregator {
4141
}
4242

4343
public static Block evaluateFinal(StdDevStates.SingleState state, DriverContext driverContext) {
44-
return state.evaluateFinal(driverContext);
44+
return state.evaluateFinal(driverContext, StdDevStates.Variation.$EnumName$);
4545
}
4646

4747
public static StdDevStates.GroupingState initGrouping(BigArrays bigArrays) {
@@ -61,6 +61,6 @@ public class StdDev$Type$Aggregator {
6161
}
6262

6363
public static Block evaluateFinal(StdDevStates.GroupingState state, IntVector selected, DriverContext driverContext) {
64-
return state.evaluateFinal(selected, driverContext);
64+
return state.evaluateFinal(selected, driverContext, StdDevStates.Variation.$EnumName$);
6565
}
6666
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,12 @@ public enum Cap {
12151215
/**
12161216
* (Re)Added EXPLAIN command
12171217
*/
1218-
EXPLAIN(Build.current().isSnapshot());
1218+
EXPLAIN(Build.current().isSnapshot()),
1219+
1220+
/**
1221+
* Support for {@code STD_DEV_SAMPLE}, {@code VARIANCE_POPULATION}, and {@code VARIANCE_SAMPLE} aggregations.
1222+
*/
1223+
STD_DEV_VARIATIONS;
12191224

12201225
private final boolean enabled;
12211226

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
4040
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent;
4141
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
42+
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDevSample;
4243
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
4344
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
4445
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
4546
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
47+
import org.elasticsearch.xpack.esql.expression.function.aggregate.VariancePopulation;
48+
import org.elasticsearch.xpack.esql.expression.function.aggregate.VarianceSample;
4649
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
4750
import org.elasticsearch.xpack.esql.expression.function.fulltext.Kql;
4851
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
@@ -317,9 +320,12 @@ private static FunctionDefinition[][] functions() {
317320
def(Percentile.class, bi(Percentile::new), "percentile"),
318321
def(Sample.class, bi(Sample::new), "sample"),
319322
def(StdDev.class, uni(StdDev::new), "std_dev"),
323+
def(StdDevSample.class, uni(StdDevSample::new), "std_dev_sample"),
320324
def(Sum.class, uni(Sum::new), "sum"),
321325
def(Top.class, tri(Top::new), "top"),
322326
def(Values.class, uni(Values::new), "values"),
327+
def(VariancePopulation.class, uni(VariancePopulation::new), "variance_population"),
328+
def(VarianceSample.class, uni(VarianceSample::new), "variance_sample"),
323329
def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg") },
324330
// math
325331
new FunctionDefinition[] {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
2828
SpatialCentroid.ENTRY,
2929
SpatialExtent.ENTRY,
3030
StdDev.ENTRY,
31+
StdDevSample.ENTRY,
3132
Sum.ENTRY,
3233
Top.ENTRY,
3334
Values.ENTRY,
35+
VariancePopulation.ENTRY,
36+
VarianceSample.ENTRY,
3437
MinOverTime.ENTRY,
3538
MaxOverTime.ENTRY,
3639
AvgOverTime.ENTRY,
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.aggregate;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
13+
import org.elasticsearch.compute.aggregation.$Variation$DoubleAggregatorFunctionSupplier;
14+
import org.elasticsearch.compute.aggregation.$Variation$IntAggregatorFunctionSupplier;
15+
import org.elasticsearch.compute.aggregation.$Variation$LongAggregatorFunctionSupplier;
16+
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
17+
import org.elasticsearch.xpack.esql.core.expression.Expression;
18+
import org.elasticsearch.xpack.esql.core.expression.Literal;
19+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
20+
import org.elasticsearch.xpack.esql.core.tree.Source;
21+
import org.elasticsearch.xpack.esql.core.type.DataType;
22+
import org.elasticsearch.xpack.esql.expression.function.Example;
23+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
24+
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
25+
import org.elasticsearch.xpack.esql.expression.function.Param;
26+
import org.elasticsearch.xpack.esql.planner.ToAggregator;
27+
28+
import java.io.IOException;
29+
import java.util.List;
30+
31+
import static java.util.Collections.emptyList;
32+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
33+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
34+
35+
public class $Variation$ extends AggregateFunction implements ToAggregator {
36+
$if(StdDev)$
37+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "StdDev", StdDev::new);
38+
39+
$else$
40+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
41+
Expression.class,
42+
"$Variation$",
43+
$Variation$::new
44+
);
45+
$endif$
46+
47+
@FunctionInfo(
48+
returnType = "double",
49+
description = "The $lower_case$ of a numeric field.",
50+
type = FunctionType.AGGREGATE,
51+
examples = {
52+
@Example(file = "stats", tag = "$Variation$"),
53+
@Example(
54+
description = "The expression can use inline functions. For example, to calculate the "
55+
+ "$lower_case$ of each employee’s maximum salary changes, "
56+
+ "first use `MV_MAX` on each row, and then use `$CAPS$` on the result",
57+
file = "stats",
58+
tag = "docsStats$Variation$NestedExpression"
59+
) }
60+
)
61+
public $Variation$(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
62+
this(source, field, Literal.TRUE);
63+
}
64+
65+
public $Variation$(Source source, Expression field, Expression filter) {
66+
super(source, field, filter, emptyList());
67+
}
68+
69+
private $Variation$(StreamInput in) throws IOException {
70+
super(in);
71+
}
72+
73+
@Override
74+
public String getWriteableName() {
75+
return ENTRY.name;
76+
}
77+
78+
@Override
79+
public DataType dataType() {
80+
return DataType.DOUBLE;
81+
}
82+
83+
@Override
84+
protected Expression.TypeResolution resolveType() {
85+
return isType(
86+
field(),
87+
dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
88+
sourceText(),
89+
DEFAULT,
90+
"numeric except unsigned_long or counter types"
91+
);
92+
}
93+
94+
@Override
95+
protected NodeInfo<$Variation$> info() {
96+
return NodeInfo.create(this, $Variation$::new, field(), filter());
97+
}
98+
99+
@Override
100+
public $Variation$ replaceChildren(List<Expression> newChildren) {
101+
return new $Variation$(source(), newChildren.get(0), newChildren.get(1));
102+
}
103+
104+
public $Variation$ withFilter(Expression filter) {
105+
return new $Variation$(source(), field(), filter);
106+
}
107+
108+
@Override
109+
public final AggregatorFunctionSupplier supplier() {
110+
DataType type = field().dataType();
111+
if (type == DataType.LONG) {
112+
return new $Variation$LongAggregatorFunctionSupplier();
113+
}
114+
if (type == DataType.INTEGER) {
115+
return new $Variation$IntAggregatorFunctionSupplier();
116+
}
117+
if (type == DataType.DOUBLE) {
118+
return new $Variation$DoubleAggregatorFunctionSupplier();
119+
}
120+
throw EsqlIllegalArgumentException.illegalDataType(type);
121+
}
122+
}

0 commit comments

Comments
 (0)