Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,18 @@ public Expression get(Object key) {
return map.get(key);
} else {
// the key(literal) could be converted to BytesRef by ConvertStringToByteRef
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString()));
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(getKeyAsBytesRef(key));
}
}

public Expression getOrDefault(Object key, Expression defaultValue) {
return containsKey(key) ? get(key) : defaultValue;
}

public boolean containsKey(Object key) {
return keyFoldedMap.containsKey(key) || keyFoldedMap.containsKey(getKeyAsBytesRef(key));
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
Expand All @@ -142,4 +150,8 @@ public String toString() {
String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", "));
return "{ " + str + " }";
}

private BytesRef getKeyAsBytesRef(Object key) {
return new BytesRef(key.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,11 @@ public enum Cap {
*/
KNN_FUNCTION_V2(Build.current().isSnapshot()),

/**
* Support for dense vector embedding function
*/
DENSE_VECTOR_EMBEDDING_FUNCTION(Build.current().isSnapshot()),

LIKE_WITH_LIST_OF_PATTERNS,

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
Expand Down Expand Up @@ -138,6 +139,7 @@
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT;
Expand Down Expand Up @@ -406,7 +408,7 @@ protected LogicalPlan rule(InferencePlan<?> plan, AnalyzerContext context) {
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
return plan;
return plan.withModelConfigurations(resolvedInference.modelConfigurations());
} else if (resolvedInference != null) {
String error = "cannot use inference endpoint ["
+ inferenceId
Expand Down Expand Up @@ -516,6 +518,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
return resolveEval(p, childrenOutput);
}

if (plan instanceof DenseVectorEmbedding dve) {
return resolveDenseVectorEmbedding(dve, childrenOutput);
}

if (plan instanceof Enrich p) {
return resolveEnrich(p, childrenOutput);
}
Expand Down Expand Up @@ -820,6 +826,28 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
return changed ? new Fork(fork.source(), newSubPlans, newOutput) : fork;
}

private LogicalPlan resolveDenseVectorEmbedding(DenseVectorEmbedding p, List<Attribute> childrenOutput) {
// Resolve the input expression
Expression input = p.input();
if (input.resolved() == false) {
input = input.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}

// Resolve the target field (similar to Completion)
Attribute targetField = p.embeddingField();
if (targetField instanceof UnresolvedAttribute ua) {
targetField = new ReferenceAttribute(ua.source(), ua.name(), DENSE_VECTOR);
}

// Create a new DenseVectorEmbedding with resolved expressions
// Only create a new instance if something changed to avoid unnecessary object creation
if (input != p.input() || targetField != p.embeddingField()) {
return p.withTargetField(targetField);
}

return p;
}

private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput) {
List<Alias> newFields = new ArrayList<>();
boolean changed = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.elasticsearch.xpack.esql.analysis;

import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.plan.IndexPattern;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
Expand All @@ -21,32 +23,33 @@
import java.util.Set;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;

/**
* This class is part of the planner. Acts somewhat like a linker, to find the indices and enrich policies referenced by the query.
*/
public class PreAnalyzer {

public static class PreAnalysis {
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList());
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptySet(), emptyList());

public final IndexMode indexMode;
public final List<IndexPattern> indices;
public final List<Enrich> enriches;
public final List<InferencePlan<?>> inferencePlans;
public final Set<String> inferenceIds;
public final List<IndexPattern> lookupIndices;

public PreAnalysis(
IndexMode indexMode,
List<IndexPattern> indices,
List<Enrich> enriches,
List<InferencePlan<?>> inferencePlans,
Set<String> inferenceIds,
List<IndexPattern> lookupIndices
) {
this.indexMode = indexMode;
this.indices = indices;
this.enriches = enriches;
this.inferencePlans = inferencePlans;
this.inferenceIds = inferenceIds;
this.lookupIndices = lookupIndices;
}
}
Expand All @@ -64,7 +67,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {

List<Enrich> unresolvedEnriches = new ArrayList<>();
List<IndexPattern> lookupIndices = new ArrayList<>();
List<InferencePlan<?>> unresolvedInferencePlans = new ArrayList<>();
Set<String> unresolvedInferenceIds = new HashSet<>();
Holder<IndexMode> indexMode = new Holder<>();
plan.forEachUp(UnresolvedRelation.class, p -> {
if (p.indexMode() == IndexMode.LOOKUP) {
Expand All @@ -78,11 +81,28 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
});

plan.forEachUp(Enrich.class, unresolvedEnriches::add);
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);

// mark plan as preAnalyzed (if it were marked, there would be no analysis)
plan.forEachUp(LogicalPlan::setPreAnalyzed);

return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, inferenceIds(plan), lookupIndices);
}

protected Set<String> inferenceIds(LogicalPlan plan) {
Set<String> inferenceIds = new HashSet<>();

List<InferencePlan<?>> inferencePlans = new ArrayList<>();
plan.forEachUp(InferencePlan.class, inferencePlans::add);
inferencePlans.stream().map(this::inferenceId).forEach(inferenceIds::add);

return inferenceIds;
}

private String inferenceId(InferencePlan<?> inferencePlan) {
if (inferencePlan.inferenceId() instanceof Literal literal) {
return BytesRefs.toString(literal.value());
}

throw new IllegalStateException("inferenceId is not a literal");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
Expand Down Expand Up @@ -119,6 +120,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
entries.addAll(inference());
return entries;
}

Expand Down Expand Up @@ -264,4 +266,11 @@ private static List<NamedWriteableRegistry.Entry> vector() {
}
return List.of();
}

private static List<NamedWriteableRegistry.Entry> inference() {
if (EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()) {
return List.of(DenseVectorEmbeddingFunction.ENTRY);
}
return List.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.Term;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, Knn::new, "knn"),
def(DenseVectorEmbeddingFunction.class, bi(DenseVectorEmbeddingFunction::new), "text_dense_vector_embedding"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
Expand Down
Loading
Loading