Skip to content

Commit 4fe169a

Browse files
committed
Text embedding analysis and verification.
1 parent 847c998 commit 4fe169a

File tree

4 files changed

+223
-4
lines changed

4 files changed

+223
-4
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
6969
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
7070
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
71+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
7172
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
7273
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
7374
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -1329,7 +1330,8 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log
13291330

13301331
@Override
13311332
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
1332-
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
1333+
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
1334+
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
13331335
}
13341336

13351337
private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
@@ -1358,6 +1360,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext
13581360

13591361
return plan;
13601362
}
1363+
1364+
private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
1365+
if (inferenceFunction.inferenceId().resolved()
1366+
&& inferenceFunction.inferenceId().foldable()
1367+
&& DataType.isString(inferenceFunction.inferenceId().dataType())) {
1368+
1369+
String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
1370+
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
1371+
1372+
if (resolvedInference == null) {
1373+
String error = context.inferenceResolution().getError(inferenceId);
1374+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1375+
}
1376+
1377+
if (resolvedInference.taskType() != inferenceFunction.taskType()) {
1378+
String error = "cannot use inference endpoint ["
1379+
+ inferenceId
1380+
+ "] with task type ["
1381+
+ resolvedInference.taskType()
1382+
+ "] within a "
1383+
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
1384+
+ " function. Only inference endpoints with the task type ["
1385+
+ inferenceFunction.taskType()
1386+
+ "] are supported.";
1387+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1388+
}
1389+
}
1390+
1391+
return inferenceFunction;
1392+
}
13611393
}
13621394

13631395
private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.index.IndexMode;
1111
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
1314
import org.elasticsearch.xpack.esql.EsqlTestUtils;
1415
import org.elasticsearch.xpack.esql.core.type.EsField;
@@ -26,6 +27,7 @@
2627
import org.elasticsearch.xpack.esql.session.Configuration;
2728

2829
import java.util.ArrayList;
30+
import java.util.Arrays;
2931
import java.util.LinkedHashMap;
3032
import java.util.List;
3133
import java.util.Map;
@@ -195,14 +197,39 @@ public static EnrichResolution defaultEnrichResolution() {
195197
return enrichResolution;
196198
}
197199

200+
public static final String RERANKING_INFERENCE_ID = "reranking-inference-id";
201+
public static final String COMPLETION_INFERENCE_ID = "completion-inference-id";
202+
public static final String TEXT_EMBEDDING_INFERENCE_ID = "text-embedding-inference-id";
203+
public static final String CHAT_COMPLETION_INFERENCE_ID = "chat-completion-inference-id";
204+
public static final String SPARSE_EMBEDDING_INFERENCE_ID = "sparse-embedding-inference-id";
205+
public static final List<String> VALID_INFERENCE_IDS = List.of(
206+
RERANKING_INFERENCE_ID,
207+
COMPLETION_INFERENCE_ID,
208+
TEXT_EMBEDDING_INFERENCE_ID,
209+
CHAT_COMPLETION_INFERENCE_ID,
210+
SPARSE_EMBEDDING_INFERENCE_ID
211+
);
212+
public static final String ERROR_INFERENCE_ID = "error-inference-id";
213+
198214
public static InferenceResolution defaultInferenceResolution() {
199215
return InferenceResolution.builder()
200-
.withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
201-
.withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
202-
.withError("error-inference-id", "error with inference resolution")
216+
.withResolvedInference(new ResolvedInference(RERANKING_INFERENCE_ID, TaskType.RERANK))
217+
.withResolvedInference(new ResolvedInference(COMPLETION_INFERENCE_ID, TaskType.COMPLETION))
218+
.withResolvedInference(new ResolvedInference(TEXT_EMBEDDING_INFERENCE_ID, TaskType.TEXT_EMBEDDING))
219+
.withResolvedInference(new ResolvedInference(CHAT_COMPLETION_INFERENCE_ID, TaskType.CHAT_COMPLETION))
220+
.withResolvedInference(new ResolvedInference(SPARSE_EMBEDDING_INFERENCE_ID, TaskType.SPARSE_EMBEDDING))
221+
.withError(ERROR_INFERENCE_ID, "error with inference resolution")
203222
.build();
204223
}
205224

