Skip to content

Commit c723309

Browse files
committed
Use enum for DecayFunction instead of switch statements
1 parent d93e573 commit c723309

File tree

1 file changed

+102
-44
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/score

1 file changed

+102
-44
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/score/Decay.java

Lines changed: 102 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
import java.io.IOException;
4545
import java.time.Duration;
46+
import java.util.Arrays;
4647
import java.util.Collection;
4748
import java.util.HashMap;
4849
import java.util.List;
@@ -51,6 +52,7 @@
5152
import java.util.Set;
5253
import java.util.function.BiConsumer;
5354
import java.util.function.Predicate;
55+
import java.util.stream.Collectors;
5456

5557
import static org.elasticsearch.xpack.esql.common.Failure.fail;
5658
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
@@ -405,11 +407,7 @@ static double process(
405407
@Fixed double decay,
406408
@Fixed BytesRef functionType
407409
) {
408-
return switch (functionType.utf8ToString()) {
409-
case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
410-
case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
411-
default -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
412-
};
410+
return DecayFunction.fromBytesRef(functionType).numericDecay(value, origin, scale, offset, decay);
413411
}
414412

415413
@Evaluator(extraName = "Double")
@@ -421,11 +419,7 @@ static double process(
421419
@Fixed double decay,
422420
@Fixed BytesRef functionType
423421
) {
424-
return switch (functionType.utf8ToString()) {
425-
case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
426-
case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
427-
default -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
428-
};
422+
return DecayFunction.fromBytesRef(functionType).numericDecay(value, origin, scale, offset, decay);
429423
}
430424

431425
@Evaluator(extraName = "Long")
@@ -437,11 +431,8 @@ static double process(
437431
@Fixed double decay,
438432
@Fixed BytesRef functionType
439433
) {
440-
return switch (functionType.utf8ToString()) {
441-
case "exp" -> new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
442-
case "gauss" -> new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
443-
default -> new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
444-
};
434+
return DecayFunction.fromBytesRef(functionType).numericDecay(value, origin, scale, offset, decay);
435+
445436
}
446437

447438
@Evaluator(extraName = "GeoPoint")
@@ -463,11 +454,7 @@ static double process(
463454
String scaleStr = scale.utf8ToString();
464455
String offsetStr = offset.utf8ToString();
465456

466-
return switch (functionType.utf8ToString()) {
467-
case "exp" -> new ScoreScriptUtils.DecayGeoExp(originStr, scaleStr, offsetStr, decay).decayGeoExp(valueGeoPoint);
468-
case "gauss" -> new ScoreScriptUtils.DecayGeoGauss(originStr, scaleStr, offsetStr, decay).decayGeoGauss(valueGeoPoint);
469-
default -> new ScoreScriptUtils.DecayGeoLinear(originStr, scaleStr, offsetStr, decay).decayGeoLinear(valueGeoPoint);
470-
};
457+
return DecayFunction.fromBytesRef(functionType).geoPointDecay(valueGeoPoint, originStr, scaleStr, offsetStr, decay);
471458
}
472459

473460
@Evaluator(extraName = "CartesianPoint")
@@ -489,21 +476,7 @@ static double processCartesianPoint(
489476

490477
distance = Math.max(0.0, distance - offset);
491478

492-
return switch (functionType.utf8ToString()) {
493-
case "exp" -> {
494-
double scaling = Math.log(decay) / scale;
495-
yield Math.exp(scaling * distance);
496-
}
497-
case "gauss" -> {
498-
double sigmaSquared = -Math.pow(scale, 2.0) / (2.0 * Math.log(decay));
499-
yield Math.exp(-Math.pow(distance, 2.0) / (2.0 * sigmaSquared));
500-
}
501-
// linear
502-
default -> {
503-
double scaling = scale / (1.0 - decay);
504-
yield Math.max(0.0, (scaling - distance) / scaling);
505-
}
506-
};
479+
return DecayFunction.fromBytesRef(functionType).cartesianDecay(distance, scale, offset, decay);
507480
}
508481

509482
@Evaluator(extraName = "Datetime", warnExceptions = { InvalidArgumentException.class, IllegalArgumentException.class })
@@ -515,11 +488,7 @@ static double processDatetime(
515488
@Fixed double decay,
516489
@Fixed BytesRef functionType
517490
) {
518-
return switch (functionType.utf8ToString()) {
519-
case "exp" -> decayDateExp(origin, scale, offset, decay, value);
520-
case "gauss" -> decayDateGauss(origin, scale, offset, decay, value);
521-
default -> decayDateLinear(origin, scale, offset, decay, value);
522-
};
491+
return DecayFunction.fromBytesRef(functionType).temporalDecay(value, origin, scale, offset, decay);
523492
}
524493

