Skip to content

Commit 92e2877

Browse files
committed
Add random test cases (comparing to ScoreScriptUtils) to DecayTests, where applicable
1 parent 393248d commit 92e2877

File tree

1 file changed

+292
-0
lines changed
  • x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score

1 file changed

+292
-0
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/DecayTests.java

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1212

13+
import org.elasticsearch.common.geo.GeoPoint;
14+
import org.elasticsearch.common.unit.DistanceUnit;
15+
import org.elasticsearch.script.ScoreScriptUtils;
1316
import org.elasticsearch.xpack.esql.core.expression.Expression;
1417
import org.elasticsearch.xpack.esql.core.expression.Literal;
1518
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
@@ -19,8 +22,10 @@
1922
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
2023

2124
import java.time.Duration;
25+
import java.time.Instant;
2226
import java.time.LocalDateTime;
2327
import java.time.ZoneId;
28+
import java.time.ZonedDateTime;
2429
import java.util.ArrayList;
2530
import java.util.List;
2631
import java.util.Objects;
@@ -65,6 +70,9 @@ public static Iterable<Object[]> parameters() {
6570
// Int defaults
6671
testCaseSuppliers.addAll(intTestCase(10, 0, 10, null, null, null, 0.5));
6772

73+
// Int random
74+
testCaseSuppliers.addAll(intRandomTestCases());
75+
6876
// Long Linear
6977
testCaseSuppliers.addAll(longTestCase(0L, 10L, 10000000L, 200L, 0.33, "linear", 1.0));
7078
testCaseSuppliers.addAll(longTestCase(10L, 10L, 10000000L, 200L, 0.33, "linear", 1.0));
@@ -89,6 +97,9 @@ public static Iterable<Object[]> parameters() {
8997
// Long defaults
9098
testCaseSuppliers.addAll(longTestCase(10L, 0L, 10L, null, null, null, 0.5));
9199

100+
// Long random
101+
testCaseSuppliers.addAll(longRandomTestCases());
102+
92103
// Double Linear
93104
testCaseSuppliers.addAll(doubleTestCase(0.0, 10.0, 10000000.0, 200.0, 0.25, "linear", 1.0));
94105
testCaseSuppliers.addAll(doubleTestCase(10.0, 10.0, 10000000.0, 200.0, 0.25, "linear", 1.0));
@@ -113,6 +124,9 @@ public static Iterable<Object[]> parameters() {
113124
// Double defaults
114125
testCaseSuppliers.addAll(doubleTestCase(10.0, 0.0, 10.0, null, null, null, 0.5));
115126

127+
// Double random
128+
testCaseSuppliers.addAll(doubleRandomTestCases());
129+
116130
// GeoPoint Linear
117131
testCaseSuppliers.addAll(geoPointTestCase("POINT (1.0 1.0)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 1.0));
118132
testCaseSuppliers.addAll(geoPointTestCase("POINT (0 0)", "POINT (1 1)", "10000km", "10km", 0.33, "linear", 0.9901342769495362));
@@ -155,6 +169,9 @@ public static Iterable<Object[]> parameters() {
155169
// GeoPoint defaults
156170
testCaseSuppliers.addAll(geoPointTestCase("POINT (12.3 45.6)", "POINT (1 1)", "10000km", null, null, null, 0.7459413262379005));
157171

172+
// GeoPoint random
173+
testCaseSuppliers.addAll(geoPointRandomTestCases());
174+
158175
// CartesianPoint Linear
159176
testCaseSuppliers.addAll(cartesianPointTestCase("POINT (0 0)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 1.0));
160177
testCaseSuppliers.addAll(cartesianPointTestCase("POINT (1 1)", "POINT (1 1)", 10000.0, 10.0, 0.33, "linear", 1.0));
@@ -381,6 +398,9 @@ public static Iterable<Object[]> parameters() {
381398
)
382399
);
383400

401+
// Datetime random
402+
testCaseSuppliers.addAll(datetimeRandomTestCases());
403+
384404
// Datenanos Linear
385405
testCaseSuppliers.addAll(
386406
dateNanosTestCase(
@@ -606,6 +626,52 @@ private static List<TestCaseSupplier> intTestCase(
606626
);
607627
}
608628

629+
private static List<TestCaseSupplier> intRandomTestCases() {
630+
return List.of(new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER, DataType.INTEGER, DataType.SOURCE), () -> {
631+
int randomValue = randomInt();
632+
int randomOrigin = randomInt();
633+
int randomScale = randomInt();
634+
int randomOffset = randomInt();
635+
double randomDecay = randomDouble();
636+
String randomType = getRandomType();
637+
638+
double scoreScriptNumericResult = intDecayWithScoreScript(
639+
randomValue,
640+
randomOrigin,
641+
randomScale,
642+
randomOffset,
643+
randomDecay,
644+
randomType
645+
);
646+
647+
return new TestCaseSupplier.TestCase(
648+
List.of(
649+
new TestCaseSupplier.TypedData(randomValue, DataType.INTEGER, "value"),
650+
new TestCaseSupplier.TypedData(randomOrigin, DataType.INTEGER, "origin").forceLiteral(),
651+
new TestCaseSupplier.TypedData(randomScale, DataType.INTEGER, "scale").forceLiteral(),
652+
new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options")
653+
.forceLiteral()
654+
),
655+
startsWith("DecayIntEvaluator["),
656+
DataType.DOUBLE,
657+
equalTo(scoreScriptNumericResult)
658+
);
659+
}));
660+
}
661+
662+
private static String getRandomType() {
663+
return randomFrom("linear", "gauss", "exp");
664+
}
665+
666+
private static double intDecayWithScoreScript(int value, int origin, int scale, int offset, double decay, String type) {
667+
return switch (type) {
668+
case "linear" -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
669+
case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
670+
case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
671+
default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]");
672+
};
673+
}
674+
609675
private static List<TestCaseSupplier> longTestCase(
610676
long value,
611677
long origin,
@@ -634,6 +700,48 @@ private static List<TestCaseSupplier> longTestCase(
634700
);
635701
}
636702

703+
private static List<TestCaseSupplier> longRandomTestCases() {
704+
return List.of(new TestCaseSupplier(List.of(DataType.LONG, DataType.LONG, DataType.LONG, DataType.SOURCE), () -> {
705+
long randomValue = randomLong();
706+
long randomOrigin = randomLong();
707+
long randomScale = randomLong();
708+
long randomOffset = randomLong();
709+
double randomDecay = randomDouble();
710+
String randomType = randomFrom("linear", "gauss", "exp");
711+
712+
double scoreScriptNumericResult = longDecayWithScoreScript(
713+
randomValue,
714+
randomOrigin,
715+
randomScale,
716+
randomOffset,
717+
randomDecay,
718+
randomType
719+
);
720+
721+
return new TestCaseSupplier.TestCase(
722+
List.of(
723+
new TestCaseSupplier.TypedData(randomValue, DataType.LONG, "value"),
724+
new TestCaseSupplier.TypedData(randomOrigin, DataType.LONG, "origin").forceLiteral(),
725+
new TestCaseSupplier.TypedData(randomScale, DataType.LONG, "scale").forceLiteral(),
726+
new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options")
727+
.forceLiteral()
728+
),
729+
startsWith("DecayLongEvaluator["),
730+
DataType.DOUBLE,
731+
equalTo(scoreScriptNumericResult)
732+
);
733+
}));
734+
}
735+
736+
private static double longDecayWithScoreScript(long value, long origin, long scale, long offset, double decay, String type) {
737+
return switch (type) {
738+
case "linear" -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
739+
case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
740+
case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
741+
default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]");
742+
};
743+
}
744+
637745
private static List<TestCaseSupplier> doubleTestCase(
638746
double value,
639747
double origin,
@@ -662,6 +770,48 @@ private static List<TestCaseSupplier> doubleTestCase(
662770
);
663771
}
664772

773+
private static List<TestCaseSupplier> doubleRandomTestCases() {
774+
return List.of(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.DOUBLE, DataType.SOURCE), () -> {
775+
double randomValue = randomLong();
776+
double randomOrigin = randomLong();
777+
double randomScale = randomLong();
778+
double randomOffset = randomLong();
779+
double randomDecay = randomDouble();
780+
String randomType = randomFrom("linear", "gauss", "exp");
781+
782+
double scoreScriptNumericResult = doubleDecayWithScoreScript(
783+
randomValue,
784+
randomOrigin,
785+
randomScale,
786+
randomOffset,
787+
randomDecay,
788+
randomType
789+
);
790+
791+
return new TestCaseSupplier.TestCase(
792+
List.of(
793+
new TestCaseSupplier.TypedData(randomValue, DataType.DOUBLE, "value"),
794+
new TestCaseSupplier.TypedData(randomOrigin, DataType.DOUBLE, "origin").forceLiteral(),
795+
new TestCaseSupplier.TypedData(randomScale, DataType.DOUBLE, "scale").forceLiteral(),
796+
new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options")
797+
.forceLiteral()
798+
),
799+
startsWith("DecayDoubleEvaluator["),
800+
DataType.DOUBLE,
801+
equalTo(scoreScriptNumericResult)
802+
);
803+
}));
804+
}
805+
806+
private static double doubleDecayWithScoreScript(double value, double origin, double scale, double offset, double decay, String type) {
807+
return switch (type) {
808+
case "linear" -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
809+
case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
810+
case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
811+
default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]");
812+
};
813+
}
814+
665815
private static List<TestCaseSupplier> geoPointTestCase(
666816
String valueWkt,
667817
String originWkt,
@@ -718,6 +868,91 @@ private static List<TestCaseSupplier> geoPointTestCaseKeywordScale(
718868
);
719869
}
720870

871+
private static List<TestCaseSupplier> geoPointRandomTestCases() {
872+
return List.of(new TestCaseSupplier(List.of(DataType.GEO_POINT, DataType.GEO_POINT, DataType.KEYWORD, DataType.SOURCE), () -> {
873+
GeoPoint randomValue = randomGeoPoint();
874+
GeoPoint randomOrigin = randomGeoPoint();
875+
String randomScale = randomDistance();
876+
String randomOffset = randomDistance();
877+
double randomDecay = randomDouble();
878+
String randomType = randomDecayType();
879+
880+
double scoreScriptNumericResult = geoPointDecayWithScoreScript(
881+
randomValue,
882+
randomOrigin,
883+
randomScale,
884+
randomOffset,
885+
randomDecay,
886+
randomType
887+
);
888+
889+
return new TestCaseSupplier.TestCase(
890+
List.of(
891+
new TestCaseSupplier.TypedData(GEO.wktToWkb(randomValue.toWKT()), DataType.GEO_POINT, "value"),
892+
new TestCaseSupplier.TypedData(GEO.wktToWkb(randomOrigin.toWKT()), DataType.GEO_POINT, "origin").forceLiteral(),
893+
new TestCaseSupplier.TypedData(randomScale, DataType.KEYWORD, "scale").forceLiteral(),
894+
new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options")
895+
.forceLiteral()
896+
),
897+
startsWith("DecayGeoPointEvaluator["),
898+
DataType.DOUBLE,
899+
equalTo(scoreScriptNumericResult)
900+
);
901+
}));
902+
}
903+
904+
private static String randomDecayType() {
905+
return randomFrom("linear", "gauss", "exp");
906+
}
907+
908+
private static GeoPoint randomGeoPoint() {
909+
return new GeoPoint(randomLatitude(), randomLongitude());
910+
}
911+
912+
private static double randomLongitude() {
913+
return randomDoubleBetween(-180.0, 180.0, true);
914+
}
915+
916+
private static double randomLatitude() {
917+
return randomDoubleBetween(-90.0, 90.0, true);
918+
}
919+
920+
private static String randomDistance() {
921+
return String.format(
922+
"%d%s",
923+
randomNonNegativeInt(),
924+
randomFrom(
925+
DistanceUnit.INCH,
926+
DistanceUnit.YARD,
927+
DistanceUnit.FEET,
928+
DistanceUnit.KILOMETERS,
929+
DistanceUnit.NAUTICALMILES,
930+
DistanceUnit.MILLIMETERS,
931+
DistanceUnit.CENTIMETERS,
932+
DistanceUnit.MILES,
933+
DistanceUnit.METERS
934+
)
935+
);
936+
}
937+
938+
private static double geoPointDecayWithScoreScript(
939+
GeoPoint value,
940+
GeoPoint origin,
941+
String scale,
942+
String offset,
943+
double decay,
944+
String type
945+
) {
946+
String originStr = origin.getX() + "," + origin.getY();
947+
948+
return switch (type) {
949+
case "linear" -> new ScoreScriptUtils.DecayGeoLinear(originStr, scale, offset, decay).decayGeoLinear(value);
950+
case "gauss" -> new ScoreScriptUtils.DecayGeoGauss(originStr, scale, offset, decay).decayGeoGauss(value);
951+
case "exp" -> new ScoreScriptUtils.DecayGeoExp(originStr, scale, offset, decay).decayGeoExp(value);
952+
default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]");
953+
};
954+
}
955+
721956
private static List<TestCaseSupplier> geoPointOffsetKeywordTestCase(
722957
String valueWkt,
723958
String originWkt,
@@ -802,6 +1037,63 @@ private static List<TestCaseSupplier> datetimeTestCase(
8021037
);
8031038
}
8041039

