Skip to content

Commit 244a0c9

Browse files
committed
Logical & physical plan implementation.
1 parent 94fb918 commit 244a0c9

File tree

5 files changed

+402
-0
lines changed

5 files changed

+402
-0
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.xpack.esql.plan.logical.TopN;
2727
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
2828
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
29+
import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding;
2930
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
3031
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
3132
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
@@ -55,6 +56,7 @@
5556
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
5657
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
5758
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
59+
import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec;
5860

5961
import java.util.ArrayList;
6062
import java.util.List;
@@ -72,6 +74,7 @@ public static List<NamedWriteableRegistry.Entry> logical() {
7274
return List.of(
7375
Aggregate.ENTRY,
7476
Completion.ENTRY,
77+
DenseVectorEmbedding.ENTRY,
7578
Dissect.ENTRY,
7679
Enrich.ENTRY,
7780
EsRelation.ENTRY,
@@ -99,6 +102,7 @@ public static List<NamedWriteableRegistry.Entry> physical() {
99102
return List.of(
100103
AggregateExec.ENTRY,
101104
CompletionExec.ENTRY,
105+
DenseVectorEmbeddingExec.ENTRY,
102106
DissectExec.ENTRY,
103107
EnrichExec.ENTRY,
104108
EsQueryExec.ENTRY,
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference.embedding;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
15+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
16+
import org.elasticsearch.xpack.esql.core.expression.Expression;
17+
import org.elasticsearch.xpack.esql.core.expression.NameId;
18+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
19+
import org.elasticsearch.xpack.esql.core.tree.Source;
20+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
21+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
22+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
23+
24+
import java.io.IOException;
25+
import java.util.List;
26+
import java.util.Objects;
27+
28+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
29+
30+
public class DenseVectorEmbedding extends InferencePlan<DenseVectorEmbedding> implements TelemetryAware {
31+
32+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
33+
LogicalPlan.class,
34+
"DenseVectorEmbedding",
35+
DenseVectorEmbedding::new
36+
);
37+
38+
private final Expression input;
39+
private final Attribute targetField;
40+
private List<Attribute> lazyOutput;
41+
42+
public DenseVectorEmbedding(Source source, LogicalPlan child, Expression inferenceId, Expression input, Attribute targetField) {
43+
super(source, child, inferenceId);
44+
this.input = input;
45+
this.targetField = targetField;
46+
}
47+
48+
public DenseVectorEmbedding(StreamInput in) throws IOException {
49+
this(
50+
Source.readFrom((PlanStreamInput) in),
51+
in.readNamedWriteable(LogicalPlan.class),
52+
in.readNamedWriteable(Expression.class),
53+
in.readNamedWriteable(Expression.class),
54+
in.readNamedWriteable(Attribute.class)
55+
);
56+
}
57+
58+
@Override
59+
public void writeTo(StreamOutput out) throws IOException {
60+
source().writeTo(out);
61+
out.writeNamedWriteable(child());
62+
out.writeNamedWriteable(inferenceId());
63+
out.writeNamedWriteable(input);
64+
out.writeNamedWriteable(targetField);
65+
}
66+
67+
public Expression input() {
68+
return input;
69+
}
70+
71+
public Attribute embeddingField() {
72+
return targetField;
73+
}
74+
75+
@Override
76+
public TaskType taskType() {
77+
return TaskType.TEXT_EMBEDDING;
78+
}
79+
80+
@Override
81+
public String getWriteableName() {
82+
return ENTRY.name;
83+
}
84+
85+
@Override
86+
public List<Attribute> output() {
87+
if (lazyOutput == null) {
88+
lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
89+
}
90+
return lazyOutput;
91+
}
92+
93+
@Override
94+
public List<Attribute> generatedAttributes() {
95+
return List.of(targetField);
96+
}
97+
98+
@Override
99+
public DenseVectorEmbedding withGeneratedNames(List<String> newNames) {
100+
checkNumberOfNewNames(newNames);
101+
return new DenseVectorEmbedding(source(), child(), inferenceId(), input, this.renameTargetField(newNames.get(0)));
102+
}
103+
104+
private Attribute renameTargetField(String newName) {
105+
if (newName.equals(targetField.name())) {
106+
return targetField;
107+
}
108+
109+
return targetField.withName(newName).withId(new NameId());
110+
}
111+
112+
113+
@Override
114+
public boolean expressionsResolved() {
115+
return super.expressionsResolved() && input.resolved() && targetField.resolved();
116+
}
117+
118+
@Override
119+
public DenseVectorEmbedding withInferenceId(Expression newInferenceId) {
120+
return new DenseVectorEmbedding(source(), child(), newInferenceId, input, targetField);
121+
}
122+
123+
@Override
124+
public DenseVectorEmbedding replaceChild(LogicalPlan newChild) {
125+
return new DenseVectorEmbedding(source(), newChild, inferenceId(), input, targetField);
126+
}
127+
128+
@Override
129+
protected NodeInfo<? extends LogicalPlan> info() {
130+
return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), input, targetField);
131+
}
132+
133+
@Override
134+
public boolean equals(Object o) {
135+
if (this == o) return true;
136+
if (o == null || getClass() != o.getClass()) return false;
137+
if (super.equals(o) == false) return false;
138+
DenseVectorEmbedding that = (DenseVectorEmbedding) o;
139+
return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField);
140+
}
141+
142+
@Override
143+
public int hashCode() {
144+
return Objects.hash(super.hashCode(), input, targetField);
145+
}
146+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference.embedding;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
14+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
15+
import org.elasticsearch.xpack.esql.core.expression.Expression;
16+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
17+
import org.elasticsearch.xpack.esql.core.tree.Source;
18+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
19+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
20+
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
21+
import org.elasticsearch.xpack.esql.plan.physical.inference.InferenceExec;
22+
23+
import java.io.IOException;
24+
import java.util.List;
25+
import java.util.Objects;
26+
27+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
28+
29+
public class DenseVectorEmbeddingExec extends InferenceExec {
30+
31+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
32+
PhysicalPlan.class,
33+
"DenseVectorEmbeddingExec",
34+
DenseVectorEmbeddingExec::new
35+
);
36+
37+
private final Expression input;
38+
private final Attribute targetField;
39+
private List<Attribute> lazyOutput;
40+
41+
public DenseVectorEmbeddingExec(
42+
Source source,
43+
PhysicalPlan child,
44+
Expression inferenceId,
45+
Expression input,
46+
Attribute targetField
47+
) {
48+
super(source, child, inferenceId);
49+
this.input = input;
50+
this.targetField = targetField;
51+
}
52+
53+
public DenseVectorEmbeddingExec(StreamInput in) throws IOException {
54+
this(
55+
Source.readFrom((PlanStreamInput) in),
56+
in.readNamedWriteable(PhysicalPlan.class),
57+
in.readNamedWriteable(Expression.class),
58+
in.readNamedWriteable(Expression.class),
59+
in.readNamedWriteable(Attribute.class)
60+
);
61+
}
62+
63+
public Expression input() {
64+
return input;
65+
}
66+
67+
public Attribute targetField() {
68+
return targetField;
69+
}
70+
71+
@Override
72+
public String getWriteableName() {
73+
return ENTRY.name;
74+
}
75+
76+
@Override
77+
public void writeTo(StreamOutput out) throws IOException {
78+
super.writeTo(out);
79+
out.writeNamedWriteable(input);
80+
out.writeNamedWriteable(targetField);
81+
}
82+
83+
@Override
84+
protected NodeInfo<? extends PhysicalPlan> info() {
85+
return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, targetField);
86+
}
87+
88+
@Override
89+
public UnaryExec replaceChild(PhysicalPlan newChild) {
90+
return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, targetField);
91+
}
92+
93+
@Override
94+
public List<Attribute> output() {
95+
if (lazyOutput == null) {
96+
lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
97+
}
98+
return lazyOutput;
99+
}
100+
101+
@Override
102+
protected AttributeSet computeReferences() {
103+
return input.references();
104+
}
105+
106+
@Override
107+
public boolean equals(Object o) {
108+
if (this == o) return true;
109+
if (o == null || getClass() != o.getClass()) return false;
110+
if (super.equals(o) == false) return false;
111+
DenseVectorEmbeddingExec that = (DenseVectorEmbeddingExec) o;
112+
return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField);
113+
}
114+
115+
@Override
116+
public int hashCode() {
117+
return Objects.hash(super.hashCode(), input, targetField);
118+
}
119+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference.embedding;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.Literal;
13+
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
15+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
16+
import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
17+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
18+
19+
import java.io.IOException;
20+
21+
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
22+
23+
public class DenseVectorEmbeddingSerializationTests extends AbstractLogicalPlanSerializationTests<DenseVectorEmbedding> {
24+
25+
@Override
26+
protected DenseVectorEmbedding createTestInstance() {
27+
return new DenseVectorEmbedding(
28+
randomSource(),
29+
randomChild(0),
30+
randomInferenceId(),
31+
randomInput(),
32+
randomTargetField()
33+
);
34+
}
35+
36+
@Override
37+
protected DenseVectorEmbedding mutateInstance(DenseVectorEmbedding instance) throws IOException {
38+
LogicalPlan child = instance.child();
39+
Expression inferenceId = instance.inferenceId();
40+
Expression input = instance.input();
41+
Attribute targetField = instance.embeddingField();
42+
43+
switch (between(0, 3)) {
44+
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
45+
case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
46+
case 2 -> input = randomValueOtherThan(input, this::randomInput);
47+
case 3 -> targetField = randomValueOtherThan(targetField, this::randomTargetField);
48+
}
49+
return new DenseVectorEmbedding(instance.source(), child, inferenceId, input, targetField);
50+
}
51+
52+
private Literal randomInferenceId() {
53+
return Literal.keyword(EMPTY, randomIdentifier());
54+
}
55+
56+
private Expression randomInput() {
57+
return randomBoolean() ? Literal.keyword(EMPTY, randomIdentifier()) : randomAttribute();
58+
}
59+
60+
private Attribute randomTargetField() {
61+
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
62+
}
63+
64+
private Attribute randomAttribute() {
65+
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
66+
}
67+
}

0 commit comments

Comments
 (0)