Skip to content

Commit db6fa4c

Browse files
committed
Introduce a new PRE_OPTIMIZED to the LogicalPlan
1 parent 51f865d commit db6fa4c

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ public PreOptimizer(TransportActionServices services, FoldContext foldContext) {
3636
}
3737

3838
public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
39-
inferencePreOptimizer.foldInferenceFunctions(plan, listener);
39+
if (plan.analyzed() == false) {
40+
throw new IllegalStateException("Expected analyzed plan");
41+
}
42+
43+
inferencePreOptimizer.foldInferenceFunctions(plan, listener.safeMap(p -> {
44+
p.setPreOptimized();
45+
return p;
46+
}));
4047
}
4148

4249
private static class InferencePreOptimizer {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ public boolean optimized() {
6565
return stage.ordinal() >= Stage.OPTIMIZED.ordinal();
6666
}
6767

68+
public void setPreOptimized() {
69+
stage = Stage.PRE_OPTIMIZED;
70+
}
71+
72+
public boolean preOptimized() {
73+
return stage.ordinal() >= Stage.PRE_OPTIMIZED.ordinal();
74+
}
75+
6876
public void setOptimized() {
6977
stage = Stage.OPTIMIZED;
7078
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.plan.logical.Eval;
2828
import org.elasticsearch.xpack.esql.plan.logical.Filter;
2929

30+
import java.util.ArrayList;
3031
import java.util.List;
3132

3233
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
@@ -76,18 +77,20 @@ private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel
7677
relation,
7778
List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId))))
7879
);
80+
eval.setAnalyzed();
7981

8082
SetOnce<Object> preOptimizedPlanHolder = new SetOnce<>();
8183
preOptimizer.preOptimize(eval, ActionListener.wrap(preOptimizedPlanHolder::set, ESTestCase::fail));
8284

8385
assertBusy(() -> {
8486
assertThat(preOptimizedPlanHolder.get(), notNullValue());
8587
Eval preOptimizedEval = as(preOptimizedPlanHolder.get(), Eval.class);
88+
assertThat(preOptimizedEval.preOptimized(), equalTo(true));
8689
assertThat(preOptimizedEval.fields(), hasSize(1));
8790
assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName));
8891
Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class);
8992
assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR));
90-
assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query)));
93+
assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query)));
9194
});
9295
}
9396

@@ -102,6 +105,7 @@ private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel)
102105
relation,
103106
new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null)
104107
);
108+
filter.setAnalyzed();
105109
Knn knn = as(filter.condition(), Knn.class);
106110

107111
SetOnce<Object> preOptimizedHolder = new SetOnce<>();
@@ -110,14 +114,15 @@ private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel)
110114
assertBusy(() -> {
111115
assertThat(preOptimizedHolder.get(), notNullValue());
112116
Filter preOptimizedFilter = as(preOptimizedHolder.get(), Filter.class);
117+
assertThat(preOptimizedFilter.preOptimized(), equalTo(true));
113118
Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class);
114119
assertThat(preOptimizedKnn.field(), equalTo(knn.field()));
115120
assertThat(preOptimizedKnn.k(), equalTo(knn.k()));
116121
assertThat(preOptimizedKnn.options(), equalTo(knn.options()));
117122

118123
Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class);
119124
assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR));
120-
assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query)));
125+
assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query)));
121126
});
122127
}
123128

@@ -141,6 +146,15 @@ private interface TextEmbeddingModelMock {
141146
TextEmbeddingResults<?> embeddingResults(String input);
142147

143148
float[] embedding(String input);
149+
150+
default List<Float> embeddingList(String input) {
151+
float[] embedding = embedding(input);
152+
List<Float> embeddingList = new ArrayList<>(embedding.length);
153+
for (float value : embedding) {
154+
embeddingList.add(value);
155+
}
156+
return embeddingList;
157+
}
144158
}
145159

146160
private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() {

0 commit comments

Comments
 (0)