Skip to content

Commit 94fb918

Browse files
committed
First implementation of inference function.
1 parent a36e3ae commit 94fb918

File tree

12 files changed

+605
-1
lines changed

12 files changed

+605
-1
lines changed

docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,18 @@ public Expression get(Object key) {
120120
return map.get(key);
121121
} else {
122122
// the key(literal) could be converted to BytesRef by ConvertStringToByteRef
123-
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString()));
123+
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(getKeyAsBytesRef(key));
124124
}
125125
}
126126

127+
public Expression getOrDefault(Object key, Expression defaultValue) {
128+
return containsKey(key) ? get(key) : defaultValue;
129+
}
130+
131+
public boolean containsKey(Object key) {
132+
return keyFoldedMap.containsKey(key) || keyFoldedMap.containsKey(getKeyAsBytesRef(key));
133+
}
134+
127135
@Override
128136
public boolean equals(Object obj) {
129137
if (this == obj) {
@@ -142,4 +150,8 @@ public String toString() {
142150
String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", "));
143151
return "{ " + str + " }";
144152
}
153+
154+
private BytesRef getKeyAsBytesRef(Object key) {
155+
return new BytesRef(key.toString());
156+
}
145157
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
1414
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
1515
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
16+
import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction;
1617
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
1718
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
1819
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
@@ -119,6 +120,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
119120
entries.addAll(fullText());
120121
entries.addAll(unaryScalars());
121122
entries.addAll(vector());
123+
entries.addAll(inference());
122124
return entries;
123125
}
124126

@@ -264,4 +266,11 @@ private static List<NamedWriteableRegistry.Entry> vector() {
264266
}
265267
return List.of();
266268
}
269+
270+
private static List<NamedWriteableRegistry.Entry> inference() {
271+
if (EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()) {
272+
return List.of(DenseVectorEmbeddingFunction.ENTRY);
273+
}
274+
return List.of();
275+
}
267276
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.elasticsearch.xpack.esql.expression.function.fulltext.Term;
5353
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
5454
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
55+
import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction;
5556
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
5657
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
5758
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
@@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
479480
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
480481
def(Term.class, bi(Term::new), "term"),
481482
def(Knn.class, tri(Knn::new), "knn"),
483+
def(DenseVectorEmbeddingFunction.class, bi(DenseVectorEmbeddingFunction::new), "text_dense_vector_embedding"),
482484
def(StGeohash.class, StGeohash::new, "st_geohash"),
483485
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
484486
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
14+
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.expression.Literal;
16+
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
17+
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
18+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
19+
import org.elasticsearch.xpack.esql.core.tree.Source;
20+
import org.elasticsearch.xpack.esql.core.type.DataType;
21+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
22+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
23+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
24+
import org.elasticsearch.xpack.esql.expression.function.MapParam;
25+
import org.elasticsearch.xpack.esql.expression.function.Param;
26+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
27+
28+
import java.io.IOException;
29+
import java.util.List;
30+
import java.util.Objects;
31+
import java.util.UUID;
32+
33+
/**
34+
* * A function that embeds input text into a dense vector representation using an inference model.
35+
*/
36+
public class DenseVectorEmbeddingFunction extends InferenceFunction {
37+
38+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
39+
Expression.class,
40+
"TextDenseVectorEmbedding",
41+
DenseVectorEmbeddingFunction::new
42+
);
43+
44+
private final Expression inputText;
45+
private final Attribute tmpAttribute;
46+
47+
@FunctionInfo(
48+
returnType = "dense_vector",
49+
preview = true,
50+
description = "Embed input text into a dense vector representation using an inference model.",
51+
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
52+
)
53+
public DenseVectorEmbeddingFunction(
54+
Source source,
55+
@Param(name = "inputText", type = { "keyword", "text" }, description = "Input text") Expression inputText,
56+
@MapParam(
57+
name = "options",
58+
params = { @MapParam.MapParamEntry(name = "inference_id", type = "keyword", description = "Inference endpoint to use.") },
59+
optional = true
60+
) Expression options
61+
) {
62+
this(source, inputText, options, new ReferenceAttribute(Source.EMPTY, ENTRY.name + "_" + UUID.randomUUID(), DataType.DOUBLE));
63+
}
64+
65+
private DenseVectorEmbeddingFunction(Source source, Expression inputText, Expression options, Attribute tmpAttribute) {
66+
super(source, List.of(inputText, tmpAttribute), options);
67+
this.inputText = inputText;
68+
this.tmpAttribute = tmpAttribute;
69+
}
70+
71+
public DenseVectorEmbeddingFunction(StreamInput in) throws IOException {
72+
this(
73+
Source.readFrom((PlanStreamInput) in),
74+
in.readNamedWriteable(Expression.class),
75+
in.readNamedWriteable(Expression.class),
76+
in.readNamedWriteable(Attribute.class)
77+
);
78+
}
79+
80+
@Override
81+
public void writeTo(StreamOutput out) throws IOException {
82+
source().writeTo(out);
83+
out.writeNamedWriteable(inputText);
84+
out.writeNamedWriteable(options());
85+
out.writeNamedWriteable(tmpAttribute);
86+
}
87+
88+
@Override
89+
public String functionName() {
90+
super.functionName();
91+
return getWriteableName();
92+
}
93+
94+
@Override
95+
public DataType dataType() {
96+
return DataType.DENSE_VECTOR;
97+
}
98+
99+
@Override
100+
public DenseVectorEmbeddingFunction replaceChildren(List<Expression> newChildren) {
101+
return new DenseVectorEmbeddingFunction(
102+
source(),
103+
newChildren.get(0),
104+
newChildren.size() > 1 ? newChildren.get(1) : null,
105+
tmpAttribute
106+
);
107+
}
108+
109+
@Override
110+
protected NodeInfo<? extends Expression> info() {
111+
return NodeInfo.create(this, DenseVectorEmbeddingFunction::new, inputText, options(), tmpAttribute);
112+
}
113+
114+
@Override
115+
public String getWriteableName() {
116+
return ENTRY.name;
117+
}
118+
119+
@Override
120+
protected Literal defaultInferenceId() {
121+
return Literal.NULL;
122+
}
123+
124+
@Override
125+
public List<Attribute> temporaryAttributes() {
126+
return List.of(tmpAttribute);
127+
}
128+
129+
@Override
130+
protected TypeResolution resolveParams() {
131+
return TypeResolutions.isString(inputText, sourceText(), TypeResolutions.ParamOrdinal.FIRST);
132+
}
133+
134+
@Override
135+
protected TypeResolutions.ParamOrdinal optionsParamsOrdinal() {
136+
return TypeResolutions.ParamOrdinal.SECOND;
137+
}
138+
139+
@Override
140+
public boolean equals(Object o) {
141+
if (this == o) return true;
142+
if (o == null || getClass() != o.getClass()) return false;
143+
if (!super.equals(o)) return false;
144+
DenseVectorEmbeddingFunction that = (DenseVectorEmbeddingFunction) o;
145+
return Objects.equals(inputText, that.inputText) && Objects.equals(tmpAttribute, that.tmpAttribute);
146+
}
147+
148+
@Override
149+
public int hashCode() {
150+
return Objects.hash(super.hashCode(), inputText, tmpAttribute);
151+
}
152+
}

0 commit comments

Comments
 (0)