Skip to content

Commit 0b71d74

Browse files
committed
Introducing inference function.
1 parent d86d16b commit 0b71d74

File tree

9 files changed

+387
-2
lines changed

9 files changed

+387
-2
lines changed

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,20 @@ public int hashCode() {
115115
return Objects.hash(entryExpressions);
116116
}
117117

118+
public Expression getOrDefault(Object key, Expression defaultValue) {
119+
return containsKey(key) ? get(key) : defaultValue;
120+
}
121+
122+
public boolean containsKey(Object key) {
123+
return keyFoldedMap.containsKey(key) || keyFoldedMap.containsKey(getKeyAsBytesRef(key));
124+
}
125+
118126
public Expression get(Object key) {
119127
if (key instanceof Expression) {
120128
return map.get(key);
121129
} else {
122130
// the key(literal) could be converted to BytesRef by ConvertStringToByteRef
123-
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString()));
131+
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(getKeyAsBytesRef(key));
124132
}
125133
}
126134

@@ -142,4 +150,8 @@ public String toString() {
142150
String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", "));
143151
return "{ " + str + " }";
144152
}
153+
154+
private BytesRef getKeyAsBytesRef(Object key) {
155+
return new BytesRef(key.toString());
156+
}
145157
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,11 @@ public enum Cap {
927927
*/
928928
RERANK(Build.current().isSnapshot()),
929929

930+
/**
931+
* Support for RERANK as a function
932+
*/
933+
RERANK_FUNCTION(Build.current().isSnapshot()),
934+
930935
/**
931936
* Support for COMPLETION command
932937
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
1414
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
1515
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
16+
import org.elasticsearch.xpack.esql.expression.function.inference.RerankFunction;
1617
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
1718
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
1819
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
@@ -119,6 +120,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
119120
entries.addAll(fullText());
120121
entries.addAll(unaryScalars());
121122
entries.addAll(vector());
123+
entries.addAll(inferenceFunctions());
122124
return entries;
123125
}
124126

@@ -264,4 +266,11 @@ private static List<NamedWriteableRegistry.Entry> vector() {
264266
}
265267
return List.of();
266268
}
269+
270+
private static List<NamedWriteableRegistry.Entry> inferenceFunctions() {
271+
if (EsqlCapabilities.Cap.RERANK_FUNCTION.isEnabled()) {
272+
return List.of(RerankFunction.ENTRY);
273+
}
274+
return List.of();
275+
}
267276
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.logging.LogManager;
11+
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
12+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
13+
import org.elasticsearch.xpack.esql.core.expression.Expression;
14+
import org.elasticsearch.xpack.esql.core.expression.Literal;
15+
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
16+
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
17+
import org.elasticsearch.xpack.esql.core.expression.function.Function;
18+
import org.elasticsearch.xpack.esql.core.tree.Source;
19+
import org.elasticsearch.xpack.esql.core.type.DataType;
20+
21+
import java.util.List;
22+
import java.util.Objects;
23+
import java.util.stream.Stream;
24+
25+
public abstract class InferenceFunction extends Function {
26+
27+
private Expression inferenceId;
28+
29+
private final Expression options;
30+
31+
@SuppressWarnings("this-escape")
32+
protected InferenceFunction(Source source, List<Expression> children, Expression options) {
33+
super(source, Stream.concat(children.stream(), Stream.of(options)).toList());
34+
this.inferenceId = parseInferenceId(options);
35+
this.options = options;
36+
}
37+
38+
public Expression inferenceId() {
39+
return inferenceId;
40+
}
41+
42+
public Expression options() {
43+
return options;
44+
}
45+
46+
protected abstract Expression parseInferenceId(Expression options);
47+
48+
public abstract List<Attribute> temporaryAttributes();
49+
50+
protected Expression readOption(String optionName, TypeResolutions.ParamOrdinal optionParamOrd, Expression options) {
51+
return readOption(optionName, optionParamOrd, options, Literal.NULL);
52+
}
53+
54+
protected Expression readOption(String optionName, TypeResolutions.ParamOrdinal optionParamOrd, Expression options, Expression defaultValue) {
55+
if (options != null && options.dataType() != DataType.NULL && options instanceof MapExpression mapOptions) {
56+
return mapOptions.getOrDefault(optionName, defaultValue);
57+
}
58+
59+
return defaultValue;
60+
}
61+
62+
@Override
63+
protected TypeResolution resolveType() {
64+
if (childrenResolved() == false) {
65+
return new TypeResolution("Unresolved children");
66+
}
67+
68+
return resolveOptions().and(resolveParams());
69+
}
70+
71+
protected abstract TypeResolution resolveParams();
72+
73+
protected abstract TypeResolution resolveOptions();
74+
75+
@Override
76+
public boolean equals(Object o) {
77+
if (o == null || getClass() != o.getClass()) return false;
78+
if (!super.equals(o)) return false;
79+
InferenceFunction that = (InferenceFunction) o;
80+
return Objects.equals(inferenceId, that.inferenceId) && Objects.equals(options, that.options);
81+
}
82+
83+
@Override
84+
public int hashCode() {
85+
return Objects.hash(super.hashCode(), inferenceId, options);
86+
}
87+
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package org.elasticsearch.xpack.esql.expression.function.inference;
2+
3+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
4+
import org.elasticsearch.common.io.stream.StreamInput;
5+
import org.elasticsearch.common.io.stream.StreamOutput;
6+
import org.elasticsearch.xpack.esql.core.expression.*;
7+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
8+
import org.elasticsearch.xpack.esql.core.tree.Source;
9+
import org.elasticsearch.xpack.esql.core.type.DataType;
10+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
11+
import org.elasticsearch.xpack.esql.expression.function.MapParam;
12+
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
13+
import org.elasticsearch.xpack.esql.expression.function.Param;
14+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
15+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
16+
17+
import java.io.IOException;
18+
import java.util.List;
19+
import java.util.Objects;
20+
import java.util.UUID;
21+
22+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
23+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
24+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
25+
26+
public class RerankFunction extends InferenceFunction implements OptionalArgument {
27+
28+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Rerank", RerankFunction::new);
29+
30+
private final Expression query;
31+
private final Expression field;
32+
private final Attribute tmpAttribute;
33+
34+
@FunctionInfo(
35+
returnType = "double",
36+
description = "Compute text similarity score using an inference model."
37+
)
38+
public RerankFunction(
39+
Source source,
40+
@Param(name = "field", type = { "keyword", "text" }, description = "Field used as input of the reranker") Expression field,
41+
@Param(name = "query", type = { "keyword", "text" }, description = "The query") Expression query,
42+
@MapParam(
43+
name = "options",
44+
params = {
45+
@MapParam.MapParamEntry(
46+
name = "inference_id",
47+
type = "keyword",
48+
valueHint = { ".rerank-v1-elasticsearch" },
49+
description = "Reranker inference endpoint to use."
50+
)
51+
},
52+
optional = true
53+
) Expression options
54+
) {
55+
this(source, field, query, options, new ReferenceAttribute(Source.EMPTY, ENTRY.name + "_" + UUID.randomUUID(), DataType.DOUBLE));
56+
}
57+
58+
59+
private RerankFunction(Source source, Expression field, Expression query, Expression options, Attribute tmpAttribute) {
60+
super(source, List.of(field, query), options);
61+
this.query = query;
62+
this.field = field;
63+
this.tmpAttribute = tmpAttribute;
64+
}
65+
66+
public RerankFunction(StreamInput in) throws IOException {
67+
this(
68+
Source.readFrom((PlanStreamInput) in),
69+
in.readNamedWriteable(Expression.class),
70+
in.readNamedWriteable(Expression.class),
71+
in.readNamedWriteable(Expression.class),
72+
in.readNamedWriteable(Attribute.class)
73+
);
74+
}
75+
76+
@Override
77+
public void writeTo(StreamOutput out) throws IOException {
78+
source().writeTo(out);
79+
out.writeNamedWriteable(field);
80+
out.writeNamedWriteable(query);
81+
out.writeNamedWriteable(options());
82+
out.writeNamedWriteable(tmpAttribute);
83+
}
84+
85+
@Override
86+
public String functionName() {
87+
return getWriteableName();
88+
}
89+
90+
@Override
91+
public DataType dataType() {
92+
return DataType.DOUBLE;
93+
}
94+
95+
@Override
96+
public Expression replaceChildren(List<Expression> newChildren) {
97+
return new RerankFunction(source(), newChildren.get(0), newChildren.get(1), options(), tmpAttribute);
98+
}
99+
100+
@Override
101+
protected NodeInfo<? extends Expression> info() {
102+
return NodeInfo.create(this, RerankFunction::new, query, field, options(), tmpAttribute);
103+
}
104+
105+
@Override
106+
public String getWriteableName() {
107+
return ENTRY.name;
108+
}
109+
110+
@Override
111+
protected Expression parseInferenceId(Expression options) {
112+
return readOption("inference_id", TypeResolutions.ParamOrdinal.THIRD, options, defaultInferenceId());
113+
}
114+
115+
private Literal defaultInferenceId() {
116+
return new Literal(Source.EMPTY, Rerank.DEFAULT_INFERENCE_ID, DataType.KEYWORD);
117+
}
118+
119+
@Override
120+
public List<Attribute> temporaryAttributes() {
121+
return List.of(tmpAttribute);
122+
}
123+
124+
@Override
125+
protected TypeResolution resolveParams() {
126+
return resolveField().and(resoolveQueru());
127+
}
128+
129+
@Override
130+
protected TypeResolution resolveOptions() {
131+
return TypeResolution.TYPE_RESOLVED;
132+
}
133+
134+
private TypeResolution resolveField() {
135+
return isString(field, functionName(), TypeResolutions.ParamOrdinal.FIRST)
136+
.and(isNotNull(field, functionName(), TypeResolutions.ParamOrdinal.FIRST));
137+
}
138+
139+
private TypeResolution resoolveQueru() {
140+
return isString(query, functionName(), TypeResolutions.ParamOrdinal.SECOND)
141+
.and(isNotNull(query, functionName(), TypeResolutions.ParamOrdinal.SECOND))
142+
.and(isFoldable(query, functionName(), TypeResolutions.ParamOrdinal.SECOND));
143+
}
144+
145+
@Override
146+
public boolean equals(Object o) {
147+
if (this == o) return true;
148+
if (o == null || getClass() != o.getClass()) return false;
149+
if (!super.equals(o)) return false;
150+
RerankFunction that = (RerankFunction) o;
151+
return Objects.equals(query, that.query) && Objects.equals(field, that.field) && Objects.equals(tmpAttribute, that.tmpAttribute);
152+
}
153+
154+
@Override
155+
public int hashCode() {
156+
return Objects.hash(super.hashCode(), query, field, tmpAttribute);
157+
}
158+
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ErrorsForCasesWithoutExamplesTestCase.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ public final void test() {
7272
List<Set<DataType>> validPerPosition = AbstractFunctionTestCase.validPerPosition(valid);
7373
Iterable<List<DataType>> testCandidates = testCandidates(cases, valid)::iterator;
7474
for (List<DataType> signature : testCandidates) {
75-
logger.debug("checking {}", signature);
7675
List<Expression> args = new ArrayList<>(signature.size());
7776
for (DataType type : signature) {
7877
args.add(randomLiteral(type));

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,10 @@ public TestCase(List<TypedData> data, Matcher<String> evaluatorToString, DataTyp
14521452
this(data, evaluatorToString, expectedType, matcher, null, null, null, null, null, null);
14531453
}
14541454

1455+
public TestCase(List<TypedData> data) {
1456+
this(data, null, null, null, null, null, null, null, null, false);
1457+
}
1458+
14551459
/**
14561460
* Build a test case for type errors.
14571461
*
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.logging.LogManager;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.tree.Source;
13+
import org.elasticsearch.xpack.esql.core.type.DataType;
14+
import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
15+
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
16+
import org.hamcrest.Matcher;
17+
18+
import java.util.List;
19+
import java.util.Set;
20+
import java.util.stream.Stream;
21+
22+
import static org.hamcrest.Matchers.equalTo;
23+
24+
public class RerankFunctionErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
25+
@Override
26+
protected List<TestCaseSupplier> cases() {
27+
return paramsToSuppliers(RerankFunctionTests.parameters());
28+
}
29+
30+
@Override
31+
protected Stream<List<DataType>> testCandidates(List<TestCaseSupplier> cases, Set<List<DataType>> valid) {
32+
// Don't test null, as it is not allowed but the expected message is not a type error - so we check it separately in VerifierTests
33+
return super.testCandidates(cases, valid).filter(sig -> false == sig.contains(DataType.NULL));
34+
}
35+
36+
@Override
37+
protected Expression build(Source source, List<Expression> args) {
38+
LogManager.getLogger(RerankFunctionErrorTests.class).error("{}", args);
39+
return new RerankFunction(source, args.get(0), args.get(1), args.get(2));
40+
}
41+
42+
@Override
43+
protected Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature) {
44+
return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> switch (p) {
45+
case 0, 1 -> "string";
46+
default -> "";
47+
}));
48+
}
49+
}

0 commit comments

Comments
 (0)