diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/predict_linear.md b/docs/reference/query-languages/esql/_snippets/functions/description/predict_linear.md new file mode 100644 index 0000000000000..47d0c80e2cda5 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/predict_linear.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Predicts the value of a time series at `t` seconds in the future. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/predict_linear.md b/docs/reference/query-languages/esql/_snippets/functions/examples/predict_linear.md new file mode 100644 index 0000000000000..e252d31da3e7b --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/predict_linear.md @@ -0,0 +1,16 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +TS k8s +| STATS predicted_cost_int = MAX(ROUND(PREDICT_LINEAR(TO_LONG(network.total_bytes_in), 10), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster +``` + +| predicted_cost_int:double | time_bucket:datetime | cluster:keyword | +| --- | --- | --- | +| 3105.0 | 2024-05-10T00:00:00.000Z | prod | +| 4218.33168 | 2024-05-10T00:05:00.000Z | prod | +| 6029.22491 | 2024-05-10T00:10:00.000Z | prod | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md b/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md index 09036fe8314d1..aac90accaa12b 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/deriv.md @@ -1,6 +1,10 @@ % This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. ## `DERIV` [esql-deriv] +```{applies_to} +stack: preview 9.3.0 +serverless: preview +``` **Syntax** diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/predict_linear.md b/docs/reference/query-languages/esql/_snippets/functions/layout/predict_linear.md new file mode 100644 index 0000000000000..365fc8c0b71f0 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/predict_linear.md @@ -0,0 +1,27 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `PREDICT_LINEAR` [esql-predict_linear] +```{applies_to} +stack: preview 9.3.0 +serverless: preview +``` + +**Syntax** + +:::{image} ../../../images/functions/predict_linear.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/predict_linear.md +::: + +:::{include} ../description/predict_linear.md +::: + +:::{include} ../types/predict_linear.md +::: + +:::{include} ../examples/predict_linear.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/predict_linear.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/predict_linear.md new file mode 100644 index 0000000000000..530ba74f0de33 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/predict_linear.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`field` +: the expression to use for the prediction + +`t` +: how long in the fututre to predict in seconds for numeric, or in time delta + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/predict_linear.md b/docs/reference/query-languages/esql/_snippets/functions/types/predict_linear.md new file mode 100644 index 0000000000000..8e4a8fdc41167 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/predict_linear.md @@ -0,0 +1,19 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| field | t | result | +| --- | --- | --- | +| double | double | double | +| double | integer | double | +| double | long | double | +| double | time_duration | double | +| integer | double | double | +| integer | integer | double | +| integer | long | double | +| integer | time_duration | double | +| long | double | double | +| long | integer | double | +| long | long | double | +| long | time_duration | double | + diff --git a/docs/reference/query-languages/esql/images/functions/predict_linear.svg b/docs/reference/query-languages/esql/images/functions/predict_linear.svg new file mode 100644 index 0000000000000..248a9dc9037a1 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/predict_linear.svg @@ -0,0 +1 @@ +PREDICT_LINEAR(field,t) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json index f84745f4de37a..8109bdbc3ad0e 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/deriv.json @@ -44,6 +44,6 @@ "examples" : [ "TS k8s\n| WHERE pod == \"three\"\n| STATS max_deriv = MAX(DERIV(network.cost)) BY time_bucket = BUCKET(@timestamp,5minute), pod" ], - "preview" : false, + "preview" : true, "snapshot_only" : false } diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/predict_linear.json b/docs/reference/query-languages/esql/kibana/definition/functions/predict_linear.json new file mode 100644 index 0000000000000..9dc94aeae6671 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/predict_linear.json @@ -0,0 +1,229 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "time_series_agg", + "name" : "predict_linear", + "description" : "Predicts the value of a time series at `t` seconds in the future.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "double", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "integer", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "long", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "time_duration", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "double", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "integer", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "long", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "time_duration", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "double", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "integer", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "long", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "the expression to use for the prediction" + }, + { + "name" : "t", + "type" : "time_duration", + "optional" : false, + "description" : "how long in the fututre to predict in seconds for numeric, or in time delta" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "TS k8s\n| STATS predicted_cost_int = MAX(ROUND(PREDICT_LINEAR(TO_LONG(network.total_bytes_in), 10), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster" + ], + "preview" : true, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/predict_linear.md b/docs/reference/query-languages/esql/kibana/docs/functions/predict_linear.md new file mode 100644 index 0000000000000..ecbe7512a2846 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/predict_linear.md @@ -0,0 +1,9 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### PREDICT LINEAR +Predicts the value of a time series at `t` seconds in the future. + +```esql +TS k8s +| STATS predicted_cost_int = MAX(ROUND(PREDICT_LINEAR(TO_LONG(network.total_bytes_in), 10), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster +``` diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java index 6698c7284d352..63b70ce65db53 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunction.java @@ -29,7 +29,9 @@ public final class DerivDoubleAggregatorFunction implements AggregatorFunction { new IntermediateStateDesc("sumVal", ElementType.DOUBLE), new IntermediateStateDesc("sumTs", ElementType.LONG), new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("maxTs", ElementType.LONG), + new IntermediateStateDesc("valueAtMaxTs", ElementType.DOUBLE) ); private final DriverContext driverContext; @@ -37,16 +39,24 @@ public final class DerivDoubleAggregatorFunction implements AggregatorFunction { private final List channels; + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + public DerivDoubleAggregatorFunction(DriverContext driverContext, List channels, - SimpleLinearRegressionWithTimeseries state) { + SimpleLinearRegressionWithTimeseries state, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { this.driverContext = driverContext; this.channels = channels; this.state = state; + this.fn = fn; + this.dateNanos = dateNanos; } public static DerivDoubleAggregatorFunction create(DriverContext driverContext, - List channels) { - return new DerivDoubleAggregatorFunction(driverContext, channels, DerivDoubleAggregator.initSingle(driverContext)); + List channels, SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos) { + return new DerivDoubleAggregatorFunction(driverContext, channels, DerivDoubleAggregator.initSingle(driverContext, fn, dateNanos), fn, dateNanos); } public static List intermediateStateDesc() { @@ -206,7 +216,19 @@ public void addIntermediateInput(Page page) { } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); assert sumTsSq.getPositionCount() == 1; - DerivDoubleAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + assert maxTs.getPositionCount() == 1; + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert valueAtMaxTs.getPositionCount() == 1; + DerivDoubleAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0), maxTs.getLong(0), valueAtMaxTs.getDouble(0)); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java index 440c03cd403f8..b2ffbfcc2af79 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleAggregatorFunctionSupplier.java @@ -15,7 +15,14 @@ * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. */ public final class DerivDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { - public DerivDoubleAggregatorFunctionSupplier() { + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + + public DerivDoubleAggregatorFunctionSupplier( + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { + this.fn = fn; + this.dateNanos = dateNanos; } @Override @@ -31,13 +38,13 @@ public List groupingIntermediateStateDesc() { @Override public DerivDoubleAggregatorFunction aggregator(DriverContext driverContext, List channels) { - return DerivDoubleAggregatorFunction.create(driverContext, channels); + return DerivDoubleAggregatorFunction.create(driverContext, channels, fn, dateNanos); } @Override public DerivDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return DerivDoubleGroupingAggregatorFunction.create(channels, driverContext); + return DerivDoubleGroupingAggregatorFunction.create(channels, driverContext, fn, dateNanos); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java index adf77eff37fa8..06f6e8fe78310 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivDoubleGroupingAggregatorFunction.java @@ -31,7 +31,9 @@ public final class DerivDoubleGroupingAggregatorFunction implements GroupingAggr new IntermediateStateDesc("sumVal", ElementType.DOUBLE), new IntermediateStateDesc("sumTs", ElementType.LONG), new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("maxTs", ElementType.LONG), + new IntermediateStateDesc("valueAtMaxTs", ElementType.DOUBLE) ); private final DerivDoubleAggregator.GroupingState state; @@ -39,16 +41,24 @@ public final class DerivDoubleGroupingAggregatorFunction implements GroupingAggr private final DriverContext driverContext; + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + public DerivDoubleGroupingAggregatorFunction(List channels, - DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + DerivDoubleAggregator.GroupingState state, DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { this.channels = channels; this.state = state; this.driverContext = driverContext; + this.fn = fn; + this.dateNanos = dateNanos; } public static DerivDoubleGroupingAggregatorFunction create(List channels, - DriverContext driverContext) { - return new DerivDoubleGroupingAggregatorFunction(channels, DerivDoubleAggregator.initGrouping(driverContext), driverContext); + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { + return new DerivDoubleGroupingAggregatorFunction(channels, DerivDoubleAggregator.initGrouping(driverContext, fn, dateNanos), driverContext, fn, dateNanos); } public static List intermediateStateDesc() { @@ -214,7 +224,17 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -224,7 +244,7 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } } @@ -308,7 +328,17 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -318,7 +348,7 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } } @@ -388,11 +418,21 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); int valuesPosition = groupPosition + positionOffset; - DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivDoubleAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java index 1c936d5de9135..0e1352a3be9b9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunction.java @@ -31,7 +31,9 @@ public final class DerivIntAggregatorFunction implements AggregatorFunction { new IntermediateStateDesc("sumVal", ElementType.DOUBLE), new IntermediateStateDesc("sumTs", ElementType.LONG), new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("maxTs", ElementType.LONG), + new IntermediateStateDesc("valueAtMaxTs", ElementType.DOUBLE) ); private final DriverContext driverContext; @@ -39,16 +41,24 @@ public final class DerivIntAggregatorFunction implements AggregatorFunction { private final List channels; + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + public DerivIntAggregatorFunction(DriverContext driverContext, List channels, - SimpleLinearRegressionWithTimeseries state) { + SimpleLinearRegressionWithTimeseries state, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { this.driverContext = driverContext; this.channels = channels; this.state = state; + this.fn = fn; + this.dateNanos = dateNanos; } public static DerivIntAggregatorFunction create(DriverContext driverContext, - List channels) { - return new DerivIntAggregatorFunction(driverContext, channels, DerivIntAggregator.initSingle(driverContext)); + List channels, SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos) { + return new DerivIntAggregatorFunction(driverContext, channels, DerivIntAggregator.initSingle(driverContext, fn, dateNanos), fn, dateNanos); } public static List intermediateStateDesc() { @@ -207,7 +217,19 @@ public void addIntermediateInput(Page page) { } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); assert sumTsSq.getPositionCount() == 1; - DerivIntAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + assert maxTs.getPositionCount() == 1; + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert valueAtMaxTs.getPositionCount() == 1; + DerivIntAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0), maxTs.getLong(0), valueAtMaxTs.getDouble(0)); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java index ecd4e4bf8dbd8..f65664342c0ff 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntAggregatorFunctionSupplier.java @@ -15,7 +15,14 @@ * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. */ public final class DerivIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { - public DerivIntAggregatorFunctionSupplier() { + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + + public DerivIntAggregatorFunctionSupplier( + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { + this.fn = fn; + this.dateNanos = dateNanos; } @Override @@ -31,13 +38,13 @@ public List groupingIntermediateStateDesc() { @Override public DerivIntAggregatorFunction aggregator(DriverContext driverContext, List channels) { - return DerivIntAggregatorFunction.create(driverContext, channels); + return DerivIntAggregatorFunction.create(driverContext, channels, fn, dateNanos); } @Override public DerivIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return DerivIntGroupingAggregatorFunction.create(channels, driverContext); + return DerivIntGroupingAggregatorFunction.create(channels, driverContext, fn, dateNanos); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java index ab1e3d602be25..b63d11903afa3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivIntGroupingAggregatorFunction.java @@ -32,7 +32,9 @@ public final class DerivIntGroupingAggregatorFunction implements GroupingAggrega new IntermediateStateDesc("sumVal", ElementType.DOUBLE), new IntermediateStateDesc("sumTs", ElementType.LONG), new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("maxTs", ElementType.LONG), + new IntermediateStateDesc("valueAtMaxTs", ElementType.DOUBLE) ); private final DerivDoubleAggregator.GroupingState state; @@ -40,16 +42,24 @@ public final class DerivIntGroupingAggregatorFunction implements GroupingAggrega private final DriverContext driverContext; + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + public DerivIntGroupingAggregatorFunction(List channels, - DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + DerivDoubleAggregator.GroupingState state, DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { this.channels = channels; this.state = state; this.driverContext = driverContext; + this.fn = fn; + this.dateNanos = dateNanos; } public static DerivIntGroupingAggregatorFunction create(List channels, - DriverContext driverContext) { - return new DerivIntGroupingAggregatorFunction(channels, DerivIntAggregator.initGrouping(driverContext), driverContext); + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { + return new DerivIntGroupingAggregatorFunction(channels, DerivIntAggregator.initGrouping(driverContext, fn, dateNanos), driverContext, fn, dateNanos); } public static List intermediateStateDesc() { @@ -215,7 +225,17 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -225,7 +245,7 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } } @@ -309,7 +329,17 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -319,7 +349,7 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } } @@ -389,11 +419,21 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); int valuesPosition = groupPosition + positionOffset; - DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivIntAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java index b18118d21c08a..e8f034842ed6c 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunction.java @@ -29,7 +29,9 @@ public final class DerivLongAggregatorFunction implements AggregatorFunction { new IntermediateStateDesc("sumVal", ElementType.DOUBLE), new IntermediateStateDesc("sumTs", ElementType.LONG), new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("maxTs", ElementType.LONG), + new IntermediateStateDesc("valueAtMaxTs", ElementType.DOUBLE) ); private final DriverContext driverContext; @@ -37,16 +39,24 @@ public final class DerivLongAggregatorFunction implements AggregatorFunction { private final List channels; + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + public DerivLongAggregatorFunction(DriverContext driverContext, List channels, - SimpleLinearRegressionWithTimeseries state) { + SimpleLinearRegressionWithTimeseries state, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { this.driverContext = driverContext; this.channels = channels; this.state = state; + this.fn = fn; + this.dateNanos = dateNanos; } public static DerivLongAggregatorFunction create(DriverContext driverContext, - List channels) { - return new DerivLongAggregatorFunction(driverContext, channels, DerivLongAggregator.initSingle(driverContext)); + List channels, SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos) { + return new DerivLongAggregatorFunction(driverContext, channels, DerivLongAggregator.initSingle(driverContext, fn, dateNanos), fn, dateNanos); } public static List intermediateStateDesc() { @@ -206,7 +216,19 @@ public void addIntermediateInput(Page page) { } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); assert sumTsSq.getPositionCount() == 1; - DerivLongAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0)); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + assert maxTs.getPositionCount() == 1; + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert valueAtMaxTs.getPositionCount() == 1; + DerivLongAggregator.combineIntermediate(state, count.getLong(0), sumVal.getDouble(0), sumTs.getLong(0), sumTsVal.getDouble(0), sumTsSq.getLong(0), maxTs.getLong(0), valueAtMaxTs.getDouble(0)); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java index 259eb756cb1b2..d18df07541a83 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongAggregatorFunctionSupplier.java @@ -15,7 +15,14 @@ * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. */ public final class DerivLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { - public DerivLongAggregatorFunctionSupplier() { + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + + public DerivLongAggregatorFunctionSupplier( + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { + this.fn = fn; + this.dateNanos = dateNanos; } @Override @@ -31,13 +38,13 @@ public List groupingIntermediateStateDesc() { @Override public DerivLongAggregatorFunction aggregator(DriverContext driverContext, List channels) { - return DerivLongAggregatorFunction.create(driverContext, channels); + return DerivLongAggregatorFunction.create(driverContext, channels, fn, dateNanos); } @Override public DerivLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List channels) { - return DerivLongGroupingAggregatorFunction.create(channels, driverContext); + return DerivLongGroupingAggregatorFunction.create(channels, driverContext, fn, dateNanos); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java index cb97638df4ea4..114275872105d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/DerivLongGroupingAggregatorFunction.java @@ -31,7 +31,9 @@ public final class DerivLongGroupingAggregatorFunction implements GroupingAggreg new IntermediateStateDesc("sumVal", ElementType.DOUBLE), new IntermediateStateDesc("sumTs", ElementType.LONG), new IntermediateStateDesc("sumTsVal", ElementType.DOUBLE), - new IntermediateStateDesc("sumTsSq", ElementType.LONG) ); + new IntermediateStateDesc("sumTsSq", ElementType.LONG), + new IntermediateStateDesc("maxTs", ElementType.LONG), + new IntermediateStateDesc("valueAtMaxTs", ElementType.DOUBLE) ); private final DerivDoubleAggregator.GroupingState state; @@ -39,16 +41,24 @@ public final class DerivLongGroupingAggregatorFunction implements GroupingAggreg private final DriverContext driverContext; + private final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + + private final boolean dateNanos; + public DerivLongGroupingAggregatorFunction(List channels, - DerivDoubleAggregator.GroupingState state, DriverContext driverContext) { + DerivDoubleAggregator.GroupingState state, DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { this.channels = channels; this.state = state; this.driverContext = driverContext; + this.fn = fn; + this.dateNanos = dateNanos; } public static DerivLongGroupingAggregatorFunction create(List channels, - DriverContext driverContext) { - return new DerivLongGroupingAggregatorFunction(channels, DerivLongAggregator.initGrouping(driverContext), driverContext); + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { + return new DerivLongGroupingAggregatorFunction(channels, DerivLongAggregator.initGrouping(driverContext, fn, dateNanos), driverContext, fn, dateNanos); } public static List intermediateStateDesc() { @@ -214,7 +224,17 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -224,7 +244,7 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } } @@ -308,7 +328,17 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { continue; @@ -318,7 +348,7 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); int valuesPosition = groupPosition + positionOffset; - DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } } @@ -388,11 +418,21 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } LongVector sumTsSq = ((LongBlock) sumTsSqUncast).asVector(); - assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount(); + Block maxTsUncast = page.getBlock(channels.get(5)); + if (maxTsUncast.areAllValuesNull()) { + return; + } + LongVector maxTs = ((LongBlock) maxTsUncast).asVector(); + Block valueAtMaxTsUncast = page.getBlock(channels.get(6)); + if (valueAtMaxTsUncast.areAllValuesNull()) { + return; + } + DoubleVector valueAtMaxTs = ((DoubleBlock) valueAtMaxTsUncast).asVector(); + assert count.getPositionCount() == sumVal.getPositionCount() && count.getPositionCount() == sumTs.getPositionCount() && count.getPositionCount() == sumTsVal.getPositionCount() && count.getPositionCount() == sumTsSq.getPositionCount() && count.getPositionCount() == maxTs.getPositionCount() && count.getPositionCount() == valueAtMaxTs.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); int valuesPosition = groupPosition + positionOffset; - DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition)); + DerivLongAggregator.combineIntermediate(state, groupId, count.getLong(valuesPosition), sumVal.getDouble(valuesPosition), sumTs.getLong(valuesPosition), sumTsVal.getDouble(valuesPosition), sumTsSq.getLong(valuesPosition), maxTs.getLong(valuesPosition), valueAtMaxTs.getDouble(valuesPosition)); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java index e44df4d56236b..8c44c20a21e84 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivDoubleAggregator.java @@ -26,13 +26,19 @@ @IntermediateState(name = "sumVal", type = "DOUBLE"), @IntermediateState(name = "sumTs", type = "LONG"), @IntermediateState(name = "sumTsVal", type = "DOUBLE"), - @IntermediateState(name = "sumTsSq", type = "LONG") } + @IntermediateState(name = "sumTsSq", type = "LONG"), + @IntermediateState(name = "maxTs", type = "LONG"), + @IntermediateState(name = "valueAtMaxTs", type = "DOUBLE"), } ) @GroupingAggregator class DerivDoubleAggregator { - public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { - return new SimpleLinearRegressionWithTimeseries(); + public static SimpleLinearRegressionWithTimeseries initSingle( + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos + ) { + return new SimpleLinearRegressionWithTimeseries(fn, dateNanos); } public static void combine(SimpleLinearRegressionWithTimeseries current, double value, long timestamp) { @@ -45,26 +51,36 @@ public static void combineIntermediate( double sumVal, long sumTs, double sumTsVal, - long sumTsSq + long sumTsSq, + long maxTs, + double valueAtMaxTs ) { state.count += count; state.sumVal += sumVal; state.sumTs += sumTs; state.sumTsVal += sumTsVal; state.sumTsSq += sumTsSq; + if (state.maxTs < maxTs) { + state.maxTs = maxTs; + state.valueAtMaxTs = valueAtMaxTs; + } } public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { BlockFactory blockFactory = driverContext.blockFactory(); - var slope = state.slope(); + var slope = state.fn.predict(state); if (Double.isNaN(slope)) { return blockFactory.newConstantNullBlock(1); } return blockFactory.newConstantDoubleBlockWith(slope, 1); } - public static GroupingState initGrouping(DriverContext driverContext) { - return new GroupingState(driverContext.bigArrays()); + public static GroupingState initGrouping( + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos + ) { + return new GroupingState(driverContext.bigArrays(), fn, dateNanos); } public static void combine(GroupingState state, int groupId, double value, long timestamp) { @@ -78,9 +94,11 @@ public static void combineIntermediate( double sumVal, long sumTs, double sumTsVal, - long sumTsSq + long sumTsSq, + long maxTs, + double valueAtMaxTs ) { - combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq, maxTs, valueAtMaxTs); } public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, GroupingAggregatorEvaluationContext ctx) { @@ -92,7 +110,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, builder.appendNull(); continue; } - double result = slr.slope(); + double result = slr.fn.predict(slr); if (Double.isNaN(result)) { builder.appendNull(); continue; @@ -105,10 +123,14 @@ public static Block evaluateFinal(GroupingState state, IntVector selectedGroups, public static final class GroupingState extends AbstractArrayState { private ObjectArray states; + final SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn; + final boolean dateNanos; - GroupingState(BigArrays bigArrays) { + GroupingState(BigArrays bigArrays, SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, boolean dateNanos) { super(bigArrays); states = bigArrays.newObjectArray(1); + this.fn = fn; + this.dateNanos = dateNanos; } SimpleLinearRegressionWithTimeseries get(int groupId) { @@ -124,7 +146,7 @@ SimpleLinearRegressionWithTimeseries getAndGrow(int groupId) { } SimpleLinearRegressionWithTimeseries slr = states.get(groupId); if (slr == null) { - slr = new SimpleLinearRegressionWithTimeseries(); + slr = new SimpleLinearRegressionWithTimeseries(fn, dateNanos); states.set(groupId, slr); } return slr; @@ -142,7 +164,9 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive DoubleBlock.Builder sumValBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); LongBlock.Builder sumTsBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); DoubleBlock.Builder sumTsValBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()); - LongBlock.Builder sumTsSqBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()) + LongBlock.Builder sumTsSqBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + LongBlock.Builder lastTsBuilder = driverContext.blockFactory().newLongBlockBuilder(selected.getPositionCount()); + DoubleBlock.Builder valueAtLastTsBuilder = driverContext.blockFactory().newDoubleBlockBuilder(selected.getPositionCount()) ) { for (int i = 0; i < selected.getPositionCount(); i++) { int groupId = selected.getInt(i); @@ -153,12 +177,16 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive sumTsBuilder.appendNull(); sumTsValBuilder.appendNull(); sumTsSqBuilder.appendNull(); + lastTsBuilder.appendNull(); + valueAtLastTsBuilder.appendNull(); } else { countBuilder.appendLong(slr.count); sumValBuilder.appendDouble(slr.sumVal); sumTsBuilder.appendLong(slr.sumTs); sumTsValBuilder.appendDouble(slr.sumTsVal); sumTsSqBuilder.appendLong(slr.sumTsSq); + lastTsBuilder.appendLong(slr.maxTs); + valueAtLastTsBuilder.appendDouble(slr.valueAtMaxTs); } } blocks[offset] = countBuilder.build(); @@ -166,6 +194,8 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset + 2] = sumTsBuilder.build(); blocks[offset + 3] = sumTsValBuilder.build(); blocks[offset + 4] = sumTsSqBuilder.build(); + blocks[offset + 5] = lastTsBuilder.build(); + blocks[offset + 6] = valueAtLastTsBuilder.build(); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java index 0b1dcc912ce45..23f5dca6f84d3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivIntAggregator.java @@ -20,13 +20,19 @@ @IntermediateState(name = "sumVal", type = "DOUBLE"), @IntermediateState(name = "sumTs", type = "LONG"), @IntermediateState(name = "sumTsVal", type = "DOUBLE"), - @IntermediateState(name = "sumTsSq", type = "LONG") } + @IntermediateState(name = "sumTsSq", type = "LONG"), + @IntermediateState(name = "maxTs", type = "LONG"), + @IntermediateState(name = "valueAtMaxTs", type = "DOUBLE"), } ) @GroupingAggregator class DerivIntAggregator { - public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { - return new SimpleLinearRegressionWithTimeseries(); + public static SimpleLinearRegressionWithTimeseries initSingle( + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos + ) { + return new SimpleLinearRegressionWithTimeseries(fn, dateNanos); } public static void combine(SimpleLinearRegressionWithTimeseries current, int value, long timestamp) { @@ -39,17 +45,23 @@ public static void combineIntermediate( double sumVal, long sumTs, double sumTsVal, - long sumTsSq + long sumTsSq, + long maxTs, + double valueAtMaxTs ) { - DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq); + DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq, maxTs, valueAtMaxTs); } public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { return DerivDoubleAggregator.evaluateFinal(state, driverContext); } - public static DerivDoubleAggregator.GroupingState initGrouping(DriverContext driverContext) { - return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays()); + public static DerivDoubleAggregator.GroupingState initGrouping( + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos + ) { + return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays(), fn, dateNanos); } public static void combine(DerivDoubleAggregator.GroupingState state, int groupId, int value, long timestamp) { @@ -63,9 +75,11 @@ public static void combineIntermediate( double sumVal, long sumTs, double sumTsVal, - long sumTsSq + long sumTsSq, + long maxTs, + double valueAtMaxTs ) { - combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq, maxTs, valueAtMaxTs); } public static Block evaluateFinal( diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java index b1572be80618c..7236ecda22dff 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/DerivLongAggregator.java @@ -20,13 +20,19 @@ @IntermediateState(name = "sumVal", type = "DOUBLE"), @IntermediateState(name = "sumTs", type = "LONG"), @IntermediateState(name = "sumTsVal", type = "DOUBLE"), - @IntermediateState(name = "sumTsSq", type = "LONG") } + @IntermediateState(name = "sumTsSq", type = "LONG"), + @IntermediateState(name = "maxTs", type = "LONG"), + @IntermediateState(name = "valueAtMaxTs", type = "DOUBLE"), } ) @GroupingAggregator class DerivLongAggregator { - public static SimpleLinearRegressionWithTimeseries initSingle(DriverContext driverContext) { - return new SimpleLinearRegressionWithTimeseries(); + public static SimpleLinearRegressionWithTimeseries initSingle( + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos + ) { + return new SimpleLinearRegressionWithTimeseries(fn, dateNanos); } public static void combine(SimpleLinearRegressionWithTimeseries current, long value, long timestamp) { @@ -39,17 +45,23 @@ public static void combineIntermediate( double sumVal, long sumTs, double sumTsVal, - long sumTsSq + long sumTsSq, + long maxTs, + double valueAtMaxTs ) { - DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq); + DerivDoubleAggregator.combineIntermediate(state, count, sumVal, sumTs, sumTsVal, sumTsSq, maxTs, valueAtMaxTs); } public static Block evaluateFinal(SimpleLinearRegressionWithTimeseries state, DriverContext driverContext) { return DerivDoubleAggregator.evaluateFinal(state, driverContext); } - public static DerivDoubleAggregator.GroupingState initGrouping(DriverContext driverContext) { - return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays()); + public static DerivDoubleAggregator.GroupingState initGrouping( + DriverContext driverContext, + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn, + boolean dateNanos + ) { + return new DerivDoubleAggregator.GroupingState(driverContext.bigArrays(), fn, dateNanos); } public static void combine(DerivDoubleAggregator.GroupingState state, int groupId, long value, long timestamp) { @@ -63,9 +75,11 @@ public static void combineIntermediate( double sumVal, long sumTs, double sumTsVal, - long sumTsSq + long sumTsSq, + long maxTs, + double valueAtMaxTs ) { - combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq); + combineIntermediate(state.getAndGrow(groupId), count, sumVal, sumTs, sumTsVal, sumTsSq, maxTs, valueAtMaxTs); } public static Block evaluateFinal( diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java index 9ed3a5bf2b081..dcafd440746d8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SimpleLinearRegressionWithTimeseries.java @@ -9,7 +9,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.operator.DriverContext; -class SimpleLinearRegressionWithTimeseries implements AggregatorState { +public class SimpleLinearRegressionWithTimeseries implements AggregatorState { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset + 0] = driverContext.blockFactory().newConstantLongBlockWith(count, 1); @@ -29,24 +29,49 @@ public void close() { long sumTs; double sumTsVal; long sumTsSq; + long maxTs; + double valueAtMaxTs; + double dateFactor; + final SimpleLinearModelFunction fn; - SimpleLinearRegressionWithTimeseries() { + public interface SimpleLinearModelFunction { + double predict(SimpleLinearRegressionWithTimeseries model); + } + + SimpleLinearRegressionWithTimeseries(SimpleLinearModelFunction fn, boolean dateNanos) { this.count = 0; this.sumVal = 0.0; this.sumTs = 0; this.sumTsVal = 0.0; this.sumTsSq = 0; + this.maxTs = Long.MIN_VALUE; + this.valueAtMaxTs = Double.NaN; + this.dateFactor = dateNanos ? 1_000_000_000.0 : 1_000.0; + this.fn = fn; } void add(long ts, double val) { + ts = ts / (long) dateFactor; count++; sumVal += val; sumTs += ts; sumTsVal += ts * val; sumTsSq += ts * ts; + if (ts > maxTs) { + maxTs = ts; + valueAtMaxTs = val; + } + } + + public double lastTimestamp() { + return maxTs; + } + + public double valueAtLastTimestamp() { + return valueAtMaxTs; } - double slope() { + public double slope() { if (count <= 1) { return Double.NaN; } @@ -55,10 +80,10 @@ void add(long ts, double val) { if (denominator == 0) { return Double.NaN; } - return numerator / denominator * 1000.0; // per second + return numerator / denominator; } - double intercept() { + public double intercept() { if (count == 0) { return 0.0; // or handle as needed } @@ -68,5 +93,4 @@ void add(long ts, double val) { } return (sumVal - slp * sumTs) / count; } - } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec index 8bc6538447fc8..876ab11c93f8d 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-timeseries.csv-spec @@ -783,18 +783,18 @@ required_capability: TS_LINREG_DERIVATIVE TS k8s | STATS max_deriv = max(deriv(to_long(network.total_bytes_in))), max_rate = max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster -| EVAL max_deriv = ROUND(max_deriv,6), max_rate = ROUND(max_rate,6) +| EVAL max_deriv = ROUND(max_deriv,5), max_rate = ROUND(max_rate,5) | KEEP max_deriv, max_rate, time_bucket, cluster | SORT cluster, time_bucket | LIMIT 5 ; max_deriv:double | max_rate:double | time_bucket:datetime | cluster:keyword -85.5 | 8.120833 | 2024-05-10T00:00:00.000Z | prod -4.933168 | 6.451737 | 2024-05-10T00:05:00.000Z | prod -8.922491 | 11.56274 | 2024-05-10T00:10:00.000Z | prod +85.5 | 8.12083 | 2024-05-10T00:00:00.000Z | prod +4.93317 | 6.45174 | 2024-05-10T00:05:00.000Z | prod +8.92249 | 11.56274 | 2024-05-10T00:10:00.000Z | prod 16.62316 | 11.86081 | 2024-05-10T00:15:00.000Z | prod -9.026268 | 6.980661 | 2024-05-10T00:20:00.000Z | prod +9.02627 | 6.98066 | 2024-05-10T00:20:00.000Z | prod ; bare_count_over_time_outputs_dimensions @@ -984,3 +984,69 @@ TS k8s max_rate:double 13.17372515125324 ; + +predict_linear_with_long +required_capability: ts_command_v0 +required_capability: ts_linreg_predict + +TS k8s +| STATS + predicted_cost = max(round(predict_linear(to_double(network.total_cost), 10), 5)), + hand_extrap_v2=max(round(last_over_time(to_double(network.total_cost)) + deriv(to_double(network.total_cost))*10, 5)) +BY time_bucket = bucket(@timestamp,5minute), cluster +| KEEP predicted_cost, hand_extrap_v2, time_bucket, cluster +| SORT cluster, time_bucket +| LIMIT 5 +; + +predicted_cost:double | hand_extrap_v2:double | time_bucket:datetime | cluster:keyword +24.59085 | 24.59085 | 2024-05-10T00:00:00.000Z | prod +55.61696 | 55.61696 | 2024-05-10T00:05:00.000Z | prod +97.49373 | 97.49373 | 2024-05-10T00:10:00.000Z | prod +61.0174 | 61.0174 | 2024-05-10T00:15:00.000Z | prod +74.66924 | 74.66924 | 2024-05-10T00:20:00.000Z | prod + +; + +predict_linear_with_timedelta +required_capability: ts_command_v0 +required_capability: ts_linreg_predict + +TS k8s +| STATS predicted_cost = max(round(predict_linear(to_double(network.total_cost), 10 minutes), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster +| KEEP predicted_cost, time_bucket, cluster +| SORT cluster, time_bucket +| LIMIT 5 +; + +predicted_cost:double | time_bucket:datetime | cluster:keyword +592.375 | 2024-05-10T00:00:00.000Z | prod +128.89239 | 2024-05-10T00:05:00.000Z | prod +215.12355 | 2024-05-10T00:10:00.000Z | prod +150.54425 | 2024-05-10T00:15:00.000Z | prod +128.90425 | 2024-05-10T00:20:00.000Z | prod + +; + +predict_linear_with_int_and_timedelta +required_capability: ts_command_v0 +required_capability: ts_linreg_predict + +// tag::predict_linear[] +TS k8s +| STATS predicted_cost_int = MAX(ROUND(PREDICT_LINEAR(TO_LONG(network.total_bytes_in), 10), 5)) BY time_bucket = bucket(@timestamp,5minute), cluster +// end::predict_linear[] +| KEEP predicted_cost_int, time_bucket, cluster +| SORT cluster, time_bucket +| LIMIT 5 +; + +// tag::predict_linear-result[] +predicted_cost_int:double | time_bucket:datetime | cluster:keyword +3105.0 | 2024-05-10T00:00:00.000Z | prod +4218.33168 | 2024-05-10T00:05:00.000Z | prod +6029.22491 | 2024-05-10T00:10:00.000Z | prod +// end::predict_linear-result[] +9420.23165 | 2024-05-10T00:15:00.000Z | prod +10367.26268 | 2024-05-10T00:20:00.000Z | prod +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 851e1218140ad..a0a139019b08d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1518,6 +1518,7 @@ public enum Cap { PERCENTILE_OVER_TIME, VARIANCE_STDDEV_OVER_TIME, TS_LINREG_DERIVATIVE, + TS_LINREG_PREDICT, /** * INLINE STATS fix incorrect prunning of null filtering * https://github.com/elastic/elasticsearch/pull/135011 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 584191ffc46f3..2da0fbc544223 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.PercentileOverTime; +import org.elasticsearch.xpack.esql.expression.function.aggregate.PredictLinear; import org.elasticsearch.xpack.esql.expression.function.aggregate.Present; import org.elasticsearch.xpack.esql.expression.function.aggregate.PresentOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; @@ -539,6 +540,7 @@ private static FunctionDefinition[][] functions() { defTS(Delta.class, bi(Delta::new), "delta"), defTS(Increase.class, bi(Increase::new), "increase"), defTS(Deriv.class, bi(Deriv::new), "deriv"), + defTS3(PredictLinear.class, tri(PredictLinear::new), "predict_linear"), def(MaxOverTime.class, bi(MaxOverTime::new), "max_over_time"), def(MinOverTime.class, bi(MinOverTime::new), "min_over_time"), def(SumOverTime.class, bi(SumOverTime::new), "sum_over_time"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index 1163fbde777fe..818487ead7423 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -31,6 +31,7 @@ public static List getNamedWriteables() { Increase.ENTRY, Delta.ENTRY, Deriv.ENTRY, + PredictLinear.ENTRY, Sample.ENTRY, SpatialCentroid.ENTRY, SpatialExtent.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java index 431f9e1c4a056..c8adde0f50d00 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Deriv.java @@ -12,6 +12,7 @@ import org.elasticsearch.compute.aggregation.DerivDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.DerivLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SimpleLinearRegressionWithTimeseries; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; @@ -19,6 +20,8 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; import org.elasticsearch.xpack.esql.expression.function.Param; @@ -42,6 +45,8 @@ public class Deriv extends TimeSeriesAggregateFunction implements ToAggregator, type = FunctionType.TIME_SERIES_AGGREGATE, returnType = { "double" }, description = "Calculates the derivative over time of a numeric field using linear regression.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.PREVIEW, version = "9.3.0") }, + preview = true, examples = { @Example(file = "k8s-timeseries", tag = "deriv") } ) public Deriv(Source source, @Param(name = "field", type = { "long", "integer", "double" }) Expression field, Expression timestamp) { @@ -112,10 +117,12 @@ public String getWriteableName() { @Override public AggregatorFunctionSupplier supplier() { final DataType type = field().dataType(); + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn = (SimpleLinearRegressionWithTimeseries model) -> model.slope(); + final boolean isDateNanos = timestamp.dataType() == DataType.DATE_NANOS; return switch (type) { - case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(); - case LONG -> new DerivLongAggregatorFunctionSupplier(); - case INTEGER -> new DerivIntAggregatorFunctionSupplier(); + case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(fn, isDateNanos); + case LONG -> new DerivLongAggregatorFunctionSupplier(fn, isDateNanos); + case INTEGER -> new DerivIntAggregatorFunctionSupplier(fn, isDateNanos); default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PredictLinear.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PredictLinear.java new file mode 100644 index 0000000000000..e4a706fa1cf27 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PredictLinear.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.DerivLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SimpleLinearRegressionWithTimeseries; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.FunctionType; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.TimestampAware; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.ToAggregator; +import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter; + +import java.time.Duration; +import java.util.List; +import java.util.Objects; + +/** + * Calculates the derivative over time of a numeric field using linear regression. + */ +public class PredictLinear extends TimeSeriesAggregateFunction implements ToAggregator, TimestampAware { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "PredictLinear", + PredictLinear::new + ); + private final Expression timestamp; + private final Expression t; + + @FunctionInfo( + type = FunctionType.TIME_SERIES_AGGREGATE, + returnType = { "double" }, + description = "Predicts the value of a time series at `t` seconds in the future.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.PREVIEW, version = "9.3.0") }, + preview = true, + examples = { @Example(file = "k8s-timeseries", tag = "predict_linear") } + ) + public PredictLinear( + Source source, + @Param( + name = "field", + type = { "long", "integer", "double" }, + description = "the expression to use for the prediction" + ) Expression field, + @Param( + name = "t", + type = { "long", "integer", "time_duration", "double" }, + description = "how long in the fututre to predict in seconds for numeric, or in time delta" + ) Expression t, + Expression timestamp + ) { + this(source, field, Literal.TRUE, NO_WINDOW, timestamp, t); + } + + public PredictLinear(Source source, Expression field, Expression filter, Expression window, Expression ts, Expression t) { + super(source, field, filter, window, List.of(ts, t)); + this.timestamp = ts; + this.t = t; + } + + private PredictLinear(org.elasticsearch.common.io.stream.StreamInput in) throws java.io.IOException { + super( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteableCollectionAsList(Expression.class) + ); + this.t = children().get(4); + this.timestamp = children().get(3); + } + + @Override + public Expression timestamp() { + return timestamp; + } + + @Override + public AggregateFunction perTimeSeriesAggregation() { + return this; + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new PredictLinear(source(), field(), filter, window(), timestamp, t); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new PredictLinear( + source(), + newChildren.get(0), + newChildren.get(1), + newChildren.get(2), + newChildren.get(3), + newChildren.get(4) + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, PredictLinear::new, field(), filter(), window(), timestamp, t); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public AggregatorFunctionSupplier supplier() { + if (t.foldable() == false) { + throw new IllegalArgumentException("The 't' parameter of the 'predict_linear' function must be a constant value."); + } + final double timeDiffSeconds = switch (t.fold(FoldContext.small())) { + case Duration d -> d.toMillis() / 1000.0; + case Long l -> l.doubleValue(); + case Integer i -> i.doubleValue(); + case String s -> Duration.from(Objects.requireNonNull(EsqlDataTypeConverter.parseTemporalAmount(s, DataType.TIME_DURATION))) + .toMillis() / 1000.0; + default -> throw new IllegalArgumentException( + "The 't' parameter of the 'predict_linear' function must be of type long, integer, keyword, or timedelta. It was of type: " + + t.dataType() + ); + }; + + final DataType type = field().dataType(); + final boolean isDateNanos = timestamp.dataType() == DataType.DATE_NANOS; + SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction fn = getSimpleLinearModelFunction(timeDiffSeconds); + return switch (type) { + case LONG -> new DerivLongAggregatorFunctionSupplier(fn, isDateNanos); + case INTEGER -> new DerivIntAggregatorFunctionSupplier(fn, isDateNanos); + case DOUBLE -> new DerivDoubleAggregatorFunctionSupplier(fn, isDateNanos); + default -> throw new IllegalArgumentException("Unsupported data type for deriv aggregation: " + type); + }; + } + + private static SimpleLinearRegressionWithTimeseries.SimpleLinearModelFunction getSimpleLinearModelFunction(double timeDiffSeconds) { + return (SimpleLinearRegressionWithTimeseries model) -> { + double slope = model.slope(); + if (Double.isNaN(slope)) { + return Double.NaN; + } + return model.valueAtLastTimestamp() + slope * timeDiffSeconds; + }; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PredictLinearTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PredictLinearTests.java new file mode 100644 index 0000000000000..ff1fda418ddf4 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PredictLinearTests.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.DocsV3Support; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class PredictLinearTests extends AbstractFunctionTestCase { + public PredictLinearTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + var valuesSuppliers = List.of( + MultiRowTestCaseSupplier.longCases(1, 1000, 0, 1000_000_000, true), + MultiRowTestCaseSupplier.intCases(1, 1000, 0, 1000_000_000, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, 0, 1000_000_000, true) + ); + for (List valuesSupplier : valuesSuppliers) { + for (TestCaseSupplier.TypedDataSupplier fieldSupplier : valuesSupplier) { + List testCaseSuppliers = makeSuppliers(fieldSupplier); + suppliers.addAll(testCaseSuppliers); + } + } + List parameters = new ArrayList<>(suppliers.size()); + for (TestCaseSupplier supplier : suppliers) { + parameters.add(new Object[] { supplier }); + } + return parameters; + } + + @Override + protected Expression build(Source source, List args) { + return new PredictLinear(source, args.get(0), Literal.TRUE, AggregateFunction.NO_WINDOW, args.get(1), args.get(2)); + } + + @SuppressWarnings("unchecked") + private static List makeSuppliers(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + DataType type = fieldSupplier.type(); + return List.of(DataType.LONG, DataType.TIME_DURATION, DataType.DOUBLE, DataType.INTEGER) + .stream() + .map(tType -> new TestCaseSupplier(fieldSupplier.name(), List.of(type, DataType.DATETIME, tType), () -> { + TestCaseSupplier.TypedData fieldTypedData = fieldSupplier.get(); + List dataRows = fieldTypedData.multiRowData(); + if (randomBoolean()) { + List withNulls = new ArrayList<>(dataRows.size()); + for (Object dataRow : dataRows) { + if (randomBoolean()) { + withNulls.add(null); + } else { + withNulls.add(dataRow); + } + } + dataRows = withNulls; + } + fieldTypedData = TestCaseSupplier.TypedData.multiRow(dataRows, type, fieldTypedData.name()); + List timestamps = new ArrayList<>(); + List slices = new ArrayList<>(); + List maxTimestamps = new ArrayList<>(); + long lastTimestamp = randomLongBetween(0, 1_000_000); + for (int row = 0; row < dataRows.size(); row++) { + lastTimestamp += randomLongBetween(1, 10_000); + timestamps.add(lastTimestamp); + slices.add(0); + maxTimestamps.add(Long.MAX_VALUE); + } + TestCaseSupplier.TypedData timestampsField = TestCaseSupplier.TypedData.multiRow( + timestamps.reversed(), + DataType.DATETIME, + "timestamps" + ); + TestCaseSupplier.TypedData theTField = TestCaseSupplier.TypedData.multiRow(maxTimestamps, tType, "some_t_field"); + + List nonNullDataRows = dataRows.stream().filter(Objects::nonNull).toList(); + Matcher matcher; + if (nonNullDataRows.size() < 2) { + matcher = Matchers.nullValue(); + } else { + var lastValue = ((Number) nonNullDataRows.getFirst()).doubleValue(); + var secondLastValue = ((Number) nonNullDataRows.get(1)).doubleValue(); + var increase = lastValue >= secondLastValue ? lastValue - secondLastValue : lastValue; + var largestTimestamp = timestamps.get(0); + var secondLargestTimestamp = timestamps.get(1); + var smallestTimestamp = timestamps.getLast(); + matcher = Matchers.allOf( + Matchers.greaterThanOrEqualTo(increase / (largestTimestamp - smallestTimestamp) * 1000 * 0.9), + Matchers.lessThanOrEqualTo( + increase / (largestTimestamp - secondLargestTimestamp) * (largestTimestamp - smallestTimestamp) * 1000 + ) + ); + } + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData, timestampsField, theTField), + Matchers.stringContainsInOrder("GroupingAggregator", "PredictLinear", "GroupingAggregatorFunction"), + DataType.DOUBLE, + matcher + ); + })) + .toList(); + } + + public static List signatureTypes(List params) { + assertThat(params, hasSize(3)); + assertThat(params.get(1).dataType(), equalTo(DataType.DATETIME)); + assertThat( + params.get(2).dataType(), + Matchers.in(List.of(DataType.LONG, DataType.TIME_DURATION, DataType.DOUBLE, DataType.INTEGER)) + ); + return List.of(params.get(0), params.get(2)); + } +}