Skip to content

Commit 7322c0c

Browse files
committed
Resolve dimensions of the embedding from the model config.
1 parent 4408f90 commit 7322c0c

File tree

11 files changed

+183
-91
lines changed

11 files changed

+183
-91
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ protected LogicalPlan rule(InferencePlan<?> plan, AnalyzerContext context) {
408408
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
409409

410410
if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
411-
return plan;
411+
return plan.withModelConfigurations(resolvedInference.modelConfigurations());
412412
} else if (resolvedInference != null) {
413413
String error = "cannot use inference endpoint ["
414414
+ inferenceId
@@ -842,13 +842,7 @@ private LogicalPlan resolveDenseVectorEmbedding(DenseVectorEmbedding p, List<Att
842842
// Create a new DenseVectorEmbedding with resolved expressions
843843
// Only create a new instance if something changed to avoid unnecessary object creation
844844
if (input != p.input() || targetField != p.embeddingField()) {
845-
return new DenseVectorEmbedding(
846-
p.source(),
847-
p.child(),
848-
p.inferenceId(),
849-
input,
850-
targetField
851-
);
845+
return p.withTargetField(targetField);
852846
}
853847

854848
return p;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<Infere
6363
GetInferenceModelAction.INSTANCE,
6464
new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
6565
ActionListener.wrap(r -> {
66-
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
66+
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst());
6767
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
6868
countdownListener.onResponse(null);
6969
}, e -> {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,12 @@
77

88
package org.elasticsearch.xpack.esql.inference;
99

10-
import org.elasticsearch.common.io.stream.StreamInput;
11-
import org.elasticsearch.common.io.stream.StreamOutput;
12-
import org.elasticsearch.common.io.stream.Writeable;
10+
import org.elasticsearch.inference.ModelConfigurations;
1311
import org.elasticsearch.inference.TaskType;
1412

15-
import java.io.IOException;
13+
public record ResolvedInference(String inferenceId, ModelConfigurations modelConfigurations) {
1614

17-
public record ResolvedInference(String inferenceId, TaskType taskType) implements Writeable {
18-
19-
public ResolvedInference(StreamInput in) throws IOException {
20-
this(in.readString(), TaskType.valueOf(in.readString()));
21-
}
22-
23-
@Override
24-
public void writeTo(StreamOutput out) throws IOException {
25-
out.writeString(inferenceId);
26-
out.writeString(taskType.name());
15+
public TaskType taskType() {
16+
return modelConfigurations.getTaskType();
2717
}
2818
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.plan.logical.inference;
99

1010
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.inference.ModelConfigurations;
1112
import org.elasticsearch.inference.TaskType;
1213
import org.elasticsearch.xpack.esql.core.expression.Expression;
1314
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
@@ -69,4 +70,9 @@ public int hashCode() {
6970
public PlanType withInferenceResolutionError(String inferenceId, String error) {
7071
return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
7172
}
73+
74+
@SuppressWarnings("unchecked")
75+
public PlanType withModelConfigurations(ModelConfigurations modelConfig) {
76+
return (PlanType) this;
77+
}
7278
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@
1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.common.lucene.BytesRefs;
14+
import org.elasticsearch.inference.ModelConfigurations;
1315
import org.elasticsearch.inference.TaskType;
1416
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
17+
import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable;
1518
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1619
import org.elasticsearch.xpack.esql.core.expression.Expression;
20+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
21+
import org.elasticsearch.xpack.esql.core.expression.Literal;
1722
import org.elasticsearch.xpack.esql.core.expression.NameId;
1823
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1924
import org.elasticsearch.xpack.esql.core.tree.Source;
25+
import org.elasticsearch.xpack.esql.core.type.DataType;
2026
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2127
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2228
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
@@ -36,13 +42,26 @@ public class DenseVectorEmbedding extends InferencePlan<DenseVectorEmbedding> im
3642
);
3743

3844
private final Expression input;
45+
private final Expression dimensions;
3946
private final Attribute targetField;
4047
private List<Attribute> lazyOutput;
4148

4249
public DenseVectorEmbedding(Source source, LogicalPlan child, Expression inferenceId, Expression input, Attribute targetField) {
50+
this(source, child, inferenceId, new UnresolvedDimensions(inferenceId), input, targetField);
51+
}
52+
53+
DenseVectorEmbedding(
54+
Source source,
55+
LogicalPlan child,
56+
Expression inferenceId,
57+
Expression dimensions,
58+
Expression input,
59+
Attribute targetField
60+
) {
4361
super(source, child, inferenceId);
4462
this.input = input;
4563
this.targetField = targetField;
64+
this.dimensions = dimensions;
4665
}
4766

4867
public DenseVectorEmbedding(StreamInput in) throws IOException {
@@ -51,6 +70,7 @@ public DenseVectorEmbedding(StreamInput in) throws IOException {
5170
in.readNamedWriteable(LogicalPlan.class),
5271
in.readNamedWriteable(Expression.class),
5372
in.readNamedWriteable(Expression.class),
73+
in.readNamedWriteable(Expression.class),
5474
in.readNamedWriteable(Attribute.class)
5575
);
5676
}
@@ -60,6 +80,7 @@ public void writeTo(StreamOutput out) throws IOException {
6080
source().writeTo(out);
6181
out.writeNamedWriteable(child());
6282
out.writeNamedWriteable(inferenceId());
83+
out.writeNamedWriteable(dimensions);
6384
out.writeNamedWriteable(input);
6485
out.writeNamedWriteable(targetField);
6586
}
@@ -77,6 +98,10 @@ public TaskType taskType() {
7798
return TaskType.TEXT_EMBEDDING;
7899
}
79100

101+
public Expression dimensions() {
102+
return dimensions;
103+
}
104+
80105
@Override
81106
public String getWriteableName() {
82107
return ENTRY.name;
@@ -98,7 +123,7 @@ public List<Attribute> generatedAttributes() {
98123
@Override
99124
public DenseVectorEmbedding withGeneratedNames(List<String> newNames) {
100125
checkNumberOfNewNames(newNames);
101-
return new DenseVectorEmbedding(source(), child(), inferenceId(), input, this.renameTargetField(newNames.get(0)));
126+
return new DenseVectorEmbedding(source(), child(), inferenceId(), dimensions, input, this.renameTargetField(newNames.get(0)));
102127
}
103128

104129
private Attribute renameTargetField(String newName) {
@@ -111,22 +136,45 @@ private Attribute renameTargetField(String newName) {
111136

112137
@Override
113138
public boolean expressionsResolved() {
114-
return super.expressionsResolved() && input.resolved() && targetField.resolved();
139+
return super.expressionsResolved() && input.resolved() && targetField.resolved() && dimensions.resolved();
115140
}
116141

117142
@Override
118143
public DenseVectorEmbedding withInferenceId(Expression newInferenceId) {
119-
return new DenseVectorEmbedding(source(), child(), newInferenceId, input, targetField);
144+
return new DenseVectorEmbedding(source(), child(), newInferenceId, dimensions, input, targetField);
145+
}
146+
147+
public DenseVectorEmbedding withDimensions(Expression newDimensions) {
148+
return new DenseVectorEmbedding(source(), child(), inferenceId(), newDimensions, input, targetField);
149+
}
150+
151+
public DenseVectorEmbedding withTargetField(Attribute targetField) {
152+
return new DenseVectorEmbedding(source(), child(), inferenceId(), dimensions, input, targetField);
153+
}
154+
155+
@Override
156+
public DenseVectorEmbedding withModelConfigurations(ModelConfigurations modelConfig) {
157+
boolean hasChanged = false;
158+
Expression newDimensions = dimensions;
159+
160+
if (dimensions.resolved() == false
161+
&& modelConfig.getServiceSettings() != null
162+
&& modelConfig.getServiceSettings().dimensions() > 0) {
163+
hasChanged = true;
164+
newDimensions = new Literal(Source.EMPTY, modelConfig.getServiceSettings().dimensions(), DataType.INTEGER);
165+
}
166+
167+
return hasChanged ? withDimensions(newDimensions) : this;
120168
}
121169

122170
@Override
123171
public DenseVectorEmbedding replaceChild(LogicalPlan newChild) {
124-
return new DenseVectorEmbedding(source(), newChild, inferenceId(), input, targetField);
172+
return new DenseVectorEmbedding(source(), newChild, inferenceId(), dimensions, input, targetField);
125173
}
126174

127175
@Override
128176
protected NodeInfo<? extends LogicalPlan> info() {
129-
return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), input, targetField);
177+
return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), dimensions, input, targetField);
130178
}
131179

132180
@Override
@@ -135,11 +183,28 @@ public boolean equals(Object o) {
135183
if (o == null || getClass() != o.getClass()) return false;
136184
if (super.equals(o) == false) return false;
137185
DenseVectorEmbedding that = (DenseVectorEmbedding) o;
138-
return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField);
186+
return Objects.equals(input, that.input)
187+
&& Objects.equals(dimensions, that.dimensions)
188+
&& Objects.equals(targetField, that.targetField);
139189
}
140190

141191
@Override
142192
public int hashCode() {
143-
return Objects.hash(super.hashCode(), input, targetField);
193+
return Objects.hash(super.hashCode(), input, targetField, dimensions);
194+
}
195+
196+
private static class UnresolvedDimensions extends Literal implements Unresolvable {
197+
198+
private final String inferenceId;
199+
200+
private UnresolvedDimensions(Expression inferenceId) {
201+
super(Source.EMPTY, null, DataType.NULL);
202+
this.inferenceId = BytesRefs.toString(inferenceId.fold(FoldContext.small()));
203+
}
204+
205+
@Override
206+
public String unresolvedMessage() {
207+
return "Dimensions cannot be resolved for inference endpoint[" + inferenceId + "]";
208+
}
144209
}
145210
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,21 @@ public class DenseVectorEmbeddingExec extends InferenceExec {
3535
);
3636

3737
private final Expression input;
38+
private final Expression dimensions;
3839
private final Attribute targetField;
3940
private List<Attribute> lazyOutput;
4041

41-
public DenseVectorEmbeddingExec(Source source, PhysicalPlan child, Expression inferenceId, Expression input, Attribute targetField) {
42+
public DenseVectorEmbeddingExec(
43+
Source source,
44+
PhysicalPlan child,
45+
Expression inferenceId,
46+
Expression dimensions,
47+
Expression input,
48+
Attribute targetField
49+
) {
4250
super(source, child, inferenceId);
4351
this.input = input;
52+
this.dimensions = dimensions;
4453
this.targetField = targetField;
4554
}
4655

@@ -50,6 +59,7 @@ public DenseVectorEmbeddingExec(StreamInput in) throws IOException {
5059
in.readNamedWriteable(PhysicalPlan.class),
5160
in.readNamedWriteable(Expression.class),
5261
in.readNamedWriteable(Expression.class),
62+
in.readNamedWriteable(Expression.class),
5363
in.readNamedWriteable(Attribute.class)
5464
);
5565
}
@@ -70,18 +80,23 @@ public String getWriteableName() {
7080
@Override
7181
public void writeTo(StreamOutput out) throws IOException {
7282
super.writeTo(out);
83+
out.writeNamedWriteable(dimensions);
7384
out.writeNamedWriteable(input);
7485
out.writeNamedWriteable(targetField);
7586
}
7687

88+
public Expression dimensions() {
89+
return dimensions;
90+
}
91+
7792
@Override
7893
protected NodeInfo<? extends PhysicalPlan> info() {
79-
return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, targetField);
94+
return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, dimensions, targetField);
8095
}
8196

8297
@Override
8398
public UnaryExec replaceChild(PhysicalPlan newChild) {
84-
return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, targetField);
99+
return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, dimensions, targetField);
85100
}
86101

87102
@Override
@@ -103,11 +118,13 @@ public boolean equals(Object o) {
103118
if (o == null || getClass() != o.getClass()) return false;
104119
if (super.equals(o) == false) return false;
105120
DenseVectorEmbeddingExec that = (DenseVectorEmbeddingExec) o;
106-
return Objects.equals(input, that.input) && Objects.equals(targetField, that.targetField);
121+
return Objects.equals(input, that.input)
122+
&& Objects.equals(dimensions, that.dimensions)
123+
&& Objects.equals(targetField, that.targetField);
107124
}
108125

109126
@Override
110127
public int hashCode() {
111-
return Objects.hash(super.hashCode(), input, targetField);
128+
return Objects.hash(super.hashCode(), input, dimensions, targetField);
112129
}
113130
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
4040
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
4141
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
42+
import static org.mockito.Mockito.mock;
43+
import static org.mockito.Mockito.when;
4244

4345
public final class AnalyzerTestUtils {
4446

@@ -189,12 +191,19 @@ public static EnrichResolution defaultEnrichResolution() {
189191

190192
public static InferenceResolution defaultInferenceResolution() {
191193
return InferenceResolution.builder()
192-
.withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
193-
.withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
194+
.withResolvedInference(mockedResolvedInference("reranking-inference-id", TaskType.RERANK))
195+
.withResolvedInference(mockedResolvedInference("completion-inference-id", TaskType.COMPLETION))
194196
.withError("error-inference-id", "error with inference resolution")
195197
.build();
196198
}
197199

200+
private static ResolvedInference mockedResolvedInference(String id, TaskType taskType) {
201+
ResolvedInference resolvedInference = mock(ResolvedInference.class);
202+
when(resolvedInference.inferenceId()).thenReturn(id);
203+
when(resolvedInference.taskType()).thenReturn(taskType);
204+
return resolvedInference;
205+
}
206+
198207
public static void loadEnrichPolicyResolution(
199208
EnrichResolution enrich,
200209
String policyType,

0 commit comments

Comments
 (0)