Skip to content

Commit 0ae1e0e

Browse files
committed
Init inference function evaluator.
1 parent 9298450 commit 0ae1e0e

File tree

6 files changed

+97
-21
lines changed

6 files changed

+97
-21
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public void esql(
8888
indexResolver,
8989
enrichPolicyResolver,
9090
preAnalyzer,
91-
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)),
91+
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext, services.inferenceService())),
9292
functionRegistry,
9393
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
9494
mapper,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.common.lucene.BytesRefs;
12+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
13+
import org.elasticsearch.compute.operator.Operator;
14+
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
16+
import org.elasticsearch.xpack.esql.core.expression.Literal;
17+
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
18+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
19+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
20+
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator;
21+
22+
public class InferenceFunctionEvaluator {
23+
24+
private final FoldContext foldContext;
25+
private final InferenceService inferenceService;
26+
27+
public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) {
28+
this.foldContext = foldContext;
29+
this.inferenceService = inferenceService;
30+
}
31+
32+
public void fold(InferenceFunction<?> f, ActionListener<Object> listener) {
33+
assert f.foldable() : "Inference function must be foldable";
34+
35+
36+
}
37+
38+
private Operator.OperatorFactory createInferenceOperatorFactory(InferenceFunction<?> f) {
39+
return switch (f) {
40+
case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory(
41+
inferenceService,
42+
inferenceId(f),
43+
expressionEvaluatorFactory(textEmbedding.inputText())
44+
);
45+
default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName());
46+
};
47+
}
48+
49+
private String inferenceId(InferenceFunction<?> f) {
50+
return BytesRefs.toString(f.inferenceId().fold(foldContext));
51+
}
52+
53+
private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) {
54+
assert e.foldable() : "Input expression must be foldable";
55+
return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null);
56+
}
57+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
package org.elasticsearch.xpack.esql.optimizer;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.SubscribableListener;
12+
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
1113
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1214

15+
import java.util.List;
16+
1317
/**
1418
* The class is responsible for invoking any steps that need to be applied to the logical plan,
1519
* before this is being optimized.
@@ -25,6 +29,8 @@ public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) {
2529
this.preOptimizerContext = preOptimizerContext;
2630
}
2731

32+
private static final List<Rule> RULES = List.of();
33+
2834
/**
2935
* Pre-optimize a logical plan.
3036
*
@@ -44,7 +50,27 @@ public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener)
4450
}
4551

4652
private void doPreOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
47-
// this is where we will be executing async tasks
48-
listener.onResponse(plan);
53+
SubscribableListener<LogicalPlan> ruleChainListener = SubscribableListener.newSucceeded(plan);
54+
for (Rule rule : RULES) {
55+
ruleChainListener = ruleChainListener.andThen((l, p) -> rule.apply(p, l));
56+
}
57+
ruleChainListener.addListener(listener);
58+
}
59+
60+
public interface Rule {
61+
void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener);
62+
}
63+
64+
private static class FoldInferenceFunction implements Rule {
65+
private final InferenceFunctionEvaluator inferenceEvaluator;
66+
67+
private FoldInferenceFunction(LogicalPreOptimizerContext preOptimizerContext) {
68+
this.inferenceEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService());
69+
}
70+
71+
@Override
72+
public void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
73+
74+
}
4975
}
5076
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,29 @@
88
package org.elasticsearch.xpack.esql.optimizer;
99

1010
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
11-
12-
import java.util.Objects;
11+
import org.elasticsearch.xpack.esql.inference.InferenceService;
1312

1413
public class LogicalPreOptimizerContext {
1514

1615
private final FoldContext foldCtx;
1716

18-
public LogicalPreOptimizerContext(FoldContext foldCtx) {
17+
private final InferenceService inferenceService;
18+
19+
public LogicalPreOptimizerContext(FoldContext foldCtx, InferenceService inferenceService) {
1920
this.foldCtx = foldCtx;
21+
this.inferenceService = inferenceService;
2022
}
2123

2224
public FoldContext foldCtx() {
2325
return foldCtx;
2426
}
2527

26-
@Override
27-
public boolean equals(Object obj) {
28-
if (obj == this) return true;
29-
if (obj == null || obj.getClass() != this.getClass()) return false;
30-
var that = (LogicalPreOptimizerContext) obj;
31-
return this.foldCtx.equals(that.foldCtx);
32-
}
33-
34-
@Override
35-
public int hashCode() {
36-
return Objects.hash(foldCtx);
37-
}
38-
3928
@Override
4029
public String toString() {
4130
return "LogicalPreOptimizerContext[foldCtx=" + foldCtx + ']';
4231
}
32+
33+
public InferenceService inferenceService() {
34+
return inferenceService;
35+
}
4336
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
588588
null,
589589
null,
590590
null,
591-
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx)),
591+
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx, null)),
592592
functionRegistry,
593593
new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)),
594594
mapper,

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception {
7272
}
7373

7474
private LogicalPlanPreOptimizer preOptimizer() {
75-
LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small());
75+
LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small(), null);
7676
return new LogicalPlanPreOptimizer(preOptimizerContext);
7777
}
7878

0 commit comments

Comments
 (0)