|
97 | 97 | import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; |
98 | 98 | import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; |
99 | 99 | import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; |
| 100 | +import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding; |
100 | 101 | import org.elasticsearch.xpack.esql.plan.logical.join.Join; |
101 | 102 | import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; |
102 | 103 | import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; |
|
138 | 139 | import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; |
139 | 140 | import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; |
140 | 141 | import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; |
| 142 | +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; |
141 | 143 | import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; |
142 | 144 | import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; |
143 | 145 | import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; |
@@ -516,6 +518,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) { |
516 | 518 | return resolveEval(p, childrenOutput); |
517 | 519 | } |
518 | 520 |
|
| 521 | + if (plan instanceof DenseVectorEmbedding dve) { |
| 522 | + return resolveDenseVectorEmbedding(dve, childrenOutput); |
| 523 | + } |
| 524 | + |
519 | 525 | if (plan instanceof Enrich p) { |
520 | 526 | return resolveEnrich(p, childrenOutput); |
521 | 527 | } |
@@ -820,6 +826,34 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) { |
820 | 826 | return changed ? new Fork(fork.source(), newSubPlans, newOutput) : fork; |
821 | 827 | } |
822 | 828 |
|
| 829 | + private LogicalPlan resolveDenseVectorEmbedding(DenseVectorEmbedding p, List<Attribute> childrenOutput) { |
| 830 | + // Resolve the input expression |
| 831 | + Expression input = p.input(); |
| 832 | + if (input.resolved() == false) { |
| 833 | + input = input.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); |
| 834 | + } |
| 835 | + |
| 836 | + // Resolve the target field (similar to Completion) |
| 837 | + Attribute targetField = p.embeddingField(); |
| 838 | + if (targetField instanceof UnresolvedAttribute ua) { |
| 839 | + targetField = new ReferenceAttribute(ua.source(), ua.name(), DENSE_VECTOR); |
| 840 | + } |
| 841 | + |
| 842 | + // Create a new DenseVectorEmbedding with resolved expressions |
| 843 | + // Only create a new instance if something changed to avoid unnecessary object creation |
| 844 | + if (input != p.input() || targetField != p.embeddingField()) { |
| 845 | + return new DenseVectorEmbedding( |
| 846 | + p.source(), |
| 847 | + p.child(), |
| 848 | + p.inferenceId(), |
| 849 | + input, |
| 850 | + targetField |
| 851 | + ); |
| 852 | + } |
| 853 | + |
| 854 | + return p; |
| 855 | + } |
| 856 | + |
823 | 857 | private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput) { |
824 | 858 | List<Alias> newFields = new ArrayList<>(); |
825 | 859 | boolean changed = false; |
|
0 commit comments