225+
public static String randomInferenceId() {
226+
return ESTestCase.randomFrom(VALID_INFERENCE_IDS);
227+
}
228+
229+
public static String randomInferenceId(String... excludes) {
230+
return ESTestCase.randomValueOtherThanMany(Arrays.asList(excludes)::contains, AnalyzerTestUtils::randomInferenceId);
231+
}
232+
206233
public static void loadEnrichPolicyResolution(
207234
EnrichResolution enrich,
208235
String policyType,

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
5656
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
5757
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
58+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
5859
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
5960
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
6061
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
@@ -121,13 +122,15 @@
121122
import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute;
122123
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
123124
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
125+
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID;
124126
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze;
125127
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer;
126128
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
127129
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution;
128130
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution;
129131
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexWithDateDateNanosUnionType;
130132
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping;
133+
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomInferenceId;
131134
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution;
132135
import static org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG;
133136
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
@@ -3629,6 +3632,115 @@ private void assertEmptyEsRelation(LogicalPlan plan) {
36293632
assertThat(esRelation.output(), equalTo(NO_FIELDS));
36303633
}
36313634

3635+
public void testTextEmbeddingResolveInferenceId() {
3636+
assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
3637+
3638+
LogicalPlan plan = analyze(
3639+
"""
3640+
FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted(
3641+
TEXT_EMBEDDING_INFERENCE_ID
3642+
),
3643+
"mapping-books.json"
3644+
);
3645+
3646+
Eval eval = as(as(plan, Limit.class).child(), Eval.class);
3647+
assertThat(eval.fields(), hasSize(1));
3648+
Alias alias = as(eval.fields().get(0), Alias.class);
3649+
assertThat(alias.name(), equalTo("embedding"));
3650+
TextEmbedding function = as(alias.child(), TextEmbedding.class);
3651+
3652+
assertThat(function.inputText(), equalTo(string("italian food recipe")));
3653+
assertThat(function.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID)));
3654+
}
3655+
3656+
public void testTextEmbeddingFunctionResolveType() {
3657+
assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
3658+
3659+
LogicalPlan plan = analyze(
3660+
"""
3661+
FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted(
3662+
TEXT_EMBEDDING_INFERENCE_ID
3663+
),
3664+
"mapping-books.json"
3665+
);
3666+
3667+
Eval eval = as(as(plan, Limit.class).child(), Eval.class);
3668+
assertThat(eval.fields(), hasSize(1));
3669+
Alias alias = as(eval.fields().get(0), Alias.class);
3670+
assertThat(alias.name(), equalTo("embedding"));
3671+
3672+
TextEmbedding function = as(alias.child(), TextEmbedding.class);
3673+
3674+
assertThat(function.foldable(), equalTo(true));
3675+
assertThat(function.dataType(), equalTo(DENSE_VECTOR));
3676+
}
3677+
3678+
public void testTextEmbeddingFunctionMissingInferenceIdError() {
3679+
assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
3680+
3681+
VerificationException ve = expectThrows(
3682+
VerificationException.class,
3683+
() -> analyze(
3684+
"""
3685+
FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted(
3686+
"unknow-inference-id"
3687+
),
3688+
"mapping-books.json"
3689+
)
3690+
);
3691+
3692+
assertThat(ve.getMessage(), containsString("unresolved inference [unknow-inference-id]"));
3693+
}
3694+
3695+
public void testTextEmbeddingFunctionInvalidInferenceIdError() {
3696+
assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
3697+
3698+
String inferenceId = randomInferenceId(TEXT_EMBEDDING_INFERENCE_ID);
3699+
VerificationException ve = expectThrows(
3700+
VerificationException.class,
3701+
() -> analyze(
3702+
"""
3703+
FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted(inferenceId),
3704+
"mapping-books.json"
3705+
)
3706+
);
3707+
3708+
assertThat(ve.getMessage(), containsString("cannot use inference endpoint [%s] with task type".formatted(inferenceId)));
3709+
}
3710+
3711+
public void testTextEmbeddingFunctionWithoutModel() {
3712+
assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
3713+
3714+
ParsingException ve = expectThrows(ParsingException.class, () -> analyze("""
3715+
FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe")""", "mapping-books.json"));
3716+
3717+
assertThat(
3718+
ve.getMessage(),
3719+
containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments")
3720+
);
3721+
}
3722+
3723+
public void testKnnFunctionWithTextEmbedding() {
3724+
assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled());
3725+
assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
3726+
3727+
String fieldName = randomFrom("float_vector", "byte_vector");
3728+
3729+
LogicalPlan plan = analyze("""
3730+
from test | where KNN(%s, TEXT_EMBEDDING("italian food recipe", "%s"))
3731+
""".formatted(fieldName, TEXT_EMBEDDING_INFERENCE_ID), "mapping-dense_vector.json");
3732+
3733+
Limit limit = as(plan, Limit.class);
3734+
Filter filter = as(limit.child(), Filter.class);
3735+
Knn knn = as(filter.condition(), Knn.class);
3736+
assertThat(knn.field(), instanceOf(FieldAttribute.class));
3737+
assertThat(((FieldAttribute) knn.field()).name(), equalTo(fieldName));
3738+
3739+
TextEmbedding textEmbedding = as(knn.query(), TextEmbedding.class);
3740+
assertThat(textEmbedding.inputText(), equalTo(string("italian food recipe")));
3741+
assertThat(textEmbedding.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID)));
3742+
}
3743+
36323744
public void testResolveRerankInferenceId() {
36333745
{
36343746
LogicalPlan plan = analyze("""

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
4242
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant;
4343
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
44+
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID;
4445
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping;
4546
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
4647
import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT;
@@ -2434,6 +2435,53 @@ public void testInvalidTBucketCalls() {
24342435
}
24352436
}
24362437

2438+
public void testTextEmbeddingFunctionInvalidQuery() {
2439+
assertThat(
2440+
error("from test | EVAL embedding = TEXT_EMBEDDING(null, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID),
2441+
equalTo("1:30: first argument of [TEXT_EMBEDDING(null, ?)] cannot be null, received [null]")
2442+
);
2443+
2444+
assertThat(
2445+
error("from test | EVAL embedding = TEXT_EMBEDDING(42, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID),
2446+
equalTo("1:30: first argument of [TEXT_EMBEDDING(42, ?)] must be [string], found value [42] type [integer]")
2447+
);
2448+
2449+
assertThat(
2450+
error("from test | EVAL embedding = TEXT_EMBEDDING(last_name, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID),
2451+
equalTo("1:30: first argument of [TEXT_EMBEDDING(last_name, ?)] must be a constant, received [last_name]")
2452+
);
2453+
}
2454+
2455+
public void testTextEmbeddingFunctionInvalidInferenceId() {
2456+
assertThat(
2457+
error("from test | EVAL embedding = TEXT_EMBEDDING(?, null)", defaultAnalyzer, "query text"),
2458+
equalTo("1:30: second argument of [TEXT_EMBEDDING(?, null)] cannot be null, received [null]")
2459+
);
2460+
2461+
assertThat(
2462+
error("from test | EVAL embedding = TEXT_EMBEDDING(?, 42)", defaultAnalyzer, "query text"),
2463+
equalTo("1:30: second argument of [TEXT_EMBEDDING(?, 42)] must be [string], found value [42] type [integer]")
2464+
);
2465+
2466+
assertThat(
2467+
error("from test | EVAL embedding = TEXT_EMBEDDING(?, last_name)", defaultAnalyzer, "query text"),
2468+
equalTo("1:30: second argument of [TEXT_EMBEDDING(?, last_name)] must be a constant, received [last_name]")
2469+
);
2470+
}
2471+
2472+
// public void testTextEmbeddingFunctionInvalidInferenceId() {
2473+
// assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
2474+
//
2475+
// ParsingException ve = expectThrows(ParsingException.class, () -> analyze("""
2476+
// FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", CONCAT("machin", title))""",
2477+
// "mapping-books.json"));
2478+
//
2479+
// assertThat(
2480+
// ve.getMessage(),
2481+
// containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments")
2482+
// );
2483+
// }
2484+
24372485
private void checkVectorFunctionsNullArgs(String functionInvocation) throws Exception {
24382486
query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer);
24392487
}

0 commit comments

Comments
 (0)