525494
@Evaluator(extraName = "DateNanos", warnExceptions = { InvalidArgumentException.class, IllegalArgumentException.class })
@@ -531,11 +500,100 @@ static double processDateNanos(
531500
@Fixed double decay,
532501
@Fixed BytesRef functionType
533502
) {
534-
return switch (functionType.utf8ToString()) {
535-
case "exp" -> decayDateExp(origin, scale, offset, decay, value);
536-
case "gauss" -> decayDateGauss(origin, scale, offset, decay, value);
537-
default -> decayDateLinear(origin, scale, offset, decay, value);
503+
return DecayFunction.fromBytesRef(functionType).temporalDecay(value, origin, scale, offset, decay);
504+
505+
}
506+
507+
private enum DecayFunction {
508+
LINEAR("linear"){
509+
@Override
510+
public double numericDecay(double value, double origin, double scale, double offset, double decay) {
511+
return new ScoreScriptUtils.DecayNumericLinear(origin, scale, offset, decay).decayNumericLinear(value);
512+
}
513+
514+
@Override
515+
public double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay) {
516+
return new ScoreScriptUtils.DecayGeoLinear(origin, scale, offset, decay).decayGeoLinear(value);
517+
}
518+
519+
@Override
520+
public double cartesianDecay(double distance, double scale, double offset, double decay) {
521+
double scaling = scale / (1.0 - decay);
522+
return Math.max(0.0, (scaling - distance) / scaling);
523+
}
524+
525+
@Override
526+
public double temporalDecay(long value, long origin, long scale, long offset, double decay) {
527+
return decayDateLinear(origin, scale, offset, decay, value);
528+
}
529+
},
530+
531+
EXPONENTIAL("exp"){
532+
@Override
533+
public double numericDecay(double value, double origin, double scale, double offset, double decay) {
534+
return new ScoreScriptUtils.DecayNumericExp(origin, scale, offset, decay).decayNumericExp(value);
535+
}
536+
537+
@Override
538+
public double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay) {
539+
return new ScoreScriptUtils.DecayGeoExp(origin, scale, offset, decay).decayGeoExp(value);
540+
}
541+
542+
@Override
543+
public double cartesianDecay(double distance, double scale, double offset, double decay) {
544+
double scaling = Math.log(decay) / scale;
545+
return Math.exp(scaling * distance);
546+
}
547+
548+
@Override
549+
public double temporalDecay(long value, long origin, long scale, long offset, double decay) {
550+
return decayDateExp(origin, scale, offset, decay, value);
551+
}
552+
},
553+
554+
GAUSSIAN("gauss"){
555+
@Override
556+
public double numericDecay(double value, double origin, double scale, double offset, double decay) {
557+
return new ScoreScriptUtils.DecayNumericGauss(origin, scale, offset, decay).decayNumericGauss(value);
558+
}
559+
560+
@Override
561+
public double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay) {
562+
return new ScoreScriptUtils.DecayGeoGauss(origin, scale, offset, decay).decayGeoGauss(value);
563+
}
564+
565+
@Override
566+
public double cartesianDecay(double distance, double scale, double offset, double decay) {
567+
double sigmaSquared = -Math.pow(scale, 2.0) / (2.0 * Math.log(decay));
568+
return Math.exp(-Math.pow(distance, 2.0) / (2.0 * sigmaSquared));
569+
}
570+
571+
@Override
572+
public double temporalDecay(long value, long origin, long scale, long offset, double decay) {
573+
return decayDateGauss(origin, scale, offset, decay, value);
574+
}
538575
};
576+
577+
578+
private final String functionName;
579+
private static final Map<String, DecayFunction> BY_NAME = Arrays.stream(values())
580+
.collect(Collectors.toMap(df -> df.functionName, df -> df));
581+
582+
DecayFunction(String functionName) {
583+
this.functionName = functionName;
584+
}
585+
586+
public abstract double numericDecay(double value, double origin, double scale, double offset, double decay);
587+
588+
public abstract double geoPointDecay(GeoPoint value, String origin, String scale, String offset, double decay);
589+
590+
public abstract double cartesianDecay(double distance, double scale, double offset, double decay);
591+
592+
public abstract double temporalDecay(long value, long origin, long scale, long offset, double decay);
593+
594+
public static DecayFunction fromBytesRef(BytesRef functionType) {
595+
return BY_NAME.getOrDefault(functionType.utf8ToString(), LINEAR);
596+
}
539597
}
540598

541599
private static double decayDateLinear(long origin, long scale, long offset, double decay, long value) {

0 commit comments

Comments
 (0)