Skip to content

Commit b2e3c7e

Browse files
committed
Small refactoring.
1 parent 56d623e commit b2e3c7e

File tree

5 files changed

+44
-27
lines changed

5 files changed

+44
-27
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@
1717

1818
import java.util.List;
1919
import java.util.Objects;
20+
import java.util.function.Supplier;
2021
import java.util.stream.Stream;
2122

2223
public abstract class InferenceFunction extends Function {
23-
24-
private Expression inferenceId;
25-
24+
private final Expression inferenceId;
2625
private final Expression options;
2726

2827
@SuppressWarnings("this-escape")
2928
protected InferenceFunction(Source source, List<Expression> children, Expression options) {
3029
super(source, Stream.concat(children.stream(), Stream.of(options)).toList());
31-
this.inferenceId = parseInferenceId(options);
30+
this.inferenceId = parseInferenceId(options, this::defaultInferenceId);
3231
this.options = options;
3332
}
3433

@@ -40,22 +39,6 @@ public Expression options() {
4039
return options;
4140
}
4241

43-
protected abstract Expression parseInferenceId(Expression options);
44-
45-
public abstract List<Attribute> temporaryAttributes();
46-
47-
protected Expression readOption(String optionName, Expression options) {
48-
return readOption(optionName, options, Literal.NULL);
49-
}
50-
51-
protected Expression readOption(String optionName, Expression options, Expression defaultValue) {
52-
if (options != null && options.dataType() != DataType.NULL && options instanceof MapExpression mapOptions) {
53-
return mapOptions.getOrDefault(optionName, defaultValue);
54-
}
55-
56-
return defaultValue;
57-
}
58-
5942
@Override
6043
protected TypeResolution resolveType() {
6144
if (childrenResolved() == false) {
@@ -65,6 +48,10 @@ protected TypeResolution resolveType() {
6548
return resolveParams().and(resolveOptions());
6649
}
6750

51+
protected abstract Expression defaultInferenceId();
52+
53+
public abstract List<Attribute> temporaryAttributes();
54+
6855
protected abstract TypeResolution resolveParams();
6956

7057
protected abstract TypeResolution resolveOptions();
@@ -81,4 +68,16 @@ public boolean equals(Object o) {
8168
public int hashCode() {
8269
return Objects.hash(super.hashCode(), inferenceId, options);
8370
}
71+
72+
private static Expression parseInferenceId(Expression options, Supplier<Expression> defautlInferenceIdSupplier) {
73+
return readOption("inference_id", options, defautlInferenceIdSupplier);
74+
}
75+
76+
private static Expression readOption(String optionName, Expression options, Supplier<Expression> defaultValueSupplier) {
77+
if (options != null && options.dataType() != DataType.NULL && options instanceof MapExpression mapOptions) {
78+
return mapOptions.getOrDefault(optionName, defaultValueSupplier.get());
79+
}
80+
81+
return defaultValueSupplier.get();
82+
}
8483
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/RerankFunction.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ public String getWriteableName() {
114114
}
115115

116116
@Override
117-
protected Expression parseInferenceId(Expression options) {
118-
return readOption("inference_id", options, defaultInferenceId());
119-
}
120-
121-
private Literal defaultInferenceId() {
117+
protected Literal defaultInferenceId() {
122118
return Literal.keyword(Source.EMPTY, org.elasticsearch.xpack.esql.plan.logical.inference.Rerank.DEFAULT_INFERENCE_ID);
123119
}
124120

@@ -129,7 +125,7 @@ public List<Attribute> temporaryAttributes() {
129125

130126
@Override
131127
protected TypeResolution resolveParams() {
132-
return resolveField().and(resoolveQuery());
128+
return resolveField().and(resolveQuery());
133129
}
134130

135131
@Override
@@ -143,7 +139,7 @@ private TypeResolution resolveField() {
143139
);
144140
}
145141

146-
private TypeResolution resoolveQuery() {
142+
private TypeResolution resolveQuery() {
147143
return isString(query, sourceText(), TypeResolutions.ParamOrdinal.SECOND).and(
148144
isNotNull(query, sourceText(), TypeResolutions.ParamOrdinal.SECOND)
149145
).and(isFoldable(query, sourceText(), TypeResolutions.ParamOrdinal.SECOND));

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/RerankFunctionErrorTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.inference;
99

10+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
1011
import org.elasticsearch.xpack.esql.core.expression.Expression;
1112
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
1213
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -15,6 +16,7 @@
1516
import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
1617
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
1718
import org.hamcrest.Matcher;
19+
import org.junit.Before;
1820

1921
import java.util.List;
2022
import java.util.Locale;
@@ -25,6 +27,12 @@
2527
import static org.hamcrest.Matchers.equalTo;
2628

2729
public class RerankFunctionErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
30+
31+
@Before
32+
public void checkCapability() {
33+
assumeTrue("RERANK_FUNCTION is not enabled", EsqlCapabilities.Cap.RERANK_FUNCTION.isEnabled());
34+
}
35+
2836
@Override
2937
protected List<TestCaseSupplier> cases() {
3038
return paramsToSuppliers(RerankFunctionTests.parameters());

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/RerankFunctionSerializationTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,22 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.inference;
99

10+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
1011
import org.elasticsearch.xpack.esql.core.expression.Expression;
1112
import org.elasticsearch.xpack.esql.core.tree.Source;
1213
import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;
14+
import org.junit.Before;
1315

1416
import java.io.IOException;
1517
import java.util.ArrayList;
1618
import java.util.List;
1719

1820
public class RerankFunctionSerializationTests extends AbstractExpressionSerializationTests<RerankFunction> {
21+
@Before
22+
public void checkCapability() {
23+
assumeTrue("RERANK_FUNCTION is not enabled", EsqlCapabilities.Cap.RERANK_FUNCTION.isEnabled());
24+
}
25+
1926
@Override
2027
protected RerankFunction createTestInstance() {
2128
Source source = randomSource();

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/RerankFunctionTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import org.apache.lucene.util.BytesRef;
1414
import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.FieldExpression;
15+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
1516
import org.elasticsearch.xpack.esql.core.expression.Expression;
1617
import org.elasticsearch.xpack.esql.core.expression.Literal;
1718
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
@@ -21,6 +22,7 @@
2122
import org.elasticsearch.xpack.esql.expression.function.FunctionName;
2223
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
2324
import org.hamcrest.Matchers;
25+
import org.junit.Before;
2426

2527
import java.util.ArrayList;
2628
import java.util.Arrays;
@@ -35,6 +37,11 @@
3537

3638
@FunctionName("rerank")
3739
public class RerankFunctionTests extends AbstractFunctionTestCase {
40+
@Before
41+
public void checkCapability() {
42+
assumeTrue("RERANK_FUNCTION is not enabled", EsqlCapabilities.Cap.RERANK_FUNCTION.isEnabled());
43+
}
44+
3845
public RerankFunctionTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
3946
this.testCase = testCaseSupplier.get();
4047
}

0 commit comments

Comments
 (0)