|
88 | 88 | import org.elasticsearch.xpack.esql.inference.InferenceRunner; |
89 | 89 | import org.elasticsearch.xpack.esql.inference.XContentRowEncoder; |
90 | 90 | import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator; |
| 91 | +import org.elasticsearch.xpack.esql.inference.embedding.DenseEmbeddingOperator; |
91 | 92 | import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator; |
92 | 93 | import org.elasticsearch.xpack.esql.plan.logical.Fork; |
93 | 94 | import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; |
|
119 | 120 | import org.elasticsearch.xpack.esql.plan.physical.TopNExec; |
120 | 121 | import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; |
121 | 122 | import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; |
| 123 | +import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec; |
122 | 124 | import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext; |
123 | 125 | import org.elasticsearch.xpack.esql.plugin.QueryPragmas; |
124 | 126 | import org.elasticsearch.xpack.esql.score.ScoreMapper; |
@@ -266,6 +268,8 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c |
266 | 268 | return planChangePoint(changePoint, context); |
267 | 269 | } else if (node instanceof CompletionExec completion) { |
268 | 270 | return planCompletion(completion, context); |
| 271 | + } else if (node instanceof DenseVectorEmbeddingExec embedding) { |
| 272 | + return planDenseVectorEmbedding(embedding, context); |
269 | 273 | } else if (node instanceof SampleExec Sample) { |
270 | 274 | return planSample(Sample, context); |
271 | 275 | } |
@@ -319,6 +323,31 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti |
319 | 323 | return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout); |
320 | 324 | } |
321 | 325 |
|
| 326 | + private PhysicalOperation planDenseVectorEmbedding(DenseVectorEmbeddingExec embedding, LocalExecutionPlannerContext context) { |
| 327 | + PhysicalOperation source = plan(embedding.child(), context); |
| 328 | + String inferenceId = BytesRefs.toString(embedding.inferenceId().fold(context.foldCtx())); |
| 329 | + |
| 330 | + int dimensions; |
| 331 | + if (embedding.dimensions() instanceof Literal literal) { |
| 332 | + Object val = literal.value() instanceof BytesRef br ? BytesRefs.toString(br) : literal.value(); |
| 333 | + dimensions = stringToInt(val.toString()); |
| 334 | + } else { |
| 335 | + throw new EsqlIllegalArgumentException("dimensions only supported with literal values"); |
| 336 | + } |
| 337 | + |
| 338 | + Layout outputLayout = source.layout.builder().append(embedding.targetField()).build(); |
| 339 | + EvalOperator.ExpressionEvaluator.Factory inputEvaluatorFactory = EvalMapper.toEvaluator( |
| 340 | + context.foldCtx(), |
| 341 | + embedding.input(), |
| 342 | + source.layout |
| 343 | + ); |
| 344 | + |
| 345 | + return source.with( |
| 346 | + new DenseEmbeddingOperator.Factory(inferenceRunner, inferenceId, dimensions, inputEvaluatorFactory), |
| 347 | + outputLayout |
| 348 | + ); |
| 349 | + } |
| 350 | + |
322 | 351 | private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) { |
323 | 352 | PhysicalOperation source = plan(rrf.child(), context); |
324 | 353 |
|
|
0 commit comments