diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/decay.md b/docs/reference/query-languages/esql/_snippets/functions/description/decay.md new file mode 100644 index 0000000000000..9d6f305d5661f --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/decay.md @@ -0,0 +1,16 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Calculates a relevance score that decays based on the distance of a numeric, spatial or date type value from a target origin, using configurable decay functions. + +`DECAY` calculates a score between 0 and 1 based on how far a field value is from a specified origin point (called distance). +The distance can be a numeric distance, spatial distance or temporal distance depending on the specific data type. + +`DECAY` can use [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params) to specify additional `options` +for the decay function. + +For spatial queries, scale and offset for geo points use distance units (e.g., "10km", "5mi"), +while cartesian points use numeric values. For date queries, scale and offset use time_duration values. +For numeric queries you also use numeric values. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/decay.md b/docs/reference/query-languages/esql/_snippets/functions/examples/decay.md new file mode 100644 index 0000000000000..b3efac50b7ae7 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/decay.md @@ -0,0 +1,9 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +null +``` + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/decay.md b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/decay.md new file mode 100644 index 0000000000000..8d72318865141 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/decay.md @@ -0,0 +1,13 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported function named parameters** + +`offset` +: (double, integer, long, time_duration, keyword, text) Distance from the origin where no decay occurs. + +`type` +: (keyword) Decay function to use: linear, exponential or gaussian. + +`decay` +: (double) Multiplier value returned at the scale distance from the origin. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/decay.md b/docs/reference/query-languages/esql/_snippets/functions/layout/decay.md new file mode 100644 index 0000000000000..2e43297523822 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/decay.md @@ -0,0 +1,30 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `DECAY` [esql-decay] +```{applies_to} +stack: preview 9.2.0 +serverless: preview +``` + +**Syntax** + +:::{image} ../../../images/functions/decay.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/decay.md +::: + +:::{include} ../description/decay.md +::: + +:::{include} ../types/decay.md +::: + +:::{include} ../functionNamedParams/decay.md +::: + +:::{include} ../examples/decay.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/decay.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/decay.md new file mode 100644 index 0000000000000..dae2fea0fb426 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/decay.md @@ -0,0 +1,16 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`value` +: The input value to apply decay scoring to. + +`origin` +: Central point from which the distances are calculated. + +`scale` +: Distance from the origin where the function returns the decay value. + +`options` +: + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/decay.md b/docs/reference/query-languages/esql/_snippets/functions/types/decay.md new file mode 100644 index 0000000000000..2b64fee072ddc --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/decay.md @@ -0,0 +1,15 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| value | origin | scale | options | result | +| --- | --- | --- | --- | --- | +| cartesian_point | cartesian_point | double | named parameters | double | +| date | date | time_duration | named parameters | double | +| date_nanos | date_nanos | time_duration | named parameters | double | +| double | double | double | named parameters | double | +| geo_point | geo_point | keyword | named parameters | double | +| geo_point | geo_point | text | named parameters | double | +| integer | integer | integer | named parameters | double | +| long | long | long | named parameters | double | + diff --git a/docs/reference/query-languages/esql/_snippets/lists/search-functions.md b/docs/reference/query-languages/esql/_snippets/lists/search-functions.md index 76b0929065a13..71c9ec005985c 100644 --- a/docs/reference/query-languages/esql/_snippets/lists/search-functions.md +++ b/docs/reference/query-languages/esql/_snippets/lists/search-functions.md @@ -2,4 +2,5 @@ * [`MATCH`](../../functions-operators/search-functions.md#esql-match) * [`MATCH_PHRASE`](../../functions-operators/search-functions.md#esql-match_phrase) * [`QSTR`](../../functions-operators/search-functions.md#esql-qstr) +% * [preview] [`DECAY`](../../functions-operators/search-functions.md#esql-decay) % * [preview] [`TERM`](../../functions-operators/search-functions.md#esql-term) diff --git a/docs/reference/query-languages/esql/functions-operators/search-functions.md b/docs/reference/query-languages/esql/functions-operators/search-functions.md index 597f61cfc5003..d72b30d0efdb1 100644 --- a/docs/reference/query-languages/esql/functions-operators/search-functions.md +++ b/docs/reference/query-languages/esql/functions-operators/search-functions.md @@ -13,7 +13,7 @@ our [hands-on tutorial](/reference/query-languages/esql/esql-search-tutorial.md) For a high-level overview of search functionalities in {{esql}}, and to learn about relevance scoring, refer to [{{esql}} for search](docs-content://solutions/search/esql-for-search.md#esql-for-search-scoring). ::: -{{esql}} provides a set of functions for performing searching on text fields. +{{esql}} provides a set of functions for performing searching on text fields. Use these functions for [full-text search](docs-content://solutions/search/full-text.md) @@ -36,6 +36,7 @@ for information on the limitations of full text search. :::{include} ../_snippets/lists/search-functions.md ::: + :::{include} ../_snippets/functions/layout/kql.md ::: @@ -54,3 +55,8 @@ lists/search-functions.md % :::{include} ../_snippets/functions/layout/term.md % ::: +% DECAY is currently a hidden feature +% To make it visible again, uncomment this and the line in +lists/search-functions.md +% :::{include} ../_snippets/functions/layout/decay.md +% ::: diff --git a/docs/reference/query-languages/esql/images/functions/decay.svg b/docs/reference/query-languages/esql/images/functions/decay.svg new file mode 100644 index 0000000000000..176ef68b5b730 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/decay.svg @@ -0,0 +1 @@ +DECAY(value,origin,scale,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/decay.json b/docs/reference/query-languages/esql/kibana/definition/functions/decay.json new file mode 100644 index 0000000000000..56ca96d77d071 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/decay.json @@ -0,0 +1,261 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "decay", + "description" : "Calculates a relevance score that decays based on the distance of a numeric, spatial or date type value from a target origin, using configurable decay functions.", + "signatures" : [ + { + "params" : [ + { + "name" : "value", + "type" : "cartesian_point", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "cartesian_point", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "double", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "date", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "date", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "time_duration", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "date_nanos", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "date_nanos", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "time_duration", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "double", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "double", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "double", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "geo_point", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "geo_point", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "keyword", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "geo_point", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "geo_point", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "text", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "integer", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "integer", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "integer", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "value", + "type" : "long", + "optional" : false, + "description" : "The input value to apply decay scoring to." + }, + { + "name" : "origin", + "type" : "long", + "optional" : false, + "description" : "Central point from which the distances are calculated." + }, + { + "name" : "scale", + "type" : "long", + "optional" : false, + "description" : "Distance from the origin where the function returns the decay value." + }, + { + "name" : "options", + "type" : "function_named_parameters", + "mapParams" : "{name='offset', values=[], description='Distance from the origin where no decay occurs.'}, {name='type', values=[], description='Decay function to use: linear, exponential or gaussian.'}, {name='decay', values=[], description='Multiplier value returned at the scale distance from the origin.'}", + "optional" : true, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + null + ], + "preview" : true, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/decay.md b/docs/reference/query-languages/esql/kibana/docs/functions/decay.md new file mode 100644 index 0000000000000..1f1550cbf1d9e --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/decay.md @@ -0,0 +1,8 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### DECAY +Calculates a relevance score that decays based on the distance of a numeric, spatial or date type value from a target origin, using configurable decay functions. + +```esql +null +``` diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java index d6f74144a9717..0bccb5080ef56 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java @@ -22,13 +22,17 @@ import org.elasticsearch.xpack.versionfield.Version; import java.io.IOException; +import java.time.Duration; import java.util.Collection; import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION; import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.CARTESIAN; @@ -208,6 +212,26 @@ public static Literal keyword(Source source, String literal) { return new Literal(source, BytesRefs.toBytesRef(literal), KEYWORD); } + public static Literal text(Source source, String literal) { + return new Literal(source, BytesRefs.toBytesRef(literal), TEXT); + } + + public static Literal timeDuration(Source source, Duration literal) { + return new Literal(source, literal, DataType.TIME_DURATION); + } + + public static Literal integer(Source source, Integer literal) { + return new Literal(source, literal, INTEGER); + } + + public static Literal fromDouble(Source source, Double literal) { + return new Literal(source, literal, DOUBLE); + } + + public static Literal fromLong(Source source, Long literal) { + return new Literal(source, literal, LONG); + } + /** * Not all literal values are currently supported in StreamInput/StreamOutput as generic values. * This mapper allows for addition of new and interesting values without (yet) adding to StreamInput/Output. diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java index 3c36884874454..528f9ac2f57ea 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java @@ -521,6 +521,14 @@ public static boolean isDateTime(DataType type) { return type == DATETIME; } + public static boolean isTimeDuration(DataType t) { + return t == TIME_DURATION; + } + + public static boolean isDateNanos(DataType t) { + return t == DATE_NANOS; + } + public static boolean isNullOrTimeDuration(DataType t) { return t == TIME_DURATION || isNull(t); } @@ -580,7 +588,15 @@ public static boolean isCounter(DataType t) { } public static boolean isSpatialPoint(DataType t) { - return t == GEO_POINT || t == CARTESIAN_POINT; + return isGeoPoint(t) || isCartesianPoint(t); + } + + public static boolean isGeoPoint(DataType t) { + return t == GEO_POINT; + } + + public static boolean isCartesianPoint(DataType t) { + return t == CARTESIAN_POINT; } public static boolean isSpatialShape(DataType t) { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/decay.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/decay.csv-spec new file mode 100644 index 0000000000000..313404d5a33af --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/decay.csv-spec @@ -0,0 +1,285 @@ +############################################### +# Tests for DecayLinear function +# + +intLinear +required_capability: decay_function + +ROW value = 5 +| EVAL decay_result = decay(value, 10, 10, {"offset": 0, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +intLinear2 +required_capability: decay_function + +ROW value = 5 +| EVAL decay_result = decay(value, 5 + 5, 5 + 5, {"offset": 0, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +intExp +required_capability: decay_function + +ROW value = 5 +| EVAL decay_result = round(decay(value, 10, 10, {"offset": 0, "decay": 0.5, "type": "exp"}), 7) +| KEEP decay_result; + +decay_result:double +0.7071068 +; + +intGauss +required_capability: decay_function + +ROW value = 5 +| EVAL decay_result = round(decay(value, 10, 10, {"offset": 0, "decay": 0.5, "type": "gauss"}), 7) +| KEEP decay_result; + +decay_result:double +0.8408964 +; + +intLinearWithOffset +required_capability: decay_function + +ROW value = 95 +| EVAL decay_result = decay(value, 100, 50, {"offset": 10, "decay": 0.3, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +1.0 +; + +intExpWithOffset +required_capability: decay_function + +ROW value = 120 +| EVAL decay_result = round(decay(value, 100, 50, {"offset": 5, "decay": 0.3, "type": "exp"}), 7) +| KEEP decay_result; + +decay_result:double +0.6968453 +; + +intGaussWithOffset +required_capability: decay_function + +ROW value = 120 +| EVAL decay_result = round(decay(value, 100, 50, {"offset": 5, "decay": 0.3, "type": "gauss"}), 7) +| KEEP decay_result; + +decay_result:double +0.8973067 +; + +intWithoutOptions +required_capability: decay_function + +ROW value = 5 +| EVAL decay_result = decay(value, 10, 10) +| KEEP decay_result; + +decay_result:double +0.75 +; + +intOnlyWithOffset +required_capability: decay_function + +ROW value = 5 +| EVAL decay_result = decay(value, 10, 10, {"offset": 100}) +| KEEP decay_result; + +decay_result:double +1.0 +; + +intMultipleRows +required_capability: decay_function + +FROM employees +| EVAL decay_result = decay(salary, 0, 100000, {"offset": 5, "decay": 0.5, "type": "linear"}) +| KEEP decay_result +| SORT decay_result DESC +| LIMIT 5; + +decay_result:double +0.873405 +0.8703 +0.870145 +0.867845 +0.86395 +; + +intOriginReference +required_capability: decay_function + +ROW value = 5, origin = 10 +| EVAL decay_result = decay(value, origin, 10, {"offset": 0, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +intScaleReference +required_capability: decay_function + +ROW value = 5, scale = 10 +| EVAL decay_result = decay(value, 10, scale, {"offset": 0, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +intScaleAndOriginReference +required_capability: decay_function + +ROW value = 5, origin = 10, scale = 10 +| EVAL decay_result = decay(value, origin, scale, {"offset": 0, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +doubleLinear +required_capability: decay_function + +ROW value = 5.0 +| EVAL decay_result = decay(value, 10.0, 10.0, {"offset": 0.0, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +doubleExp +required_capability: decay_function + +ROW value = 5.0 +| EVAL decay_result = round(decay(value, 10.0, 10.0, {"offset": 0.0, "decay": 0.5, "type": "exp"}), 7) +| KEEP decay_result; + +decay_result:double +0.7071068 +; + +doubleGauss +required_capability: decay_function + +ROW value = 5.0 +| EVAL decay_result = round(decay(value, 10.0, 10.0, {"offset": 0.0, "decay": 0.5, "type": "gauss"}), 7) +| KEEP decay_result; + +decay_result:double +0.8408964 +; + +longLinear +required_capability: decay_function + +ROW value = 15::long +| EVAL decay_result = decay(value, 10::long, 10::long, {"offset": 10000000000, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +1.0 +; + +cartesianPointLinear1 +required_capability: decay_function + +ROW value = TO_CARTESIANPOINT("POINT(5 5)") +| EVAL decay_result = decay(value, TO_CARTESIANPOINT("POINT(0 0)"), 10.0, {"offset": 0.0, "decay": 0.25, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.46966991411008935 +; + +cartesianPointLinear2 +required_capability: decay_function + +ROW value = TO_CARTESIANPOINT("POINT(10 0)") +| EVAL decay_result = ROUND(decay(value, TO_CARTESIANPOINT("POINT(0 0)"), 10.0, {"offset": 0.0, "decay": 0.25, "type": "linear"}), 7) +| KEEP decay_result; + +decay_result:double +0.25 +; + +cartesianPointLinearWithOffset +required_capability: decay_function + +ROW value = TO_CARTESIANPOINT("POINT(10 0)") +| EVAL decay_result = ROUND(decay(value, TO_CARTESIANPOINT("POINT(0 0)"), 10.0, {"offset": 5.0, "decay": 0.25, "type": "linear"}), 7) +| KEEP decay_result; + +decay_result:double +0.625 +; + + +geoPointLinear +required_capability: decay_function + +ROW value = TO_GEOPOINT("POINT(0 0)") +| EVAL decay_result = decay(value, TO_GEOPOINT("POINT(1 1)"), "200km", {"offset": "0km", "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.606876005579706 +; + +datetimeLinear1 +required_capability: decay_function + +ROW value = TO_DATETIME("2023-01-01T00:00:00Z") +| EVAL decay_result = decay(value, TO_DATETIME("2023-01-01T00:00:00Z"), 24 hours, {"offset": 0 seconds, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +1.0 +; + +datetimeLinear2 +required_capability: decay_function + +ROW value = TO_DATETIME("2023-01-01T12:00:00Z") +| EVAL decay_result = decay(value, TO_DATETIME("2023-01-01T00:00:00Z"), 24 hours, {"offset": 0 seconds, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; + +dateNanosLinear1 +required_capability: decay_function + +ROW value = TO_DATE_NANOS("2023-01-01T00:00:00Z") +| EVAL decay_result = decay(value, TO_DATE_NANOS("2023-01-01T00:00:00Z"), 24 hours, {"offset": 0 seconds, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +1.0 +; + +dateNanosLinear2 +required_capability: decay_function + +ROW value = TO_DATE_NANOS("2023-01-01T12:00:00Z") +| EVAL decay_result = decay(value, TO_DATE_NANOS("2023-01-01T00:00:00Z"), 24 hours, {"offset": 0 seconds, "decay": 0.5, "type": "linear"}) +| KEEP decay_result; + +decay_result:double +0.75 +; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayCartesianPointEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayCartesianPointEvaluator.java new file mode 100644 index 0000000000000..25c91b23ebf2f --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayCartesianPointEvaluator.java @@ -0,0 +1,169 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayCartesianPointEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayCartesianPointEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final BytesRef origin; + + private final double scale; + + private final double offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayCartesianPointEvaluator(Source source, EvalOperator.ExpressionEvaluator value, + BytesRef origin, double scale, double offset, double decay, Decay.DecayFunction decayFunction, + DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (BytesRefBlock valueBlock = (BytesRefBlock) value.eval(page)) { + BytesRefVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector).asBlock(); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, BytesRefBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + BytesRef valueScratch = new BytesRef(); + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendDouble(Decay.processCartesianPoint(valueBlock.getBytesRef(valueBlock.getFirstValueIndex(p), valueScratch), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + public DoubleVector eval(int positionCount, BytesRefVector valueVector) { + try(DoubleVector.FixedBuilder result = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + BytesRef valueScratch = new BytesRef(); + position: for (int p = 0; p < positionCount; p++) { + result.appendDouble(p, Decay.processCartesianPoint(valueVector.getBytesRef(p, valueScratch), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayCartesianPointEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final BytesRef origin; + + private final double scale; + + private final double offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, BytesRef origin, + double scale, double offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayCartesianPointEvaluator get(DriverContext context) { + return new DecayCartesianPointEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayCartesianPointEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDateNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDateNanosEvaluator.java new file mode 100644 index 0000000000000..223e11c6150aa --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDateNanosEvaluator.java @@ -0,0 +1,176 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayDateNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayDateNanosEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final long origin; + + private final long scale; + + private final long offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayDateNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator value, long origin, + long scale, long offset, double decay, Decay.DecayFunction decayFunction, + DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock valueBlock = (LongBlock) value.eval(page)) { + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, LongBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + try { + result.appendDouble(Decay.processDateNanos(valueBlock.getLong(valueBlock.getFirstValueIndex(p)), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } catch (InvalidArgumentException | IllegalArgumentException e) { + warnings().registerException(e); + result.appendNull(); + } + } + return result.build(); + } + } + + public DoubleBlock eval(int positionCount, LongVector valueVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + try { + result.appendDouble(Decay.processDateNanos(valueVector.getLong(p), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } catch (InvalidArgumentException | IllegalArgumentException e) { + warnings().registerException(e); + result.appendNull(); + } + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayDateNanosEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final long origin; + + private final long scale; + + private final long offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, long origin, + long scale, long offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayDateNanosEvaluator get(DriverContext context) { + return new DecayDateNanosEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayDateNanosEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDatetimeEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDatetimeEvaluator.java new file mode 100644 index 0000000000000..f5618ec2f95e8 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDatetimeEvaluator.java @@ -0,0 +1,176 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayDatetimeEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayDatetimeEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final long origin; + + private final long scale; + + private final long offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayDatetimeEvaluator(Source source, EvalOperator.ExpressionEvaluator value, long origin, + long scale, long offset, double decay, Decay.DecayFunction decayFunction, + DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock valueBlock = (LongBlock) value.eval(page)) { + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, LongBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + try { + result.appendDouble(Decay.processDatetime(valueBlock.getLong(valueBlock.getFirstValueIndex(p)), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } catch (InvalidArgumentException | IllegalArgumentException e) { + warnings().registerException(e); + result.appendNull(); + } + } + return result.build(); + } + } + + public DoubleBlock eval(int positionCount, LongVector valueVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + try { + result.appendDouble(Decay.processDatetime(valueVector.getLong(p), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } catch (InvalidArgumentException | IllegalArgumentException e) { + warnings().registerException(e); + result.appendNull(); + } + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayDatetimeEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final long origin; + + private final long scale; + + private final long offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, long origin, + long scale, long offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayDatetimeEvaluator get(DriverContext context) { + return new DecayDatetimeEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayDatetimeEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDoubleEvaluator.java new file mode 100644 index 0000000000000..a9fe2cbe0c416 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayDoubleEvaluator.java @@ -0,0 +1,164 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayDoubleEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayDoubleEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final double origin; + + private final double scale; + + private final double offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayDoubleEvaluator(Source source, EvalOperator.ExpressionEvaluator value, double origin, + double scale, double offset, double decay, Decay.DecayFunction decayFunction, + DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (DoubleBlock valueBlock = (DoubleBlock) value.eval(page)) { + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector).asBlock(); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, DoubleBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendDouble(Decay.process(valueBlock.getDouble(valueBlock.getFirstValueIndex(p)), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + public DoubleVector eval(int positionCount, DoubleVector valueVector) { + try(DoubleVector.FixedBuilder result = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendDouble(p, Decay.process(valueVector.getDouble(p), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayDoubleEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final double origin; + + private final double scale; + + private final double offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, double origin, + double scale, double offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayDoubleEvaluator get(DriverContext context) { + return new DecayDoubleEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayDoubleEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayGeoPointEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayGeoPointEvaluator.java new file mode 100644 index 0000000000000..139d58d6d7e37 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayGeoPointEvaluator.java @@ -0,0 +1,169 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayGeoPointEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayGeoPointEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final BytesRef origin; + + private final BytesRef scale; + + private final BytesRef offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayGeoPointEvaluator(Source source, EvalOperator.ExpressionEvaluator value, + BytesRef origin, BytesRef scale, BytesRef offset, double decay, + Decay.DecayFunction decayFunction, DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (BytesRefBlock valueBlock = (BytesRefBlock) value.eval(page)) { + BytesRefVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector).asBlock(); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, BytesRefBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + BytesRef valueScratch = new BytesRef(); + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendDouble(Decay.process(valueBlock.getBytesRef(valueBlock.getFirstValueIndex(p), valueScratch), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + public DoubleVector eval(int positionCount, BytesRefVector valueVector) { + try(DoubleVector.FixedBuilder result = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + BytesRef valueScratch = new BytesRef(); + position: for (int p = 0; p < positionCount; p++) { + result.appendDouble(p, Decay.process(valueVector.getBytesRef(p, valueScratch), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayGeoPointEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final BytesRef origin; + + private final BytesRef scale; + + private final BytesRef offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, BytesRef origin, + BytesRef scale, BytesRef offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayGeoPointEvaluator get(DriverContext context) { + return new DecayGeoPointEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayGeoPointEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayIntEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayIntEvaluator.java new file mode 100644 index 0000000000000..0a28fc15c0e2a --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayIntEvaluator.java @@ -0,0 +1,166 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayIntEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayIntEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final int origin; + + private final int scale; + + private final int offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayIntEvaluator(Source source, EvalOperator.ExpressionEvaluator value, int origin, + int scale, int offset, double decay, Decay.DecayFunction decayFunction, + DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (IntBlock valueBlock = (IntBlock) value.eval(page)) { + IntVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector).asBlock(); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, IntBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendDouble(Decay.process(valueBlock.getInt(valueBlock.getFirstValueIndex(p)), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + public DoubleVector eval(int positionCount, IntVector valueVector) { + try(DoubleVector.FixedBuilder result = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendDouble(p, Decay.process(valueVector.getInt(p), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayIntEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final int origin; + + private final int scale; + + private final int offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, int origin, + int scale, int offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayIntEvaluator get(DriverContext context) { + return new DecayIntEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayIntEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayLongEvaluator.java new file mode 100644 index 0000000000000..3cc9b4da8f7a1 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayLongEvaluator.java @@ -0,0 +1,166 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Decay}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class DecayLongEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DecayLongEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final long origin; + + private final long scale; + + private final long offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + private final DriverContext driverContext; + + private Warnings warnings; + + public DecayLongEvaluator(Source source, EvalOperator.ExpressionEvaluator value, long origin, + long scale, long offset, double decay, Decay.DecayFunction decayFunction, + DriverContext driverContext) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock valueBlock = (LongBlock) value.eval(page)) { + LongVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector).asBlock(); + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += value.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, LongBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendDouble(Decay.process(valueBlock.getLong(valueBlock.getFirstValueIndex(p)), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + public DoubleVector eval(int positionCount, LongVector valueVector) { + try(DoubleVector.FixedBuilder result = driverContext.blockFactory().newDoubleVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendDouble(p, Decay.process(valueVector.getLong(p), this.origin, this.scale, this.offset, this.decay, this.decayFunction)); + } + return result.build(); + } + } + + @Override + public String toString() { + return "DecayLongEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final long origin; + + private final long scale; + + private final long offset; + + private final double decay; + + private final Decay.DecayFunction decayFunction; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, long origin, + long scale, long offset, double decay, Decay.DecayFunction decayFunction) { + this.source = source; + this.value = value; + this.origin = origin; + this.scale = scale; + this.offset = offset; + this.decay = decay; + this.decayFunction = decayFunction; + } + + @Override + public DecayLongEvaluator get(DriverContext context) { + return new DecayLongEvaluator(source, value.get(context), origin, scale, offset, decay, decayFunction, context); + } + + @Override + public String toString() { + return "DecayLongEvaluator[" + "value=" + value + ", origin=" + origin + ", scale=" + scale + ", offset=" + offset + ", decay=" + decay + ", decayFunction=" + decayFunction + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 36ee0536e393b..57bc9382fc446 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1388,6 +1388,11 @@ public enum Cap { */ CATEGORIZE_OPTIONS, + /** + * Decay function for custom scoring + */ + DECAY_FUNCTION(Build.current().isSnapshot()), + /** * FIRST and LAST aggregate functions. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index f4d20dcafd1a0..08c17be59ba8c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -145,6 +145,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvZip; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.expression.function.scalar.score.Decay; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialContains; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialDisjoint; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialIntersects; @@ -261,7 +262,7 @@ public class EsqlFunctionRegistry { } // Translation table for error messaging in the following function - private static final String[] NUM_NAMES = { "zero", "one", "two", "three", "four", "five", }; + private static final String[] NUM_NAMES = { "zero", "one", "two", "three", "four", "five", "six" }; // list of functions grouped by type of functions (aggregate, statistics, math etc) and ordered alphabetically inside each group // a single function will have one entry for itself with its name associated to its instance and, also, one entry for each alias @@ -478,6 +479,7 @@ private static FunctionDefinition[][] functions() { def(Split.class, Split::new, "split") }, // fulltext functions new FunctionDefinition[] { + def(Decay.class, quad(Decay::new), "decay"), def(Kql.class, uni(Kql::new), "kql"), def(Match.class, tri(Match::new), "match"), def(MultiMatch.class, MultiMatch::new, "multi_match"), @@ -987,7 +989,6 @@ public static FunctionDefinition def(Class function, Bin Strings.format("function %s expects exactly two arguments, it received %d", Arrays.toString(names), children.size()) ); } - return ctorRef.build(source, children.get(0), children.size() == 2 ? children.get(1) : null); }; return def(function, builder, names); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java index a4df48834eb27..113c40166eace 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Options.java @@ -18,8 +18,10 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; +import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.function.BiConsumer; import java.util.function.Consumer; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; @@ -35,7 +37,13 @@ public static Expression.TypeResolution resolve( TypeResolutions.ParamOrdinal paramOrdinal, Map allowedOptions ) { - return resolve(options, source, paramOrdinal, allowedOptions, null); + return resolve( + options, + source, + paramOrdinal, + null, + (opts, optsMap) -> populateMap(opts, optsMap, source, paramOrdinal, allowedOptions) + ); } public static Expression.TypeResolution resolve( @@ -44,6 +52,37 @@ public static Expression.TypeResolution resolve( TypeResolutions.ParamOrdinal paramOrdinal, Map allowedOptions, Consumer> verifyOptions + ) { + return resolve( + options, + source, + paramOrdinal, + verifyOptions, + (opts, optsMap) -> populateMap(opts, optsMap, source, paramOrdinal, allowedOptions) + ); + } + + public static Expression.TypeResolution resolveWithMultipleDataTypesAllowed( + Expression options, + Source source, + TypeResolutions.ParamOrdinal paramOrdinal, + Map> allowedOptions + ) { + return resolve( + options, + source, + paramOrdinal, + null, + (opts, optsMap) -> populateMapWithExpressionsMultipleDataTypesAllowed(opts, optsMap, source, paramOrdinal, allowedOptions) + ); + } + + private static Expression.TypeResolution resolve( + Expression options, + Source source, + TypeResolutions.ParamOrdinal paramOrdinal, + Consumer> verifyOptions, + BiConsumer> populateMap ) { if (options != null) { Expression.TypeResolution resolution = isNotNull(options, source.text(), paramOrdinal); @@ -57,7 +96,7 @@ public static Expression.TypeResolution resolve( } try { Map optionsMap = new HashMap<>(); - populateMap((MapExpression) options, optionsMap, source, paramOrdinal, allowedOptions); + populateMap.accept((MapExpression) options, optionsMap); if (verifyOptions != null) { verifyOptions.accept(optionsMap); } @@ -112,4 +151,54 @@ public static void populateMap( } } } + + public static void populateMapWithExpressionsMultipleDataTypesAllowed( + final MapExpression options, + final Map optionsMap, + final Source source, + final TypeResolutions.ParamOrdinal paramOrdinal, + final Map> allowedOptions + ) throws InvalidArgumentException { + if (options == null) { + return; + } + + for (EntryExpression entry : options.entryExpressions()) { + Expression optionExpr = entry.key(); + Expression valueExpr = entry.value(); + + Expression.TypeResolution optionNameResolution = isFoldable(optionExpr, source.text(), paramOrdinal); + if (optionNameResolution.unresolved()) { + throw new InvalidArgumentException(optionNameResolution.message()); + } + + Object optionExprLiteral = ((Literal) optionExpr).value(); + String optionName = optionExprLiteral instanceof BytesRef br ? br.utf8ToString() : optionExprLiteral.toString(); + Collection allowedDataTypes = allowedOptions.get(optionName); + + // valueExpr could be a MapExpression, but for now functions only accept literal values in options + if ((valueExpr instanceof Literal) == false) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], expected a [{}] value", optionName, source.text(), allowedDataTypes) + ); + } + + // validate the optionExpr is supported + if (allowedDataTypes == null || allowedDataTypes.isEmpty()) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], expected one of {}", optionName, source.text(), allowedOptions.keySet()) + ); + } + + Literal valueExprLiteral = ((Literal) valueExpr); + // validate that the literal has one of the allowed data types + if (allowedDataTypes.contains(valueExprLiteral.dataType()) == false) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], allowed types [{}]", optionName, source.text(), allowedDataTypes) + ); + } + + optionsMap.put(optionName, valueExprLiteral); + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java index 657017a76b1db..c846da085040b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.expression.function.scalar.score.Decay; import java.util.ArrayList; import java.util.Collections; @@ -31,6 +32,9 @@ public static List getNamedWriteables() { if (EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled()) { entries.add(Score.ENTRY); } + if (EsqlCapabilities.Cap.DECAY_FUNCTION.isEnabled()) { + entries.add(Decay.ENTRY); + } return Collections.unmodifiableList(entries); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/Decay.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/Decay.java new file mode 100644 index 0000000000000..f691f6032d2dc --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/Decay.java @@ -0,0 +1,646 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.geo.GeoPoint; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.geometry.Point; +import org.elasticsearch.script.ScoreScriptUtils; +import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Options; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.esql.common.Failure.fail; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; +import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; +import static org.elasticsearch.xpack.esql.core.type.DataType.TIME_DURATION; +import static org.elasticsearch.xpack.esql.core.type.DataType.isDateNanos; +import static org.elasticsearch.xpack.esql.core.type.DataType.isGeoPoint; +import static org.elasticsearch.xpack.esql.core.type.DataType.isMillisOrNanos; +import static org.elasticsearch.xpack.esql.core.type.DataType.isSpatialPoint; +import static org.elasticsearch.xpack.esql.core.type.DataType.isTimeDuration; + +/** + * Decay a numeric, spatial or date type value based on the distance of it to an origin. + * + * This function uses the same {@link ScoreScriptUtils} implementations as Painless scripts, + * ensuring consistent decay calculations across ES|QL and script contexts. The decay + * functions support linear, exponential, and gaussian decay types for: + * - Numeric types (int, long, double) + * - Spatial types (geo_point, cartesian_point) + * - Temporal types (datetime, date_nanos) + */ +public class Decay extends EsqlScalarFunction implements OptionalArgument, PostOptimizationVerificationAware { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Decay", Decay::new); + + public static final String ORIGIN = "origin"; + public static final String SCALE = "scale"; + public static final String OFFSET = "offset"; + public static final String DECAY = "decay"; + public static final String TYPE = "type"; + + private static final Map> ALLOWED_OPTIONS = Map.of( + OFFSET, + Set.of(TIME_DURATION, INTEGER, LONG, DOUBLE, KEYWORD, TEXT), + DECAY, + Set.of(DOUBLE), + TYPE, + Set.of(KEYWORD) + ); + + // Default offsets + private static final Integer DEFAULT_INTEGER_OFFSET = 0; + private static final Long DEFAULT_LONG_OFFSET = 0L; + private static final Double DEFAULT_DOUBLE_OFFSET = 0.0; + private static final BytesRef DEFAULT_GEO_POINT_OFFSET = new BytesRef("0m"); + private static final Double DEFAULT_CARTESIAN_POINT_OFFSET = 0.0; + private static final Long DEFAULT_TEMPORAL_OFFSET = 0L; + + private static final Double DEFAULT_DECAY = 0.5; + + private static final BytesRef DEFAULT_FUNCTION = new BytesRef("linear"); + + private final Expression origin; + private final Expression value; + private final Expression scale; + private final Expression options; + + private final Map resolvedOptions; + + @FunctionInfo( + returnType = "double", + preview = true, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.PREVIEW, version = "9.2.0") }, + description = "Calculates a relevance score that decays based on the distance of a numeric, spatial or date type value " + + "from a target origin, using configurable decay functions.", + detailedDescription = """ + `DECAY` calculates a score between 0 and 1 based on how far a field value is from a specified origin point (called distance). + The distance can be a numeric distance, spatial distance or temporal distance depending on the specific data type. + + `DECAY` can use <> to specify additional `options` + for the decay function. + + For spatial queries, scale and offset for geo points use distance units (e.g., "10km", "5mi"), + while cartesian points use numeric values. For date queries, scale and offset use time_duration values. + For numeric queries you also use numeric values. + """, + examples = { @Example(file = "decay", tag = "decay") } + ) + public Decay( + Source source, + @Param( + name = "value", + type = { "double", "integer", "long", "date", "date_nanos", "geo_point", "cartesian_point" }, + description = "The input value to apply decay scoring to." + ) Expression value, + @Param( + name = ORIGIN, + type = { "double", "integer", "long", "date", "date_nanos", "geo_point", "cartesian_point" }, + description = "Central point from which the distances are calculated." + ) Expression origin, + @Param( + name = SCALE, + type = { "double", "integer", "long", "time_duration", "keyword", "text" }, + description = "Distance from the origin where the function returns the decay value." + ) Expression scale, + @MapParam( + name = "options", + params = { + @MapParam.MapParamEntry( + name = OFFSET, + type = { "double", "integer", "long", "time_duration", "keyword", "text" }, + description = "Distance from the origin where no decay occurs." + ), + @MapParam.MapParamEntry( + name = DECAY, + type = { "double" }, + description = "Multiplier value returned at the scale distance from the origin." + ), + @MapParam.MapParamEntry( + name = TYPE, + type = { "keyword" }, + description = "Decay function to use: linear, exponential or gaussian." + ) }, + optional = true + ) Expression options + ) { + super(source, options != null ? List.of(value, origin, scale, options) : List.of(value, origin, scale)); + this.value = value; + this.origin = origin; + this.scale = scale; + this.options = options; + this.resolvedOptions = new HashMap<>(); + } + + private Decay(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class), + in.readOptionalNamedWriteable(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(value); + out.writeNamedWriteable(origin); + out.writeNamedWriteable(scale); + out.writeOptionalNamedWriteable(options); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + return validateValue().and(() -> validateOriginAndScale(value.dataType())) + .and(() -> Options.resolveWithMultipleDataTypesAllowed(options, source(), FOURTH, ALLOWED_OPTIONS)); + } + + private TypeResolution validateValue() { + return isNotNull(value, sourceText(), FIRST).and( + isType(value, dt -> dt.isNumeric() || dt.isDate() || isSpatialPoint(dt), sourceText(), FIRST, "numeric, date or spatial point") + ); + } + + private TypeResolution validateOriginAndScale(DataType valueType) { + if (isSpatialPoint(valueType)) { + boolean isGeoPoint = isGeoPoint(valueType); + + return validateOriginAndScale( + DataType::isSpatialPoint, + "spatial point", + isGeoPoint ? DataType::isString : DataType::isNumeric, + isGeoPoint ? "keyword or text" : "numeric" + ); + } else if (isMillisOrNanos(valueType)) { + return validateOriginAndScale(DataType::isMillisOrNanos, "datetime or date_nanos", DataType::isTimeDuration, "time_duration"); + } else { + return validateOriginAndScale(DataType::isNumeric, "numeric", DataType::isNumeric, "numeric"); + } + } + + private TypeResolution validateOriginAndScale( + Predicate originPredicate, + String originDesc, + Predicate scalePredicate, + String scaleDesc + ) { + return isNotNull(origin, sourceText(), SECOND).and(isType(origin, originPredicate, sourceText(), SECOND, originDesc)) + .and(isNotNull(scale, sourceText(), THIRD)) + .and(isType(scale, scalePredicate, sourceText(), THIRD, scaleDesc)); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Decay(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), options != null ? newChildren.get(3) : null); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Decay::new, children().get(0), children().get(1), children().get(2), children().get(3)); + } + + @Override + public DataType dataType() { + return DOUBLE; + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + DataType valueDataType = value.dataType(); + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + (MapExpression) options, + resolvedOptions, + source(), + FOURTH, + ALLOWED_OPTIONS + ); + + EvalOperator.ExpressionEvaluator.Factory valueFactory = toEvaluator.apply(value); + + Expression offsetExpr = (Expression) resolvedOptions.get(OFFSET); + Expression decayExpr = (Expression) resolvedOptions.get(DECAY); + Expression typeExpr = (Expression) resolvedOptions.get(TYPE); + + FoldContext foldCtx = toEvaluator.foldCtx(); + + // Constants + Object originFolded = origin.fold(foldCtx); + Object scaleFolded = getFoldedScale(foldCtx, valueDataType); + Object offsetFolded = getOffset(foldCtx, valueDataType, offsetExpr); + Double decayFolded = decayExpr != null ? (Double) decayExpr.fold(foldCtx) : DEFAULT_DECAY; + DecayFunction decayFunction = DecayFunction.fromBytesRef(typeExpr != null ? (BytesRef) typeExpr.fold(foldCtx) : DEFAULT_FUNCTION); + + return switch (valueDataType) { + case INTEGER -> new DecayIntEvaluator.Factory( + source(), + valueFactory, + (Integer) originFolded, + (Integer) scaleFolded, + (Integer) offsetFolded, + decayFolded, + decayFunction + ); + case DOUBLE -> new DecayDoubleEvaluator.Factory( + source(), + valueFactory, + (Double) originFolded, + (Double) scaleFolded, + (Double) offsetFolded, + decayFolded, + decayFunction + ); + case LONG -> new DecayLongEvaluator.Factory( + source(), + valueFactory, + (Long) originFolded, + (Long) scaleFolded, + (Long) offsetFolded, + decayFolded, + decayFunction + ); + case GEO_POINT -> new DecayGeoPointEvaluator.Factory( + source(), + valueFactory, + (BytesRef) originFolded, + (BytesRef) scaleFolded, + (BytesRef) offsetFolded, + decayFolded, + decayFunction + ); + case CARTESIAN_POINT -> new DecayCartesianPointEvaluator.Factory( + source(), + valueFactory, + (BytesRef) originFolded, + (Double) scaleFolded, + (Double) offsetFolded, + decayFolded, + decayFunction + ); + case DATETIME -> new DecayDatetimeEvaluator.Factory( + source(), + valueFactory, + (Long) originFolded, + (Long) scaleFolded, + (Long) offsetFolded, + decayFolded, + decayFunction + ); + case DATE_NANOS -> new DecayDateNanosEvaluator.Factory( + source(), + valueFactory, + (Long) originFolded, + (Long) scaleFolded, + (Long) offsetFolded, + decayFolded, + decayFunction + ); + default -> throw new UnsupportedOperationException("Unsupported data typeExpr: " + valueDataType); + }; + } + + @Override + public void postOptimizationVerification(Failures failures) { + // Verify that "origin" and "scale" are literal values + Map.of(ORIGIN, origin, SCALE, scale).forEach((exprName, expr) -> { + if ((expr instanceof Literal) == false) { + failures.add(fail(expr, "Function [{}] has non-literal value [{}].", sourceText(), exprName)); + } + }); + } + + @Evaluator(extraName = "Int") + static double process( + int value, + @Fixed int origin, + @Fixed int scale, + @Fixed int offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + return decayFunction.numericDecay(value, origin, scale, offset, decay); + } + + @Evaluator(extraName = "Double") + static double process( + double value, + @Fixed double origin, + @Fixed double scale, + @Fixed double offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + return decayFunction.numericDecay(value, origin, scale, offset, decay); + } + + @Evaluator(extraName = "Long") + static double process( + long value, + @Fixed long origin, + @Fixed long scale, + @Fixed long offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + return decayFunction.numericDecay(value, origin, scale, offset, decay); + + } + + @Evaluator(extraName = "GeoPoint") + static double process( + BytesRef value, + @Fixed BytesRef origin, + @Fixed BytesRef scale, + @Fixed BytesRef offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + Point valuePoint = SpatialCoordinateTypes.UNSPECIFIED.wkbAsPoint(value); + GeoPoint valueGeoPoint = new GeoPoint(valuePoint.getY(), valuePoint.getX()); + + Point originPoint = SpatialCoordinateTypes.UNSPECIFIED.wkbAsPoint(origin); + GeoPoint originGeoPoint = new GeoPoint(originPoint.getY(), originPoint.getX()); + + String originStr = originGeoPoint.getX() + "," + originGeoPoint.getY(); + String scaleStr = scale.utf8ToString(); + String offsetStr = offset.utf8ToString(); + + return decayFunction.geoPointDecay(valueGeoPoint, originStr, scaleStr, offsetStr, decay); + } + + @Evaluator(extraName = "CartesianPoint") + static double processCartesianPoint( + BytesRef value, + @Fixed BytesRef origin, + @Fixed double scale, + @Fixed double offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + Point valuePoint = SpatialCoordinateTypes.UNSPECIFIED.wkbAsPoint(value); + Point originPoint = SpatialCoordinateTypes.UNSPECIFIED.wkbAsPoint(origin); + + // Euclidean distance + double dx = valuePoint.getX() - originPoint.getX(); + double dy = valuePoint.getY() - originPoint.getY(); + double distance = Math.sqrt(dx * dx + dy * dy); + + distance = Math.max(0.0, distance - offset); + + return decayFunction.cartesianDecay(distance, scale, offset, decay); + } + + @Evaluator(extraName = "Datetime", warnExceptions = { InvalidArgumentException.class, IllegalArgumentException.class }) + static double processDatetime( + long value, + @Fixed long origin, + @Fixed long scale, + @Fixed long offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + return decayFunction.temporalDecay(value, origin, scale, offset, decay); + } + + @Evaluator(extraName = "DateNanos", warnExceptions = { InvalidArgumentException.class, IllegalArgumentException.class }) + static double processDateNanos( + long value, + @Fixed long origin, + @Fixed long scale, + @Fixed long offset, + @Fixed double decay, + @Fixed DecayFunction decayFunction + ) { + return decayFunction.temporalDecay(value, origin, scale, offset, decay); + + } + + public enum DecayFunction { + LINEAR("linear") { + @Override + public double numericDecay(double value, double origin, double scale, double offset, double decay) { + return new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value); + } + + @Override + public double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay) { + return new ScoreScriptUtils.DecayGeoLinear(origin, scale, offset, decay).decayGeoLinear(value); + } + + @Override + public double cartesianDecay(double distance, double scale, double offset, double decay) { + double scaling = scale / (1.0 - decay); + return Math.max(0.0, (scaling - distance) / scaling); + } + + @Override + public double temporalDecay(long value, long origin, long scale, long offset, double decay) { + return decayDateLinear(origin, scale, offset, decay, value); + } + }, + + EXPONENTIAL("exp") { + @Override + public double numericDecay(double value, double origin, double scale, double offset, double decay) { + return new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value); + } + + @Override + public double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay) { + return new ScoreScriptUtils.DecayGeoExp(origin, scale, offset, decay).decayGeoExp(value); + } + + @Override + public double cartesianDecay(double distance, double scale, double offset, double decay) { + double scaling = Math.log(decay) / scale; + return Math.exp(scaling * distance); + } + + @Override + public double temporalDecay(long value, long origin, long scale, long offset, double decay) { + return decayDateExp(origin, scale, offset, decay, value); + } + }, + + GAUSSIAN("gauss") { + @Override + public double numericDecay(double value, double origin, double scale, double offset, double decay) { + return new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value); + } + + @Override + public double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay) { + return new ScoreScriptUtils.DecayGeoGauss(origin, scale, offset, decay).decayGeoGauss(value); + } + + @Override + public double cartesianDecay(double distance, double scale, double offset, double decay) { + double sigmaSquared = -Math.pow(scale, 2.0) / (2.0 * Math.log(decay)); + return Math.exp(-Math.pow(distance, 2.0) / (2.0 * sigmaSquared)); + } + + @Override + public double temporalDecay(long value, long origin, long scale, long offset, double decay) { + return decayDateGauss(origin, scale, offset, decay, value); + } + }; + + private final String functionName; + private static final Map BY_NAME = Arrays.stream(values()) + .collect(Collectors.toMap(df -> df.functionName, df -> df)); + + DecayFunction(String functionName) { + this.functionName = functionName; + } + + public abstract double numericDecay(double value, double origin, double scale, double offset, double decay); + + public abstract double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay); + + public abstract double cartesianDecay(double distance, double scale, double offset, double decay); + + public abstract double temporalDecay(long value, long origin, long scale, long offset, double decay); + + public static DecayFunction fromBytesRef(BytesRef functionType) { + return BY_NAME.getOrDefault(functionType.utf8ToString(), LINEAR); + } + } + + private static double decayDateLinear(long origin, long scale, long offset, double decay, long value) { + double scaling = scale / (1.0 - decay); + + long diff = (value >= origin) ? (value - origin) : (origin - value); + long distance = Math.max(0, diff - offset); + return Math.max(0.0, (scaling - distance) / scaling); + } + + private static double decayDateExp(long origin, long scale, long offset, double decay, long value) { + double scaling = Math.log(decay) / scale; + + long diff = (value >= origin) ? (value - origin) : (origin - value); + long distance = Math.max(0, diff - offset); + return Math.exp(scaling * distance); + } + + private static double decayDateGauss(long origin, long scale, long offset, double decay, long value) { + double scaling = 0.5 * Math.pow(scale, 2.0) / Math.log(decay); + + long diff = (value >= origin) ? (value - origin) : (origin - value); + long distance = Math.max(0, diff - offset); + return Math.exp(0.5 * Math.pow(distance, 2.0) / scaling); + } + + private Object getOffset(FoldContext foldCtx, DataType valueDataType, Expression offset) { + if (offset == null) { + return getDefaultOffset(valueDataType); + } + + if (isTimeDuration(offset.dataType()) == false) { + return offset.fold(foldCtx); + } + + if (isDateNanos(valueDataType)) { + return getTemporalOffsetAsNanos(foldCtx, offset); + } + + return getTemporalOffsetAsMillis(foldCtx, offset); + } + + private Object getFoldedScale(FoldContext foldCtx, DataType valueDataType) { + Object foldedScale = scale.fold(foldCtx); + + if (isTimeDuration(scale.dataType()) == false) { + return foldedScale; + } + + if (isDateNanos(valueDataType)) { + return ((Duration) foldedScale).toNanos(); + } + + return ((Duration) foldedScale).toMillis(); + } + + private Long getTemporalOffsetAsMillis(FoldContext foldCtx, Expression offset) { + Object foldedOffset = offset.fold(foldCtx); + return ((Duration) foldedOffset).toMillis(); + } + + private Long getTemporalOffsetAsNanos(FoldContext foldCtx, Expression offset) { + Object foldedOffset = offset.fold(foldCtx); + Duration offsetDuration = (Duration) foldedOffset; + return offsetDuration.toNanos(); + } + + private Object getDefaultOffset(DataType valueDataType) { + return switch (valueDataType) { + case INTEGER -> DEFAULT_INTEGER_OFFSET; + case LONG -> DEFAULT_LONG_OFFSET; + case DOUBLE -> DEFAULT_DOUBLE_OFFSET; + case GEO_POINT -> DEFAULT_GEO_POINT_OFFSET; + case CARTESIAN_POINT -> DEFAULT_CARTESIAN_POINT_OFFSET; + case DATETIME, DATE_NANOS -> DEFAULT_TEMPORAL_OFFSET; + default -> throw new UnsupportedOperationException("Unsupported data type: " + valueDataType); + }; + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index d5d3507928e84..6288cc9c23d87 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -2334,6 +2334,40 @@ public void testRemoteLookupJoinIsDisabled() { assertThat(e.getMessage(), containsString("remote clusters are not supported with LOOKUP JOIN")); } + public void testDecayFunctionNullArgs() { + assumeTrue("Decay function not enabled", EsqlCapabilities.Cap.DECAY_FUNCTION.isEnabled()); + + // First arg cannot be null + assertEquals( + "2:23: first argument of [decay(null, origin, scale, " + + "{\"offset\": 0, \"decay\": 0.5, \"type\": \"linear\"})] cannot be null, received [null]", + error( + "row origin = 10, scale = 10\n" + + "| eval decay_result = decay(null, origin, scale, {\"offset\": 0, \"decay\": 0.5, \"type\": \"linear\"})" + ) + ); + + // Second arg cannot be null + assertEquals( + "2:23: second argument of [decay(value, null, scale, " + + "{\"offset\": 0, \"decay\": 0.5, \"type\": \"linear\"})] cannot be null, received [null]", + error( + "row value = 10, scale = 10\n" + + "| eval decay_result = decay(value, null, scale, {\"offset\": 0, \"decay\": 0.5, \"type\": \"linear\"})" + ) + ); + + // Third arg cannot be null + assertEquals( + "2:23: third argument of [decay(value, origin, null, " + + "{\"offset\": 0, \"decay\": 0.5, \"type\": \"linear\"})] cannot be null, received [null]", + error( + "row value = 10, origin = 10\n" + + "| eval decay_result = decay(value, origin, null, {\"offset\": 0, \"decay\": 0.5, \"type\": \"linear\"})" + ) + ); + } + private void checkFullTextFunctionsInStats(String functionInvocation) { query("from test | stats c = max(id) where " + functionInvocation, fullTextAnalyzer); query("from test | stats c = max(id) where " + functionInvocation + " or length(title) > 10", fullTextAnalyzer); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/OptionsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/OptionsTests.java new file mode 100644 index 0000000000000..b3f639f50a88a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/OptionsTests.java @@ -0,0 +1,546 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class OptionsTests extends ESTestCase { + + public void testNullOptions_SingleDataTypeAllowed() { + Map allowedOptions = Map.of("keyword_option", DataType.KEYWORD); + Expression.TypeResolution resolution = Options.resolve(null, Source.EMPTY, TypeResolutions.ParamOrdinal.DEFAULT, allowedOptions); + + assertTrue(resolution.resolved()); + } + + public void testSingleEntryOptions_SingleDataTypeAllowed_ShouldResolve() { + Map allowedOptions = Map.of("keyword_option", DataType.KEYWORD); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ); + Expression.TypeResolution resolution = Options.resolve( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + } + + public void testSingleEntryOptions_SingleDataTypeAllowed_UnknownOption_ShouldNotResolve() { + Map allowedOptions = Map.of("keyword_option", DataType.KEYWORD); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "unknown_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ); + Expression.TypeResolution resolution = Options.resolve( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testMultipleEntryOptions_SingleDataTypeAllowed_ShouldResolve() { + Map allowedOptions = Map.of( + "keyword_option", + DataType.KEYWORD, + "int_option", + DataType.INTEGER, + "double_option", + DataType.DOUBLE + ); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "keyword_option"), + Literal.keyword(Source.EMPTY, randomAlphaOfLength(10)), + Literal.keyword(Source.EMPTY, "int_option"), + Literal.integer(Source.EMPTY, 1), + Literal.keyword(Source.EMPTY, "double_option"), + Literal.fromDouble(Source.EMPTY, 1.0) + ) + ); + Expression.TypeResolution resolution = Options.resolve( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + } + + public void testMultipleEntryOptions_SingleDataTypeAllowed_UnknownOption_ShouldNotResolve() { + Map allowedOptions = Map.of( + "keyword_option", + DataType.KEYWORD, + "int_option", + DataType.INTEGER, + "double_option", + DataType.DOUBLE + ); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "unknown_option"), + Literal.keyword(Source.EMPTY, randomAlphaOfLength(10)), + Literal.keyword(Source.EMPTY, "int_option"), + Literal.integer(Source.EMPTY, 1), + Literal.keyword(Source.EMPTY, "double_option"), + Literal.fromDouble(Source.EMPTY, 1.0) + ) + ); + Expression.TypeResolution resolution = Options.resolve( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testSingleEntryOptions_NullDataType_ShouldNotResolve() { + Map allowedOptions = new HashMap<>(); + allowedOptions.put("keyword_option", null); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ); + Expression.TypeResolution resolution = Options.resolve( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testSingleEntryOptions_SingleDataTypeAllowed_MapExpressionAsValue_ShouldNotResolve() { + Map allowedOptions = Map.of("map_option", DataType.OBJECT); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "map_option"), + new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "some_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ) + ) + ); + Expression.TypeResolution resolution = Options.resolve( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testNullOptions_MultipleDataTypesAllowed() { + Map> allowedOptions = Map.of("keyword_text_option", List.of(DataType.KEYWORD)); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + null, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + } + + public void testSingleEntryOptions_MultipleDataTypesAllowed_ShouldResolve() { + Map> allowedOptions = Map.of("keyword_text_option", List.of(DataType.KEYWORD, DataType.TEXT)); + + // Keyword resolution + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_text_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + + // Text resolution + mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_text_option"), Literal.text(Source.EMPTY, randomAlphaOfLength(10))) + ); + resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + } + + public void testSingleEntryOptions_MultipleDataTypesAllowed_UnknownOption_ShouldNotResolve() { + Map> allowedOptions = Map.of("keyword_string_option", List.of(DataType.KEYWORD, DataType.TEXT)); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "unknown_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testMultipleEntryOptions_MultipleDataTypesAllowed_ShouldResolve() { + Map> allowedOptions = Map.of( + "keyword_text_option", + List.of(DataType.KEYWORD, DataType.TEXT), + "double_int_option", + List.of(DataType.DOUBLE, DataType.INTEGER) + ); + + // Keyword & double resolution + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "keyword_text_option"), + Literal.keyword(Source.EMPTY, randomAlphaOfLength(10)), + Literal.keyword(Source.EMPTY, "double_int_option"), + Literal.integer(Source.EMPTY, randomInt()) + ) + ); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + + // Text & double resolution + mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "keyword_text_option"), + Literal.text(Source.EMPTY, randomAlphaOfLength(10)), + Literal.keyword(Source.EMPTY, "double_int_option"), + Literal.fromDouble(Source.EMPTY, randomDouble()) + ) + ); + resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.resolved()); + } + + public void testMultipleEntryOptions_MultipleDataTypesAllowed_UnknownOption_ShouldNotResolve() { + Map> allowedOptions = Map.of( + "keyword_text_option", + List.of(DataType.KEYWORD, DataType.TEXT), + "double_int_option", + List.of(DataType.DOUBLE, DataType.INTEGER) + ); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "unknown_option"), + Literal.keyword(Source.EMPTY, randomAlphaOfLength(10)), + Literal.keyword(Source.EMPTY, "double_int_option"), + Literal.integer(Source.EMPTY, randomInt()) + ) + ); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testSingleEntryOptions_MultipleDataTypesAllowed_NullDataType_ShouldNotResolve() { + Collection allowedDataTypes = new ArrayList<>(); + allowedDataTypes.add(null); + + Map> allowedOptions = Map.of("null_option", allowedDataTypes); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "null_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testSingleEntryOptions_MultipleDataTypeAllowed_MapExpressionAsValue_ShouldNotResolve() { + Map> allowedOptions = Map.of("map_option", List.of(DataType.OBJECT, DataType.TEXT)); + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "map_option"), + new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "some_option"), Literal.keyword(Source.EMPTY, randomAlphaOfLength(10))) + ) + ) + ); + Expression.TypeResolution resolution = Options.resolveWithMultipleDataTypesAllowed( + mapExpression, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertTrue(resolution.unresolved()); + } + + public void testPopulateMapWithExpressions_SingleEntry_KeywordDataType() throws InvalidArgumentException { + Map> allowedOptions = Map.of("keyword_option", List.of(DataType.KEYWORD)); + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_option"), Literal.keyword(Source.EMPTY, "test_value")) + ); + + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertEquals(1, optionsMap.size()); + assertTrue(optionsMap.containsKey("keyword_option")); + assertTrue(optionsMap.get("keyword_option") instanceof Literal); + Literal storedLiteral = (Literal) optionsMap.get("keyword_option"); + assertEquals(DataType.KEYWORD, storedLiteral.dataType()); + assertEquals("test_value", ((BytesRef) storedLiteral.value()).utf8ToString()); + } + + public void testPopulateMapWithExpressions_SingleEntry_MultipleAllowedDataTypes_Keyword() throws InvalidArgumentException { + Map> allowedOptions = Map.of("keyword_text_option", List.of(DataType.KEYWORD, DataType.TEXT)); + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_text_option"), Literal.keyword(Source.EMPTY, "keyword_value")) + ); + + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertEquals(1, optionsMap.size()); + assertTrue(optionsMap.containsKey("keyword_text_option")); + Literal storedLiteral = (Literal) optionsMap.get("keyword_text_option"); + assertEquals(DataType.KEYWORD, storedLiteral.dataType()); + assertEquals("keyword_value", ((BytesRef) storedLiteral.value()).utf8ToString()); + } + + public void testPopulateMapWithExpressions_MultipleEntries() throws InvalidArgumentException { + Map> allowedOptions = Map.of( + "keyword_text_option", + List.of(DataType.KEYWORD, DataType.TEXT), + "double_int_option", + List.of(DataType.DOUBLE, DataType.INTEGER) + ); + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of( + Literal.keyword(Source.EMPTY, "keyword_text_option"), + Literal.keyword(Source.EMPTY, "keyword_value"), + Literal.keyword(Source.EMPTY, "double_int_option"), + Literal.integer(Source.EMPTY, 42) + ) + ); + + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + + assertEquals(2, optionsMap.size()); + + // Check first option + assertTrue(optionsMap.containsKey("keyword_text_option")); + Literal firstLiteral = (Literal) optionsMap.get("keyword_text_option"); + assertEquals(DataType.KEYWORD, firstLiteral.dataType()); + assertEquals("keyword_value", ((BytesRef) firstLiteral.value()).utf8ToString()); + + // Check second option + assertTrue(optionsMap.containsKey("double_int_option")); + Literal secondLiteral = (Literal) optionsMap.get("double_int_option"); + assertEquals(DataType.INTEGER, secondLiteral.dataType()); + assertEquals(42, secondLiteral.value()); + } + + public void testPopulateMapWithExpressions_UnknownOption_ShouldThrowException() { + Map> allowedOptions = Map.of("known_option", List.of(DataType.KEYWORD)); + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "unknown_option"), Literal.keyword(Source.EMPTY, "value")) + ); + + InvalidArgumentException exception = assertThrows(InvalidArgumentException.class, () -> { + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + }); + + assertTrue(exception.getMessage().contains("Invalid option [unknown_option]")); + assertTrue(exception.getMessage().contains("expected one of [known_option]")); + } + + public void testPopulateMapWithExpressions_WrongDataType_ShouldThrowException() { + Map> allowedOptions = Map.of("keyword_only_option", List.of(DataType.KEYWORD)); + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "keyword_only_option"), Literal.text(Source.EMPTY, "text_value")) + ); + + InvalidArgumentException exception = assertThrows(InvalidArgumentException.class, () -> { + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + }); + + assertTrue(exception.getMessage().contains("Invalid option [keyword_only_option]")); + assertTrue(exception.getMessage().contains("allowed types")); + } + + public void testPopulateMapWithExpressions_EmptyAllowedDataTypes_ShouldThrowException() { + Map> allowedOptions = Map.of("empty_option", List.of()); + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "empty_option"), Literal.keyword(Source.EMPTY, "value")) + ); + + InvalidArgumentException exception = assertThrows(InvalidArgumentException.class, () -> { + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + }); + + assertTrue(exception.getMessage().contains("Invalid option [empty_option]")); + } + + public void testPopulateMapWithExpressions_NullAllowedDataTypes_ShouldThrowException() { + Map> allowedOptions = new HashMap<>(); + allowedOptions.put("null_option", null); + + Map optionsMap = new HashMap<>(); + + MapExpression mapExpression = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "null_option"), Literal.keyword(Source.EMPTY, "value")) + ); + + InvalidArgumentException exception = assertThrows(InvalidArgumentException.class, () -> { + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + }); + + assertTrue(exception.getMessage().contains("Invalid option [null_option]")); + } + + public void testPopulateMapWithExpressions_NonLiteralValue_ShouldThrowException() { + Map> allowedOptions = Map.of("map_option", List.of(DataType.OBJECT)); + Map optionsMap = new HashMap<>(); + + MapExpression nestedMap = new MapExpression( + Source.EMPTY, + List.of(Literal.keyword(Source.EMPTY, "nested_key"), Literal.keyword(Source.EMPTY, "nested_value")) + ); + + MapExpression mapExpression = new MapExpression(Source.EMPTY, List.of(Literal.keyword(Source.EMPTY, "map_option"), nestedMap)); + + InvalidArgumentException exception = assertThrows(InvalidArgumentException.class, () -> { + Options.populateMapWithExpressionsMultipleDataTypesAllowed( + mapExpression, + optionsMap, + Source.EMPTY, + TypeResolutions.ParamOrdinal.DEFAULT, + allowedOptions + ); + }); + + assertTrue(exception.getMessage().contains("Invalid option [map_option]")); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java new file mode 100644 index 0000000000000..fda42d74a0e30 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java @@ -0,0 +1,1225 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.score; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.common.geo.GeoPoint; +import org.elasticsearch.common.unit.DistanceUnit; +import org.elasticsearch.script.ScoreScriptUtils; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.CARTESIAN; +import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.GEO; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.startsWith; + +public class DecayTests extends AbstractScalarFunctionTestCase { + + public DecayTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List testCaseSuppliers = new ArrayList<>(); + + // Int Linear + testCaseSuppliers.addAll(intTestCase(0, 0, 10, 5, 0.5, "linear", 1.0)); + testCaseSuppliers.addAll(intTestCase(10, 0, 10, 5, 0.5, "linear", 0.75)); + testCaseSuppliers.addAll(intTestCase(50, 5, 100, 10, 0.25, "linear", 0.7375)); + testCaseSuppliers.addAll(intTestCase(100, 17, 156, 23, 0.123, "linear", 0.6626923076923077)); + testCaseSuppliers.addAll(intTestCase(2500, 0, 10, 0, 0.5, "linear", 0.0)); + + // Int Exponential + testCaseSuppliers.addAll(intTestCase(0, 0, 10, 5, 0.5, "exp", 1.0)); + testCaseSuppliers.addAll(intTestCase(10, 0, 10, 5, 0.5, "exp", 0.7071067811865475)); + testCaseSuppliers.addAll(intTestCase(50, 5, 100, 10, 0.25, "exp", 0.6155722066724582)); + testCaseSuppliers.addAll(intTestCase(100, 17, 156, 23, 0.123, "exp", 0.4466460570185927)); + testCaseSuppliers.addAll(intTestCase(2500, 0, 10, 0, 0.5, "exp", 5.527147875260539E-76)); + + // Int Gaussian + testCaseSuppliers.addAll(intTestCase(0, 0, 10, 5, 0.5, "gauss", 1.0)); + testCaseSuppliers.addAll(intTestCase(10, 0, 10, 5, 0.5, "gauss", 0.8408964152537146)); + testCaseSuppliers.addAll(intTestCase(50, 5, 100, 10, 0.25, "gauss", 0.8438157961300179)); + testCaseSuppliers.addAll(intTestCase(100, 17, 156, 23, 0.123, "gauss", 0.7334501109633149)); + testCaseSuppliers.addAll(intTestCase(2500, 0, 10, 0, 0.5, "gauss", 0.0)); + + // Int defaults + testCaseSuppliers.addAll(intTestCase(10, 0, 10, null, null, null, 0.5)); + + // Int random + testCaseSuppliers.addAll(intRandomTestCases()); + + // Long Linear + testCaseSuppliers.addAll(longTestCase(0L, 10L, 10000000L, 200L, 0.33, "linear", 1.0)); + testCaseSuppliers.addAll(longTestCase(10L, 10L, 10000000L, 200L, 0.33, "linear", 1.0)); + testCaseSuppliers.addAll(longTestCase(50000L, 10L, 10000000L, 200L, 0.33, "linear", 0.99666407)); + testCaseSuppliers.addAll(longTestCase(300000L, 10L, 10000000L, 200L, 0.33, "linear", 0.97991407)); + testCaseSuppliers.addAll(longTestCase(123456789112123L, 10L, 10000000L, 200L, 0.33, "linear", 0.0)); + + // Long Exponential + testCaseSuppliers.addAll(longTestCase(0L, 10L, 10000000L, 200L, 0.33, "exp", 1.0)); + testCaseSuppliers.addAll(longTestCase(10L, 10L, 10000000L, 200L, 0.33, "exp", 1.0)); + testCaseSuppliers.addAll(longTestCase(50000L, 10L, 10000000L, 200L, 0.33, "exp", 0.9944951761701727)); + testCaseSuppliers.addAll(longTestCase(300000L, 10L, 10000000L, 200L, 0.33, "exp", 0.9673096701204178)); + testCaseSuppliers.addAll(longTestCase(123456789112123L, 10L, 10000000L, 200L, 0.33, "exp", 0.0)); + + // Long Gaussian + testCaseSuppliers.addAll(longTestCase(0L, 10L, 10000000L, 200L, 0.33, "gauss", 1.0)); + testCaseSuppliers.addAll(longTestCase(10L, 10L, 10000000L, 200L, 0.33, "gauss", 1.0)); + testCaseSuppliers.addAll(longTestCase(50000L, 10L, 10000000L, 200L, 0.33, "gauss", 0.999972516142306)); + testCaseSuppliers.addAll(longTestCase(300000L, 10L, 10000000L, 200L, 0.33, "gauss", 0.9990040963055015)); + testCaseSuppliers.addAll(longTestCase(123456789112123L, 10L, 10000000L, 200L, 0.33, "gauss", 0.0)); + + // Long defaults + testCaseSuppliers.addAll(longTestCase(10L, 0L, 10L, null, null, null, 0.5)); + + // Long random + testCaseSuppliers.addAll(longRandomTestCases()); + + // Double Linear + testCaseSuppliers.addAll(doubleTestCase(0.0, 10.0, 10000000.0, 200.0, 0.25, "linear", 1.0)); + testCaseSuppliers.addAll(doubleTestCase(10.0, 10.0, 10000000.0, 200.0, 0.25, "linear", 1.0)); + testCaseSuppliers.addAll(doubleTestCase(50000.0, 10.0, 10000000.0, 200.0, 0.25, "linear", 0.99626575)); + testCaseSuppliers.addAll(doubleTestCase(300000.0, 10.0, 10000000.0, 200.0, 0.25, "linear", 0.97751575)); + testCaseSuppliers.addAll(doubleTestCase(123456789112.123, 10.0, 10000000.0, 200.0, 0.25, "linear", 0.0)); + + // Double Exponential + testCaseSuppliers.addAll(doubleTestCase(0.0, 10.0, 10000000.0, 200.0, 0.25, "exp", 1.0)); + testCaseSuppliers.addAll(doubleTestCase(10.0, 10.0, 10000000.0, 200.0, 0.25, "exp", 1.0)); + testCaseSuppliers.addAll(doubleTestCase(50000.0, 10.0, 10000000.0, 200.0, 0.25, "exp", 0.9931214069469289)); + testCaseSuppliers.addAll(doubleTestCase(300000.0, 10.0, 10000000.0, 200.0, 0.25, "exp", 0.959292046002994)); + testCaseSuppliers.addAll(doubleTestCase(123456789112.123, 10.0, 10000000.0, 200.0, 0.25, "exp", 0.0)); + + // Double Gaussian + testCaseSuppliers.addAll(doubleTestCase(0.0, 10.0, 10000000.0, 200.0, 0.25, "gauss", 1.0)); + testCaseSuppliers.addAll(doubleTestCase(10.0, 10.0, 10000000.0, 200.0, 0.25, "gauss", 1.0)); + testCaseSuppliers.addAll(doubleTestCase(50000.0, 10.0, 10000000.0, 200.0, 0.25, "gauss", 0.9999656337419655)); + testCaseSuppliers.addAll(doubleTestCase(300000.0, 10.0, 10000000.0, 200.0, 0.25, "gauss", 0.9987548570291238)); + testCaseSuppliers.addAll(doubleTestCase(123456789112.123, 10.0, 10000000.0, 200.0, 0.25, "gauss", 0.0)); + + // Double defaults + testCaseSuppliers.addAll(doubleTestCase(10.0, 0.0, 10.0, null, null, null, 0.5)); + + // Double random + testCaseSuppliers.addAll(doubleRandomTestCases()); + + // GeoPoint Linear + testCaseSuppliers.addAll(geoPointTestCase("POINT (1.0 1.0)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 1.0)); + testCaseSuppliers.addAll(geoPointTestCase("POINT (0 0)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 0.9901342769495362)); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (12.3 45.6)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 0.6602313771587869) + ); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (180.0 90.0)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 0.33761373954395957) + ); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (-180.0 -90.0)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 0.32271359885955425) + ); + + // GeoPoint Exponential + testCaseSuppliers.addAll(geoPointTestCase("POINT (1.0 1.0)", "POINT (1 1)", "10000km", "10km", 0.33, "exp", 1.0)); + testCaseSuppliers.addAll(geoPointTestCase("POINT (0 0)", "POINT (1 1)", "10000km", "10km", 0.33, "exp", 0.983807518295976)); + testCaseSuppliers.addAll(geoPointTestCase("POINT (12.3 45.6)", "POINT (1 1)", "10000km", "10km", 0.33, "exp", 0.5699412181941212)); + testCaseSuppliers.addAll(geoPointTestCase("POINT (180.0 90.0)", "POINT (1 1)", "10000km", "10km", 0.33, "exp", 0.3341838411351592)); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (-180.0 -90.0)", "POINT (1 1)", "10000km", "10km", 0.33, "exp", 0.32604509444656576) + ); + + // GeoPoint Gaussian + testCaseSuppliers.addAll(geoPointTestCase("POINT (1.0 1.0)", "POINT (1 1)", "10000km", "10km", 0.33, "gauss", 1.0)); + testCaseSuppliers.addAll(geoPointTestCase("POINT (0 0)", "POINT (1 1)", "10000km", "10km", 0.33, "gauss", 0.9997596437370099)); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (12.3 45.6)", "POINT (1 1)", "10000km", "10km", 0.33, "gauss", 0.7519296165431535) + ); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (180.0 90.0)", "POINT (1 1)", "10000km", "10km", 0.33, "gauss", 0.33837227875395753) + ); + testCaseSuppliers.addAll( + geoPointTestCase("POINT (-180.0 -90.0)", "POINT (1 1)", "10000km", "10km", 0.33, "gauss", 0.3220953501115956) + ); + + // GeoPoint offset & scale as keywords + testCaseSuppliers.addAll(geoPointTestCaseKeywordScale("POINT (1 1)", "POINT (1 1)", "200km", "0km", 0.5, "linear", 1.0)); + testCaseSuppliers.addAll(geoPointOffsetKeywordTestCase("POINT (1 1)", "POINT (1 1)", "200km", "0km", 0.5, "linear", 1.0)); + + // GeoPoint defaults + testCaseSuppliers.addAll(geoPointTestCase("POINT (12.3 45.6)", "POINT (1 1)", "10000km", null, null, null, 0.7459413262379005)); + + // GeoPoint random + testCaseSuppliers.addAll(geoPointRandomTestCases()); + + // CartesianPoint Linear + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (0 0)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 1.0)); + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (1 1)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 1.0)); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (1000 2000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 0.8509433324420796) + ); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (-2000 1000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 0.8508234552350306) + ); + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (10000 20000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 0.0)); + + // CartesianPoint Exponential + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (0 0)", "POINT (1 1)", 10000.0, 10.0, 0.33, "exp", 1.0)); + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (1 1)", "POINT (1 1)", 10000.0, 10.0, 0.33, "exp", 1.0)); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (1000 2000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "exp", 0.7814164075951677) + ); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (-2000 1000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "exp", 0.7812614186677811) + ); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (10000 20000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "exp", 0.0839287052363121) + ); + + // CartesianPoint Gaussian + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (0 0)", "POINT (1 1)", 10000.0, 10.0, 0.33, "gauss", 1.0)); + testCaseSuppliers.addAll(cartesianPointTestCase("POINT (1 1)", "POINT (1 1)", 10000.0, 10.0, 0.33, "gauss", 1.0)); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (1000 2000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "gauss", 0.9466060873472042) + ); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (-2000 1000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "gauss", 0.9465225092376659) + ); + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (10000 20000)", "POINT (1 1)", 10000.0, 10.0, 0.33, "gauss", 0.003935602627423666) + ); + + // CartesianPoint defaults + testCaseSuppliers.addAll( + cartesianPointTestCase("POINT (1000.0 2000.0)", "POINT (0 0)", 10000.0, null, null, null, 0.8881966011250104) + ); + + // Datetime Linear + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 1.0 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.49569100000000005 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.37334900000000004 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(1970, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.28202800000000006 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(1900, 12, 12, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.0 + ) + ); + + // Datetime Exponential + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 1.0 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.4340956586740692 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.3545406919498116 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(1970, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.30481724812400407 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(1900, 12, 12, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.01813481247808857 + ) + ); + + // Datetime Gaussian + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 1.0 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 0.5335935393743785 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 0.3791426943809958 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(1970, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 0.27996050542437345 + ) + ); + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(1900, 12, 12, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 5.025924031342025E-7 + ) + ); + + // Datetime Defaults + testCaseSuppliers.addAll( + datetimeTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + null, + null, + null, + 0.62315 + ) + ); + + // Datetime random + testCaseSuppliers.addAll(datetimeRandomTestCases()); + + // Datenanos Linear + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 1.0 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.49569100000000005 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.37334900000000004 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(1970, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.28202800000000006 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(1900, 12, 12, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "linear", + 0.0 + ) + ); + + // Datenanos Exponential + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 1.0 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.4340956586740692 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.3545406919498116 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(1970, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.30481724812400407 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(1900, 12, 12, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "exp", + 0.01813481247808857 + ) + ); + + // Datenanos Gaussian + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 1.0 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2020, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 0.5335935393743785 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 0.3791426943809958 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(1970, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 0.27996050542437345 + ) + ); + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(1900, 12, 12, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + Duration.ofDays(10), + 0.33, + "gauss", + 5.025924031342025E-7 + ) + ); + + // Datenanos default + testCaseSuppliers.addAll( + dateNanosTestCase( + LocalDateTime.of(2025, 8, 20, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + LocalDateTime.of(2000, 1, 1, 0, 0, 0).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), + Duration.ofDays(10000), + null, + null, + null, + 0.53185 + ) + ); + + // Datenanos random + testCaseSuppliers.addAll(dateNanosRandomTestCases()); + + return parameterSuppliersFromTypedData(testCaseSuppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Decay(source, args.get(0), args.get(1), args.get(2), args.get(3) != null ? args.get(3) : null); + } + + @Override + public void testFold() { + // Decay cannot be folded + } + + private static List intTestCase( + int value, + int origin, + int scale, + Integer offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.INTEGER, DataType.INTEGER, DataType.INTEGER, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(value, DataType.INTEGER, "value"), + new TestCaseSupplier.TypedData(origin, DataType.INTEGER, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.INTEGER, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayIntEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List intRandomTestCases() { + return List.of(new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER, DataType.INTEGER, DataType.SOURCE), () -> { + int randomValue = randomInt(); + int randomOrigin = randomInt(); + int randomScale = randomInt(); + int randomOffset = randomInt(); + double randomDecay = randomDouble(); + String randomType = getRandomType(); + + double scoreScriptNumericResult = intDecayWithScoreScript( + randomValue, + randomOrigin, + randomScale, + randomOffset, + randomDecay, + randomType + ); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomValue, DataType.INTEGER, "value"), + new TestCaseSupplier.TypedData(randomOrigin, DataType.INTEGER, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(randomScale, DataType.INTEGER, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayIntEvaluator["), + DataType.DOUBLE, + closeTo(scoreScriptNumericResult, Math.ulp(scoreScriptNumericResult)) + ); + })); + } + + private static String getRandomType() { + return randomFrom("linear", "gauss", "exp"); + } + + private static double intDecayWithScoreScript(int value, int origin, int scale, int offset, double decay, String type) { + return switch (type) { + case "linear" -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value); + case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value); + case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value); + default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); + }; + } + + private static List longTestCase( + long value, + long origin, + long scale, + Long offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.LONG, DataType.LONG, DataType.LONG, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(value, DataType.LONG, "value"), + new TestCaseSupplier.TypedData(origin, DataType.LONG, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.LONG, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayLongEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List longRandomTestCases() { + return List.of(new TestCaseSupplier(List.of(DataType.LONG, DataType.LONG, DataType.LONG, DataType.SOURCE), () -> { + long randomValue = randomLong(); + long randomOrigin = randomLong(); + long randomScale = randomLong(); + long randomOffset = randomLong(); + double randomDecay = randomDouble(); + String randomType = randomFrom("linear", "gauss", "exp"); + + double scoreScriptNumericResult = longDecayWithScoreScript( + randomValue, + randomOrigin, + randomScale, + randomOffset, + randomDecay, + randomType + ); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomValue, DataType.LONG, "value"), + new TestCaseSupplier.TypedData(randomOrigin, DataType.LONG, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(randomScale, DataType.LONG, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayLongEvaluator["), + DataType.DOUBLE, + equalTo(scoreScriptNumericResult) + ); + })); + } + + private static double longDecayWithScoreScript(long value, long origin, long scale, long offset, double decay, String type) { + return switch (type) { + case "linear" -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value); + case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value); + case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value); + default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); + }; + } + + private static List doubleTestCase( + double value, + double origin, + double scale, + Double offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.DOUBLE, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(value, DataType.DOUBLE, "value"), + new TestCaseSupplier.TypedData(origin, DataType.DOUBLE, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.DOUBLE, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayDoubleEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List doubleRandomTestCases() { + return List.of(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.DOUBLE, DataType.SOURCE), () -> { + double randomValue = randomLong(); + double randomOrigin = randomLong(); + double randomScale = randomLong(); + double randomOffset = randomLong(); + double randomDecay = randomDouble(); + String randomType = randomFrom("linear", "gauss", "exp"); + + double scoreScriptNumericResult = doubleDecayWithScoreScript( + randomValue, + randomOrigin, + randomScale, + randomOffset, + randomDecay, + randomType + ); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomValue, DataType.DOUBLE, "value"), + new TestCaseSupplier.TypedData(randomOrigin, DataType.DOUBLE, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(randomScale, DataType.DOUBLE, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayDoubleEvaluator["), + DataType.DOUBLE, + closeTo(scoreScriptNumericResult, Math.ulp(scoreScriptNumericResult)) + ); + })); + } + + private static double doubleDecayWithScoreScript(double value, double origin, double scale, double offset, double decay, String type) { + return switch (type) { + case "linear" -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value); + case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value); + case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value); + default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); + }; + } + + private static List geoPointTestCase( + String valueWkt, + String originWkt, + String scale, + String offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.GEO_POINT, DataType.GEO_POINT, DataType.TEXT, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(GEO.wktToWkb(valueWkt), DataType.GEO_POINT, "value"), + new TestCaseSupplier.TypedData(GEO.wktToWkb(originWkt), DataType.GEO_POINT, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.TEXT, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayGeoPointEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List geoPointTestCaseKeywordScale( + String valueWkt, + String originWkt, + String scale, + String offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.GEO_POINT, DataType.GEO_POINT, DataType.KEYWORD, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(GEO.wktToWkb(valueWkt), DataType.GEO_POINT, "value"), + new TestCaseSupplier.TypedData(GEO.wktToWkb(originWkt), DataType.GEO_POINT, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.KEYWORD, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayGeoPointEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List geoPointRandomTestCases() { + return List.of(new TestCaseSupplier(List.of(DataType.GEO_POINT, DataType.GEO_POINT, DataType.KEYWORD, DataType.SOURCE), () -> { + GeoPoint randomValue = randomGeoPoint(); + GeoPoint randomOrigin = randomGeoPoint(); + String randomScale = randomDistance(); + String randomOffset = randomDistance(); + double randomDecay = randomDouble(); + String randomType = randomDecayType(); + + double scoreScriptNumericResult = geoPointDecayWithScoreScript( + randomValue, + randomOrigin, + randomScale, + randomOffset, + randomDecay, + randomType + ); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(GEO.wktToWkb(randomValue.toWKT()), DataType.GEO_POINT, "value"), + new TestCaseSupplier.TypedData(GEO.wktToWkb(randomOrigin.toWKT()), DataType.GEO_POINT, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(randomScale, DataType.KEYWORD, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayGeoPointEvaluator["), + DataType.DOUBLE, + closeTo(scoreScriptNumericResult, Math.ulp(scoreScriptNumericResult)) + ); + })); + } + + private static String randomDecayType() { + return randomFrom("linear", "gauss", "exp"); + } + + private static GeoPoint randomGeoPoint() { + return new GeoPoint(randomLatitude(), randomLongitude()); + } + + private static double randomLongitude() { + return randomDoubleBetween(-180.0, 180.0, true); + } + + private static double randomLatitude() { + return randomDoubleBetween(-90.0, 90.0, true); + } + + private static String randomDistance() { + return String.format( + Locale.ROOT, + "%d%s", + randomNonNegativeInt(), + randomFrom( + DistanceUnit.INCH, + DistanceUnit.YARD, + DistanceUnit.FEET, + DistanceUnit.KILOMETERS, + DistanceUnit.NAUTICALMILES, + DistanceUnit.MILLIMETERS, + DistanceUnit.CENTIMETERS, + DistanceUnit.MILES, + DistanceUnit.METERS + ) + ); + } + + private static double geoPointDecayWithScoreScript( + GeoPoint value, + GeoPoint origin, + String scale, + String offset, + double decay, + String type + ) { + String originStr = origin.getX() + "," + origin.getY(); + + return switch (type) { + case "linear" -> new ScoreScriptUtils.DecayGeoLinear(originStr, scale, offset, decay).decayGeoLinear(value); + case "gauss" -> new ScoreScriptUtils.DecayGeoGauss(originStr, scale, offset, decay).decayGeoGauss(value); + case "exp" -> new ScoreScriptUtils.DecayGeoExp(originStr, scale, offset, decay).decayGeoExp(value); + default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); + }; + } + + private static List geoPointOffsetKeywordTestCase( + String valueWkt, + String originWkt, + String scale, + String offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.GEO_POINT, DataType.GEO_POINT, DataType.TEXT, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(GEO.wktToWkb(valueWkt), DataType.GEO_POINT, "value"), + new TestCaseSupplier.TypedData(GEO.wktToWkb(originWkt), DataType.GEO_POINT, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.TEXT, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayGeoPointEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List cartesianPointTestCase( + String valueWkt, + String originWkt, + double scale, + Double offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.CARTESIAN_POINT, DataType.CARTESIAN_POINT, DataType.DOUBLE, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(CARTESIAN.wktToWkb(valueWkt), DataType.CARTESIAN_POINT, "value"), + new TestCaseSupplier.TypedData(CARTESIAN.wktToWkb(originWkt), DataType.CARTESIAN_POINT, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.DOUBLE, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayCartesianPointEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ) + ) + ); + } + + private static List datetimeTestCase( + long value, + long origin, + Duration scale, + Duration offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.DATETIME, DataType.DATETIME, DataType.TIME_DURATION, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(value, DataType.DATETIME, "value"), + new TestCaseSupplier.TypedData(origin, DataType.DATETIME, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.TIME_DURATION, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayDatetimeEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ).withoutEvaluator() + ) + ); + } + + private static List datetimeRandomTestCases() { + return List.of(new TestCaseSupplier(List.of(DataType.DATETIME, DataType.DATETIME, DataType.TIME_DURATION, DataType.SOURCE), () -> { + // 1970-01-01 + long minEpoch = 0L; + // 2070-01-01 + long maxEpoch = 3155673600000L; + long randomValue = randomLongBetween(minEpoch, maxEpoch); + long randomOrigin = randomLongBetween(minEpoch, maxEpoch); + + // Max 1 year + long randomScaleMillis = randomNonNegativeLong() % (365L * 24 * 60 * 60 * 1000); + // Max 30 days + long randomOffsetMillis = randomNonNegativeLong() % (30L * 24 * 60 * 60 * 1000); + Duration randomScale = Duration.ofMillis(randomScaleMillis); + Duration randomOffset = Duration.ofMillis(randomOffsetMillis); + double randomDecay = randomDouble(); + String randomType = randomFrom("linear", "gauss", "exp"); + + double scoreScriptNumericResult = datetimeDecayWithScoreScript( + randomValue, + randomOrigin, + randomScale.toMillis(), + randomOffset.toMillis(), + randomDecay, + randomType + ); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomValue, DataType.DATETIME, "value"), + new TestCaseSupplier.TypedData(randomOrigin, DataType.DATETIME, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(randomScale, DataType.TIME_DURATION, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayDatetimeEvaluator["), + DataType.DOUBLE, + closeTo(scoreScriptNumericResult, Math.ulp(scoreScriptNumericResult)) + ); + })); + } + + private static double datetimeDecayWithScoreScript(long value, long origin, long scale, long offset, double decay, String type) { + String originStr = String.valueOf(origin); + String scaleStr = scale + "ms"; + String offsetStr = offset + "ms"; + + ZonedDateTime valueDateTime = Instant.ofEpochMilli(value).atZone(ZoneId.of("UTC")); + + return switch (type) { + case "linear" -> new ScoreScriptUtils.DecayDateLinear(originStr, scaleStr, offsetStr, decay).decayDateLinear(valueDateTime); + case "gauss" -> new ScoreScriptUtils.DecayDateGauss(originStr, scaleStr, offsetStr, decay).decayDateGauss(valueDateTime); + case "exp" -> new ScoreScriptUtils.DecayDateExp(originStr, scaleStr, offsetStr, decay).decayDateExp(valueDateTime); + default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); + }; + } + + private static List dateNanosTestCase( + long value, + long origin, + Duration scale, + Duration offset, + Double decay, + String functionType, + double expected + ) { + return List.of( + new TestCaseSupplier( + List.of(DataType.DATE_NANOS, DataType.DATE_NANOS, DataType.TIME_DURATION, DataType.SOURCE), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(value, DataType.DATE_NANOS, "value"), + new TestCaseSupplier.TypedData(origin, DataType.DATE_NANOS, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(scale, DataType.TIME_DURATION, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(offset, decay, functionType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayDateNanosEvaluator["), + DataType.DOUBLE, + closeTo(expected, Math.ulp(expected)) + ).withoutEvaluator() + ) + ); + } + + private static List dateNanosRandomTestCases() { + return List.of( + new TestCaseSupplier(List.of(DataType.DATE_NANOS, DataType.DATE_NANOS, DataType.TIME_DURATION, DataType.SOURCE), () -> { + // 1970-01-01 in nanos + long minEpochNanos = 0L; + // 2070-01-01 in nanos + long maxEpochNanos = 3155673600000L * 1_000_000L; + long randomValue = randomLongBetween(minEpochNanos, maxEpochNanos); + long randomOrigin = randomLongBetween(minEpochNanos, maxEpochNanos); + + // Max 1 year in milliseconds + long randomScaleMillis = randomNonNegativeLong() % (365L * 24 * 60 * 60 * 1000); + // Max 30 days in milliseconds + long randomOffsetMillis = randomNonNegativeLong() % (30L * 24 * 60 * 60 * 1000); + Duration randomScale = Duration.ofMillis(randomScaleMillis); + Duration randomOffset = Duration.ofMillis(randomOffsetMillis); + + double randomDecay = randomDouble(); + String randomType = randomFrom("linear", "gauss", "exp"); + + double scoreScriptNumericResult = dateNanosDecayWithScoreScript( + randomValue, + randomOrigin, + randomScale.toMillis(), + randomOffset.toMillis(), + randomDecay, + randomType + ); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomValue, DataType.DATE_NANOS, "value"), + new TestCaseSupplier.TypedData(randomOrigin, DataType.DATE_NANOS, "origin").forceLiteral(), + new TestCaseSupplier.TypedData(randomScale, DataType.TIME_DURATION, "scale").forceLiteral(), + new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options") + .forceLiteral() + ), + startsWith("DecayDateNanosEvaluator["), + DataType.DOUBLE, + closeTo(scoreScriptNumericResult, 1e-10) + ); + }) + ); + } + + private static double dateNanosDecayWithScoreScript(long value, long origin, long scale, long offset, double decay, String type) { + long valueMillis = value / 1_000_000L; + long originMillis = origin / 1_000_000L; + + String originStr = String.valueOf(originMillis); + String scaleStr = scale + "ms"; + String offsetStr = offset + "ms"; + + ZonedDateTime valueDateTime = Instant.ofEpochMilli(valueMillis).atZone(ZoneId.of("UTC")); + + return switch (type) { + case "linear" -> new ScoreScriptUtils.DecayDateLinear(originStr, scaleStr, offsetStr, decay).decayDateLinear(valueDateTime); + case "gauss" -> new ScoreScriptUtils.DecayDateGauss(originStr, scaleStr, offsetStr, decay).decayDateGauss(valueDateTime); + case "exp" -> new ScoreScriptUtils.DecayDateExp(originStr, scaleStr, offsetStr, decay).decayDateExp(valueDateTime); + default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]"); + }; + } + + private static MapExpression createOptionsMap(Object offset, Double decay, String functionType) { + List keyValuePairs = new ArrayList<>(); + + // Offset + if (Objects.nonNull(offset)) { + keyValuePairs.add(Literal.keyword(Source.EMPTY, "offset")); + switch (offset) { + case Integer value -> keyValuePairs.add(Literal.integer(Source.EMPTY, value)); + case Long value -> keyValuePairs.add(Literal.fromLong(Source.EMPTY, value)); + case Double value -> keyValuePairs.add(Literal.fromDouble(Source.EMPTY, value)); + case String value -> keyValuePairs.add(Literal.text(Source.EMPTY, value)); + case Duration value -> keyValuePairs.add(Literal.timeDuration(Source.EMPTY, value)); + default -> { + } + } + } + + // Decay + if (Objects.nonNull(decay)) { + keyValuePairs.add(Literal.keyword(Source.EMPTY, "decay")); + keyValuePairs.add(Literal.fromDouble(Source.EMPTY, decay)); + } + + // Type + if (Objects.nonNull(functionType)) { + keyValuePairs.add(Literal.keyword(Source.EMPTY, "type")); + keyValuePairs.add(Literal.keyword(Source.EMPTY, functionType)); + } + + return new MapExpression(Source.EMPTY, keyValuePairs); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 82119fb7baa82..4fded76b09076 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -168,6 +168,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultAnalyzer; import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; @@ -8994,4 +8995,36 @@ public void testTranslateMetricsGroupedByTBucketInTSMode() { Bucket bucket = as(Alias.unwrap(eval.fields().get(0)), Bucket.class); assertThat(Expressions.attribute(bucket.field()).name(), equalTo("@timestamp")); } + + public void testDecayOriginMustBeLiteral() { + assumeTrue("requires DECAY_FUNCTION capability enabled", EsqlCapabilities.Cap.DECAY_FUNCTION.isEnabled()); + + var query = """ + FROM employees + | EVAL decay_result = decay(salary, salary, 10, {"offset": 5, "decay": 0.5, "type": "linear"}) + | KEEP decay_result + | LIMIT 5"""; + + Exception e = expectThrows( + VerificationException.class, + () -> logicalOptimizer.optimize(defaultAnalyzer().analyze(parser.createStatement(query, EsqlTestUtils.TEST_CFG))) + ); + assertThat(e.getMessage(), containsString("has non-literal value [origin]")); + } + + public void testDecayScaleMustBeLiteral() { + assumeTrue("requires DECAY_FUNCTION capability enabled", EsqlCapabilities.Cap.DECAY_FUNCTION.isEnabled()); + + var query = """ + FROM employees + | EVAL decay_result = decay(salary, 10, salary, {"offset": 5, "decay": 0.5, "type": "linear"}) + | KEEP decay_result + | LIMIT 5"""; + + Exception e = expectThrows( + VerificationException.class, + () -> logicalOptimizer.optimize(defaultAnalyzer().analyze(parser.createStatement(query, EsqlTestUtils.TEST_CFG))) + ); + assertThat(e.getMessage(), containsString("has non-literal value [scale]")); + } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index f670b7e639764..a2840648b444e 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -129,7 +129,7 @@ setup: - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 171} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 172} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version":