Skip to content

Commit 4806e24

Browse files
committed
Implement PredictLinear function
1 parent b6ba4fe commit 4806e24

File tree

11 files changed

+278
-26
lines changed

11 files changed

+278
-26
lines changed

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java

Lines changed: 12 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java

Lines changed: 12 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java

Lines changed: 12 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,25 @@
66
*/
77
package org.elasticsearch.compute.aggregation;
88

9-
class SimpleLinearRegressionWithTimeseries {
9+
public class SimpleLinearRegressionWithTimeseries {
1010
long count;
1111
double sumVal;
1212
long sumTs;
1313
double sumTsVal;
1414
long sumTsSq;
15+
long maxTs;
16+
17+
public interface SimpleLinearModelFunction {
18+
double predict(SimpleLinearRegressionWithTimeseries model);
19+
}
1520

1621
SimpleLinearRegressionWithTimeseries() {
1722
this.count = 0;
1823
this.sumVal = 0.0;
1924
this.sumTs = 0;
2025
this.sumTsVal = 0.0;
2126
this.sumTsSq = 0;
27+
this.maxTs = Long.MIN_VALUE;
2228
}
2329

2430
void add(long ts, double val) {
@@ -27,9 +33,16 @@ void add(long ts, double val) {
2733
sumTs += ts;
2834
sumTsVal += ts * val;
2935
sumTsSq += ts * ts;
36+
if (ts > maxTs) {
37+
maxTs = ts;
38+
}
39+
}
40+
41+
public double lastTimestamp() {
42+
return maxTs;
3043
}
3144

32-
double slope() {
45+
public double slope() {
3346
if (count <= 1) {
3447
return Double.NaN;
3548
}
@@ -41,7 +54,7 @@ void add(long ts, double val) {
4154
return numerator / denominator;
4255
}
4356

44-
double intercept() {
57+
public double intercept() {
4558
if (count == 0) {
4659
return 0.0; // or handle as needed
4760
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-DerivGroupingAggregatorFunction.java.st

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,23 @@ public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregator
4040
private final List<Integer> channels;
4141
private final DriverContext driverContext;
4242
private ObjectArray<SimpleLinearRegressionWithTimeseries> states;
43+
private SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn;
4344

44-
public Deriv$Type$GroupingAggregatorFunction(List<Integer> channels, DriverContext driverContext) {
45+
public Deriv$Type$GroupingAggregatorFunction(List<Integer> channels, DriverContext driverContext, SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn) {
4546
this.states = driverContext.bigArrays().newObjectArray(256);
4647
this.channels = channels;
4748
this.driverContext = driverContext;
49+
this.fn = fn;
4850
}
4951

5052
public static class Supplier implements AggregatorFunctionSupplier {
5153

54+
SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn;
55+
56+
public Supplier(SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn) {
57+
this.fn = fn;
58+
}
59+
5260
@Override
5361
public List<IntermediateStateDesc> nonGroupingIntermediateStateDesc() {
5462
throw new UnsupportedOperationException("DerivGroupingAggregatorFunction does not support non-grouping aggregation");
@@ -66,7 +74,7 @@ public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregator
6674

6775
@Override
6876
public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List<Integer> channels) {
69-
return new Deriv$Type$GroupingAggregatorFunction(channels, driverContext);
77+
return new Deriv$Type$GroupingAggregatorFunction(channels, driverContext, fn);
7078
}
7179

7280
@Override
@@ -103,7 +111,7 @@ public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregator
103111
state = new SimpleLinearRegressionWithTimeseries();
104112
states.set(groupId, state);
105113
}
106-
state.add(ts, (double) vValue); // TODO - value needs to be converted to double
114+
state.add(ts, (double) vValue);
107115
}
108116
}
109117

@@ -283,8 +291,7 @@ public class Deriv$Type$GroupingAggregatorFunction implements GroupingAggregator
283291
if (state == null) {
284292
resultBuilder.appendNull();
285293
} else {
286-
double deriv = state.slope();
287-
resultBuilder.appendDouble(deriv);
294+
resultBuilder.appendDouble(fn.predict(state));
288295
}
289296
}
290297
blocks[offset] = resultBuilder.build();

x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,57 @@ max_deriv:double | max_rate:double | time_bucket:datetime | cluster:keyword
763763
0.0 | 6.980661 | 2024-05-10T00:20:00.000Z | prod
764764

765765
;
766+
767+
predict_linear_with_long
768+
required_capability: ts_command_v0
769+
770+
TS k8s
771+
| STATS predicted_cost = values(round(predict_linear(to_double(network.total_cost), 10), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster
772+
| KEEP predicted_cost, time_bucket, cluster
773+
| SORT cluster, time_bucket
774+
| LIMIT 5
775+
;
776+
777+
predicted_cost:double | time_bucket:datetime | cluster:keyword
778+
[-6.602588168740215E14, -8.877497236448354E15, 2.2311938303322465E15] | 2024-05-10T00:00:00.000Z | prod
779+
[-8.155468708896549E14, -9.51830721515923E14, -1.145502894094415E15] | 2024-05-10T00:05:00.000Z | prod
780+
[5.463777793231174E14, -2.0359960651035278E14, -1.8388879144239432E15] | 2024-05-10T00:10:00.000Z | prod
781+
[-1.3995586522213082E15, 1.0825172194782352E15, 4.0520717676617645E15] | 2024-05-10T00:15:00.000Z | prod
782+
[-8.47847162491646E14, -9.533695970876474E14, 0.0] | 2024-05-10T00:20:00.000Z | prod
783+
;
784+
785+
predict_linear_with_timedelta
786+
required_capability: ts_command_v0
787+
788+
TS k8s
789+
| STATS predicted_cost = values(round(predict_linear(to_double(network.total_cost), 10 minutes), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster
790+
| KEEP predicted_cost, time_bucket, cluster
791+
| SORT cluster, time_bucket
792+
| LIMIT 5
793+
;
794+
795+
predicted_cost:double | time_bucket:datetime | cluster:keyword
796+
[-6.602588168740214E14, -8.877497236448353E15, 2.2311938303322462E15] | 2024-05-10T00:00:00.000Z | prod
797+
[-8.155468708896546E14, -9.51830721515923E14, -1.1455028940944148E15] | 2024-05-10T00:05:00.000Z | prod
798+
[5.463777793231174E14, -2.0359960651035275E14, -1.8388879144239428E15] | 2024-05-10T00:10:00.000Z | prod
799+
[-1.399558652221308E15, 1.082517219478235E15, 4.0520717676617635E15] | 2024-05-10T00:15:00.000Z | prod
800+
[-8.478471624916458E14, -9.533695970876471E14, 0.0] | 2024-05-10T00:20:00.000Z | prod
801+
;
802+
803+
predict_linear_with_int_and_timedelta
804+
required_capability: ts_command_v0
805+
806+
TS k8s
807+
| STATS predicted_cost_int = values(round(predict_linear(to_long(network.total_bytes_in), 10), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster
808+
| KEEP predicted_cost_int, time_bucket, cluster
809+
| SORT cluster, time_bucket
810+
| LIMIT 5
811+
;
812+
813+
predicted_cost_int:double | time_bucket:datetime | cluster:keyword
814+
[-4.773636702420472E16, -7.885984558091773E17, -2.2361970129287098E17] | 2024-05-10T00:00:00.000Z | prod
815+
[-2.8368940048287676E16, -1.9090238696317244E16, -4.5500448741014824E16] | 2024-05-10T00:05:00.000Z | prod
816+
[-8.229547268699878E16, -4.188543503897301E16, -8.065656914863902E16] | 2024-05-10T00:10:00.000Z | prod
817+
[-1.5332165779386784E17, -1.0946919353828525E17, -1.2534486684657331E17] | 2024-05-10T00:15:00.000Z | prod
818+
[8.510411223069503E17, -8.325264742678522E16, 0.0] | 2024-05-10T00:20:00.000Z | prod
819+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,7 @@ public enum Cap {
14781478
PERCENTILE_OVER_TIME,
14791479
VARIANCE_STDDEV_OVER_TIME,
14801480
TS_LINREG_DERIVATIVE,
1481+
TS_LINREG_PREDICT,
14811482
/**
14821483
* INLINE STATS fix incorrect prunning of null filtering
14831484
* https://github.com/elastic/elasticsearch/pull/135011

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime;
4444
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
4545
import org.elasticsearch.xpack.esql.expression.function.aggregate.PercentileOverTime;
46+
import org.elasticsearch.xpack.esql.expression.function.aggregate.PredictLinear;
4647
import org.elasticsearch.xpack.esql.expression.function.aggregate.Present;
4748
import org.elasticsearch.xpack.esql.expression.function.aggregate.PresentOverTime;
4849
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
@@ -533,6 +534,7 @@ private static FunctionDefinition[][] functions() {
533534
def(Idelta.class, uni(Idelta::new), "idelta"),
534535
def(Delta.class, uni(Delta::new), "delta"),
535536
def(Deriv.class, uni(Deriv::new), "deriv"),
537+
def(PredictLinear.class, bi(PredictLinear::new), "predict_linear"),
536538
def(Increase.class, uni(Increase::new), "increase"),
537539
def(MaxOverTime.class, uni(MaxOverTime::new), "max_over_time"),
538540
def(MinOverTime.class, uni(MinOverTime::new), "min_over_time"),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
3030
Increase.ENTRY,
3131
Delta.ENTRY,
3232
Deriv.ENTRY,
33+
PredictLinear.ENTRY,
3334
Sample.ENTRY,
3435
SpatialCentroid.ENTRY,
3536
SpatialExtent.ENTRY,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.compute.aggregation.DerivDoubleGroupingAggregatorFunction;
1313
import org.elasticsearch.compute.aggregation.DerivIntGroupingAggregatorFunction;
1414
import org.elasticsearch.compute.aggregation.DerivLongGroupingAggregatorFunction;
15+
import org.elasticsearch.compute.aggregation.SimpleLinearRegressionWithTimeseries;
1516
import org.elasticsearch.xpack.esql.core.expression.Expression;
1617
import org.elasticsearch.xpack.esql.core.expression.Literal;
1718
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
@@ -99,10 +100,12 @@ public String getWriteableName() {
99100
@Override
100101
public AggregatorFunctionSupplier supplier() {
101102
final DataType type = field().dataType();
103+
104+
SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn = (SimpleLinearRegressionWithTimeseries model) -> model.slope();
102105
return switch (type) {
103-
case INTEGER -> new DerivIntGroupingAggregatorFunction.Supplier();
104-
case LONG -> new DerivLongGroupingAggregatorFunction.Supplier();
105-
case DOUBLE -> new DerivDoubleGroupingAggregatorFunction.Supplier();
106+
case INTEGER -> new DerivIntGroupingAggregatorFunction.Supplier(fn);
107+
case LONG -> new DerivLongGroupingAggregatorFunction.Supplier(fn);
108+
case DOUBLE -> new DerivDoubleGroupingAggregatorFunction.Supplier(fn);
106109
default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type);
107110
};
108111
}

0 commit comments

Comments
 (0)