Skip to content

Commit ae208a6

Browse files
committed
Logical & physical plan implementation.
1 parent 94fb918 commit ae208a6

File tree

5 files changed

+380
-0
lines changed

5 files changed

+380
-0
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.xpack.esql.plan.logical.TopN;
2727
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
2828
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
29+
import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding;
2930
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
3031
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
3132
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
@@ -55,6 +56,7 @@
5556
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
5657
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
5758
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
59+
import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec;
5860

5961
import java.util.ArrayList;
6062
import java.util.List;
@@ -72,6 +74,7 @@ public static List<NamedWriteableRegistry.Entry> logical() {
7274
return List.of(
7375
Aggregate.ENTRY,
7476
Completion.ENTRY,
77+
DenseVectorEmbedding.ENTRY,
7578
Dissect.ENTRY,
7679
Enrich.ENTRY,
7780
EsRelation.ENTRY,
@@ -99,6 +102,7 @@ public static List<NamedWriteableRegistry.Entry> physical() {
99102
return List.of(
100103
AggregateExec.ENTRY,
101104
CompletionExec.ENTRY,
105+
DenseVectorEmbeddingExec.ENTRY,
102106
DissectExec.ENTRY,
103107
EnrichExec.ENTRY,
104108
EsQueryExec.ENTRY,
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference.embedding;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
15+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
16+
import org.elasticsearch.xpack.esql.core.expression.Expression;
17+
import org.elasticsearch.xpack.esql.core.expression.NameId;
18+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
19+
import org.elasticsearch.xpack.esql.core.tree.Source;
20+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
21+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
22+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
23+
24+
import java.io.IOException;
25+
import java.util.List;
26+
import java.util.Objects;
27+
28+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
29+
30+
public class DenseVectorEmbedding extends InferencePlan<DenseVectorEmbedding> implements TelemetryAware {
31+
32+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
33+
LogicalPlan.class,
34+
"DenseVectorEmbedding",
35+
DenseVectorEmbedding::new
36+
);
37+
38+
private final Expression input;
39+
private final Attribute targetField;
40+
private List<Attribute> lazyOutput;
41+
42+
public DenseVectorEmbedding(Source source, LogicalPlan child, Expression inferenceId, Expression input, Attribute targetField) {
43+
super(source, child, inferenceId);
44+
this.input = input;
45+
this.targetField = targetField;
46+
}
47+
48+
public DenseVectorEmbedding(StreamInput in) throws IOException {
49+
this(
50+
Source.readFrom((PlanStreamInput) in),
51+
in.readNamedWriteable(LogicalPlan.class),
52+
in.readNamedWriteable(Expression.class),
53+
in.readNamedWriteable(Expression.class),
54+
in.readNamedWriteable(Attribute.class)
55+
);
56+
}
57+
58+
@Override
59+
public void writeTo(StreamOutput out) throws IOException {
60+
source().writeTo(out);
61+
out.writeNamedWriteable(child());
62+
out.writeNamedWriteable(inferenceId());
63+
out.writeNamedWriteable(input);
64+
out.writeNamedWriteable(targetField);
65+
}
66+
67+
public Expression input() {
68+
return input;
69+
}
70+
71+
public Attribute embeddingField() {
72+
return targetField;
73+
}
74+
75+
@Override
76+
public TaskType taskType() {
77+
return TaskType.TEXT_EMBEDDING;
78+
}
79+
80+
@Override
81+
public String getWriteableName() {
82+
return ENTRY.name;
83+
}
84+
85+
@Override
86+
public List<Attribute> output() {
87+
if (lazyOutput == null) {
88+
lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
89+
}
90+
return lazyOutput;
91+
}
92+
93+
@Override
94+
public List<Attribute> generatedAttributes() {
95+
return List.of(targetField);
96+
}
97+
98+
@Override
99+
public DenseVectorEmbedding withGeneratedNames(List<String> newNames) {
100+
checkNumberOfNewNames(newNames);
101+
return new DenseVectorEmbedding(source(), child(), inferenceId(), input, this.renameTargetField(newNames.get(0)));
102+
}
103+
104+
private Attribute renameTargetField(String newName) {
105+
if (newName.equals(targetField.name())) {
106+
return targetField;
107+
}
108+
109+
return targetField.withName(newName).withId(new NameId());
110+
}
111+
112+
@Override
113+
public boolean expressionsResolved() {
114+
return super.expressionsResolved() && input.resolved() && targetField.resolved();
115+
}
116+
117+
@Override
118+
public DenseVectorEmbedding withInferenceId(Expression newInferenceId) {
119+
return new DenseVectorEmbedding(source(), child(), newInferenceId, input, targetField);
120+
}
121+
122+
@Override
123+
public DenseVectorEmbedding replaceChild(LogicalPlan newChild) {
124+
return new DenseVectorEmbedding(source(), newChild, inferenceId(), input, targetField);
125+
}
126+
127+
@Override
128+
protected NodeInfo<? extends LogicalPlan> info() {
129+
return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), input, targetField);
130+
}
131+
132+
@Override
133+
public boolean equals(Object o) {
134+
if (this == o) return true;
135+
if (o == null || getClass() != o.getClass()) return false;
136+
if (super.equals(o) == false) return false;
137+
DenseVectorEmbedding that = (DenseVectorEmbedding) o;
138+
return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField);
139+
}
140+
141+
@Override
142+
public int hashCode() {
143+
return Objects.hash(super.hashCode(), input, targetField);
144+
}
145+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference.embedding;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
14+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
15+
import org.elasticsearch.xpack.esql.core.expression.Expression;
16+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
17+
import org.elasticsearch.xpack.esql.core.tree.Source;
18+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
19+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
20+
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
21+
import org.elasticsearch.xpack.esql.plan.physical.inference.InferenceExec;
22+
23+
import java.io.IOException;
24+
import java.util.List;
25+
import java.util.Objects;
26+
27+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
28+
29+
public class DenseVectorEmbeddingExec extends InferenceExec {
30+
31+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
32+
PhysicalPlan.class,
33+
"DenseVectorEmbeddingExec",
34+
DenseVectorEmbeddingExec::new
35+
);
36+
37+
private final Expression input;
38+
private final Attribute targetField;
39+
private List<Attribute> lazyOutput;
40+
41+
public DenseVectorEmbeddingExec(Source source, PhysicalPlan child, Expression inferenceId, Expression input, Attribute targetField) {
42+
super(source, child, inferenceId);
43+
this.input = input;
44+
this.targetField = targetField;
45+
}
46+
47+
public DenseVectorEmbeddingExec(StreamInput in) throws IOException {
48+
this(
49+
Source.readFrom((PlanStreamInput) in),
50+
in.readNamedWriteable(PhysicalPlan.class),
51+
in.readNamedWriteable(Expression.class),
52+
in.readNamedWriteable(Expression.class),
53+
in.readNamedWriteable(Attribute.class)
54+
);
55+
}
56+
57+
public Expression input() {
58+
return input;
59+
}
60+
61+
public Attribute targetField() {
62+
return targetField;
63+
}
64+
65+
@Override
66+
public String getWriteableName() {
67+
return ENTRY.name;
68+
}
69+
70+
@Override
71+
public void writeTo(StreamOutput out) throws IOException {
72+
super.writeTo(out);
73+
out.writeNamedWriteable(input);
74+
out.writeNamedWriteable(targetField);
75+
}
76+
77+
@Override
78+
protected NodeInfo<? extends PhysicalPlan> info() {
79+
return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, targetField);
80+
}
81+
82+
@Override
83+
public UnaryExec replaceChild(PhysicalPlan newChild) {
84+
return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, targetField);
85+
}
86+
87+
@Override
88+
public List<Attribute> output() {
89+
if (lazyOutput == null) {
90+
lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
91+
}
92+
return lazyOutput;
93+
}
94+
95+
@Override
96+
protected AttributeSet computeReferences() {
97+
return input.references();
98+
}
99+
100+
@Override
101+
public boolean equals(Object o) {
102+
if (this == o) return true;
103+
if (o == null || getClass() != o.getClass()) return false;
104+
if (super.equals(o) == false) return false;
105+
DenseVectorEmbeddingExec that = (DenseVectorEmbeddingExec) o;
106+
return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField);
107+
}
108+
109+
@Override
110+
public int hashCode() {
111+
return Objects.hash(super.hashCode(), input, targetField);
112+
}
113+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference.embedding;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.Literal;
13+
import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests;
14+
import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
15+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
16+
17+
import java.io.IOException;
18+
19+
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
20+
21+
public class DenseVectorEmbeddingSerializationTests extends AbstractLogicalPlanSerializationTests<DenseVectorEmbedding> {
22+
23+
@Override
24+
protected DenseVectorEmbedding createTestInstance() {
25+
return new DenseVectorEmbedding(randomSource(), randomChild(0), randomInferenceId(), randomInput(), randomTargetField());
26+
}
27+
28+
@Override
29+
protected DenseVectorEmbedding mutateInstance(DenseVectorEmbedding instance) throws IOException {
30+
LogicalPlan child = instance.child();
31+
Expression inferenceId = instance.inferenceId();
32+
Expression input = instance.input();
33+
Attribute targetField = instance.embeddingField();
34+
35+
switch (between(0, 3)) {
36+
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
37+
case 1 -> inferenceId = randomValueOtherThan(inferenceId, this::randomInferenceId);
38+
case 2 -> input = randomValueOtherThan(input, this::randomInput);
39+
case 3 -> targetField = randomValueOtherThan(targetField, this::randomTargetField);
40+
}
41+
return new DenseVectorEmbedding(instance.source(), child, inferenceId, input, targetField);
42+
}
43+
44+
private Literal randomInferenceId() {
45+
return Literal.keyword(EMPTY, randomIdentifier());
46+
}
47+
48+
private Expression randomInput() {
49+
return randomBoolean() ? Literal.keyword(EMPTY, randomIdentifier()) : randomAttribute();
50+
}
51+
52+
private Attribute randomTargetField() {
53+
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
54+
}
55+
56+
private Attribute randomAttribute() {
57+
return ReferenceAttributeTests.randomReferenceAttribute(randomBoolean());
58+
}
59+
}

0 commit comments

Comments
 (0)