1040+
private static List<TestCaseSupplier> datetimeRandomTestCases() {
1041+
return List.of(new TestCaseSupplier(List.of(DataType.DATETIME, DataType.DATETIME, DataType.TIME_DURATION, DataType.SOURCE), () -> {
1042+
// 1970-01-01
1043+
long minEpoch = 0L;
1044+
// 2070-01-01
1045+
long maxEpoch = 3155673600000L;
1046+
long randomValue = randomLongBetween(minEpoch, maxEpoch);
1047+
long randomOrigin = randomLongBetween(minEpoch, maxEpoch);
1048+
1049+
// Max 1 year
1050+
long randomScaleMillis = randomNonNegativeLong() % (365L * 24 * 60 * 60 * 1000);
1051+
// Max 30 days
1052+
long randomOffsetMillis = randomNonNegativeLong() % (30L * 24 * 60 * 60 * 1000);
1053+
Duration randomScale = Duration.ofMillis(randomScaleMillis);
1054+
Duration randomOffset = Duration.ofMillis(randomOffsetMillis);
1055+
double randomDecay = randomDouble();
1056+
String randomType = randomFrom("linear", "gauss", "exp");
1057+
1058+
double scoreScriptNumericResult = datetimeDecayWithScoreScript(
1059+
randomValue,
1060+
randomOrigin,
1061+
randomScale.toMillis(),
1062+
randomOffset.toMillis(),
1063+
randomDecay,
1064+
randomType
1065+
);
1066+
1067+
return new TestCaseSupplier.TestCase(
1068+
List.of(
1069+
new TestCaseSupplier.TypedData(randomValue, DataType.DATETIME, "value"),
1070+
new TestCaseSupplier.TypedData(randomOrigin, DataType.DATETIME, "origin").forceLiteral(),
1071+
new TestCaseSupplier.TypedData(randomScale, DataType.TIME_DURATION, "scale").forceLiteral(),
1072+
new TestCaseSupplier.TypedData(createOptionsMap(randomOffset, randomDecay, randomType), DataType.SOURCE, "options")
1073+
.forceLiteral()
1074+
),
1075+
startsWith("DecayDatetimeEvaluator["),
1076+
DataType.DOUBLE,
1077+
equalTo(scoreScriptNumericResult)
1078+
);
1079+
}));
1080+
}
1081+
1082+
private static double datetimeDecayWithScoreScript(long value, long origin, long scale, long offset, double decay, String type) {
1083+
String originStr = String.valueOf(origin);
1084+
String scaleStr = scale + "ms";
1085+
String offsetStr = offset + "ms";
1086+
1087+
ZonedDateTime valueDateTime = Instant.ofEpochMilli(value).atZone(ZoneId.of("UTC"));
1088+
1089+
return switch (type) {
1090+
case "linear" -> new ScoreScriptUtils.DecayDateLinear(originStr, scaleStr, offsetStr, decay).decayDateLinear(valueDateTime);
1091+
case "gauss" -> new ScoreScriptUtils.DecayDateGauss(originStr, scaleStr, offsetStr, decay).decayDateGauss(valueDateTime);
1092+
case "exp" -> new ScoreScriptUtils.DecayDateExp(originStr, scaleStr, offsetStr, decay).decayDateExp(valueDateTime);
1093+
default -> throw new IllegalArgumentException("Unknown decay function type [" + type + "]");
1094+
};
1095+
}
1096+
8051097
private static List<TestCaseSupplier> dateNanosTestCase(
8061098
long value,
8071099
long origin,

0 commit comments

Comments
 (0)