Skip to content

Commit b2c0e11

Browse files
committed
Adding logical plan for the rerank command.
1 parent 612dbcc commit b2c0e11

File tree

4 files changed

+222
-0
lines changed

4 files changed

+222
-0
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
2323
import org.elasticsearch.xpack.esql.plan.logical.Project;
2424
import org.elasticsearch.xpack.esql.plan.logical.TopN;
25+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
2526
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
2627
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
2728
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
@@ -79,6 +80,7 @@ public static List<NamedWriteableRegistry.Entry> logical() {
7980
MvExpand.ENTRY,
8081
OrderBy.ENTRY,
8182
Project.ENTRY,
83+
Rerank.ENTRY,
8284
TopN.ENTRY
8385
);
8486
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.xpack.esql.core.tree.Source;
13+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
14+
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
15+
16+
import java.io.IOException;
17+
import java.util.Objects;
18+
19+
public abstract class InferencePlan extends UnaryPlan {
20+
21+
private final String inferenceId;
22+
23+
protected InferencePlan(Source source, LogicalPlan child, String inferenceId) {
24+
super(source, child);
25+
this.inferenceId = inferenceId;
26+
}
27+
28+
@Override
29+
public void writeTo(StreamOutput out) throws IOException {
30+
Source.EMPTY.writeTo(out);
31+
out.writeNamedWriteable(child());
32+
out.writeString(inferenceId());
33+
}
34+
35+
public String inferenceId() {
36+
return inferenceId;
37+
}
38+
39+
@Override
40+
public boolean equals(Object o) {
41+
if (this == o) return true;
42+
if (o == null || getClass() != o.getClass()) return false;
43+
if (super.equals(o) == false) return false;
44+
InferencePlan other = (InferencePlan) o;
45+
return Objects.equals(inferenceId(), other.inferenceId());
46+
}
47+
48+
@Override
49+
public int hashCode() {
50+
return Objects.hash(super.hashCode(), inferenceId());
51+
}
52+
53+
public abstract TaskType taskType();
54+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.core.capabilities.Resolvables;
15+
import org.elasticsearch.xpack.esql.core.expression.Alias;
16+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
17+
import org.elasticsearch.xpack.esql.core.expression.Expressions;
18+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
19+
import org.elasticsearch.xpack.esql.core.tree.Source;
20+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
21+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
22+
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
23+
24+
import java.io.IOException;
25+
import java.util.List;
26+
import java.util.Objects;
27+
28+
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
29+
30+
public class Rerank extends InferencePlan {
31+
32+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
33+
private final String queryText;
34+
private final List<Alias> rerankFields;
35+
36+
public Rerank(Source source, LogicalPlan child, String inferenceId, String queryText, List<Alias> rerankFields) {
37+
super(source, child, inferenceId);
38+
this.queryText = queryText;
39+
this.rerankFields = rerankFields;
40+
}
41+
42+
public Rerank(StreamInput in) throws IOException {
43+
this(
44+
Source.readFrom((PlanStreamInput) in),
45+
in.readNamedWriteable(LogicalPlan.class),
46+
in.readString(),
47+
in.readString(),
48+
in.readCollectionAsList(Alias::new)
49+
);
50+
}
51+
52+
@Override
53+
public void writeTo(StreamOutput out) throws IOException {
54+
super.writeTo(out);
55+
out.writeString(queryText);
56+
out.writeCollection(rerankFields());
57+
}
58+
59+
public String queryText() {
60+
return queryText;
61+
}
62+
63+
public List<Alias> rerankFields() {
64+
return rerankFields;
65+
}
66+
67+
@Override
68+
public TaskType taskType() {
69+
return TaskType.RERANK;
70+
}
71+
72+
@Override
73+
public String getWriteableName() {
74+
return ENTRY.name;
75+
}
76+
77+
@Override
78+
public UnaryPlan replaceChild(LogicalPlan newChild) {
79+
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields);
80+
}
81+
82+
@Override
83+
protected AttributeSet computeReferences() {
84+
return computeReferences(rerankFields);
85+
}
86+
87+
public static AttributeSet computeReferences(List<Alias> fields) {
88+
AttributeSet generated = new AttributeSet(asAttributes(fields));
89+
return Expressions.references(fields).subtract(generated);
90+
}
91+
92+
@Override
93+
public boolean expressionsResolved() {
94+
return Resolvables.resolved(rerankFields);
95+
}
96+
97+
@Override
98+
protected NodeInfo<? extends LogicalPlan> info() {
99+
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields);
100+
}
101+
102+
@Override
103+
public boolean equals(Object o) {
104+
if (this == o) return true;
105+
if (o == null || getClass() != o.getClass()) return false;
106+
if (super.equals(o) == false) return false;
107+
Rerank rerank = (Rerank) o;
108+
return Objects.equals(queryText, rerank.queryText) && Objects.equals(rerankFields, rerank.rerankFields);
109+
}
110+
111+
@Override
112+
public int hashCode() {
113+
return Objects.hash(super.hashCode(), queryText, rerankFields);
114+
}
115+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Alias;
11+
import org.elasticsearch.xpack.esql.core.tree.Source;
12+
import org.elasticsearch.xpack.esql.expression.AliasTests;
13+
import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
14+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
15+
16+
import java.io.IOException;
17+
import java.util.List;
18+
19+
public class RerankSerializationTests extends AbstractLogicalPlanSerializationTests<Rerank> {
20+
@Override
21+
protected Rerank createTestInstance() {
22+
Source source = randomSource();
23+
LogicalPlan child = randomChild(0);
24+
return new Rerank(source, child, randomIdentifier(), randomIdentifier(), randomFields());
25+
}
26+
27+
@Override
28+
protected Rerank mutateInstance(Rerank instance) throws IOException {
29+
LogicalPlan child = instance.child();
30+
String inferenceId = instance.inferenceId();
31+
String queryText = instance.queryText();
32+
List<Alias> fields = instance.rerankFields();
33+
34+
switch (between(0, 3)) {
35+
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
36+
case 1 -> inferenceId = randomValueOtherThan(inferenceId, RerankSerializationTests::randomIdentifier);
37+
case 2 -> queryText = randomValueOtherThan(queryText, RerankSerializationTests::randomIdentifier);
38+
case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
39+
}
40+
return new Rerank(instance.source(), child, inferenceId, queryText, fields);
41+
}
42+
43+
@Override
44+
protected boolean alwaysEmptySource() {
45+
return true;
46+
}
47+
48+
private List<Alias> randomFields() {
49+
return randomList(0, 10, AliasTests::randomAlias);
50+
}
51+
}

0 commit comments

Comments
 (0)