Skip to content

Commit 3cced94

Browse files
committed
Ensure proper support of non text field as rerank fields.
1 parent e89a74b commit 3cced94

File tree

5 files changed

+96
-19
lines changed

5 files changed

+96
-19
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,25 @@ book_no:keyword | title:text | author
6161
5327 | War and Peace | Leo Tolstoy | 0.08
6262
;
6363

64+
reranker using a non text fields
65+
required_capability: rerank
66+
required_capability: match_operator_colon
67+
68+
FROM books METADATA _score
69+
| WHERE title:"war and peace" AND author:"Tolstoy"
70+
| RERANK "war and peace" ON ratings WITH { "inference_id" : "test_reranker" }
71+
| EVAL _score=ROUND(_score, 2)
72+
| SORT _score DESC, book_no ASC
73+
| KEEP book_no, title, ratings, _score
74+
;
75+
76+
book_no:keyword | title:text | ratings:double | _score:double
77+
2776 | The Devil and Other Stories (Oxford World's Classics) | 5.0 | 0.33
78+
4536 | War and Peace (Signet Classics) | 4.75 | 0.25
79+
5327 | War and Peace | 3.84 | 0.06
80+
9032 | War and Peace: A Novel (6 Volumes) | 3.81 | 0.06
81+
;
82+
6483

6584
reranker using multiple fields
6685
required_capability: rerank
@@ -82,6 +101,26 @@ book_no:keyword | title:text | author
82101
;
83102

84103

104+
reranker using multiple fields with some non text fields
105+
required_capability: rerank
106+
required_capability: match_operator_colon
107+
108+
FROM books METADATA _score
109+
| WHERE title:"war and peace" AND author:"Tolstoy"
110+
| RERANK "war and peace" ON title, ratings WITH { "inference_id" : "test_reranker" }
111+
| EVAL _score=ROUND(_score, 2)
112+
| SORT _score DESC, book_no ASC
113+
| KEEP book_no, title, ratings, _score
114+
;
115+
116+
book_no:keyword | title:text | ratings:double | _score:double
117+
4536 | War and Peace (Signet Classics) | 4.75 | 0.02
118+
5327 | War and Peace | 3.84 | 0.02
119+
2776 | The Devil and Other Stories (Oxford World's Classics) | 5.0 | 0.01
120+
9032 | War and Peace: A Novel (6 Volumes) | 3.81 | 0.01
121+
;
122+
123+
85124
reranker after a limit
86125
required_capability: rerank
87126
required_capability: match_operator_colon

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
8080
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
8181
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
82+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
8283
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong;
8384
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
8485
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
@@ -813,11 +814,23 @@ private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput)
813814
List<Alias> newFields = new ArrayList<>();
814815
boolean changed = false;
815816

817+
// Do not need to cast as string if there are multiple rerank fields since it will be converted to YAML.
818+
boolean castRerankFieldsAsString = rerank.rerankFields().size() < 2;
819+
816820
// First resolving fields used in expression
817821
for (Alias field : rerank.rerankFields()) {
818-
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childrenOutput));
819-
newFields.add(result);
820-
changed |= result != field;
822+
Alias resolved = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childrenOutput));
823+
824+
if (resolved.resolved() != false) {
825+
if (castRerankFieldsAsString
826+
&& rerank.isValidRerankField(resolved)
827+
&& DataType.isString(resolved.dataType()) == false) {
828+
resolved = resolved.replaceChild(new ToString(resolved.child().source(), resolved.child()));
829+
}
830+
}
831+
832+
newFields.add(resolved);
833+
changed |= resolved != field;
821834
}
822835

823836
if (changed) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ public static AttributeSet computeReferences(List<Alias> fields) {
162162
return Eval.computeReferences(fields);
163163
}
164164

165+
public boolean isValidRerankField(Alias rerankField) {
166+
// Only supportinng the following datatypes for now: text, numeric and boolean
167+
return DataType.isString(rerankField.dataType())
168+
|| rerankField.dataType() == DataType.BOOLEAN
169+
|| rerankField.dataType().isNumeric();
170+
}
171+
165172
@Override
166173
public boolean expressionsResolved() {
167174
return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields) && scoreAttribute.resolved();
@@ -182,7 +189,7 @@ public void postAnalysisVerification(Failures failures) {
182189
// When using multiple fields the content is transformed into YAML before it is reranked
183190
// We can use any of string, numeric or boolean field.
184191
rerankFields.stream()
185-
.filter(Predicate.not(f -> DataType.isString(f.dataType()) || f.dataType() == DataType.BOOLEAN || f.dataType().isNumeric()))
192+
.filter(Predicate.not(this::isValidRerankField))
186193
.forEach(
187194
rerankField -> failures.add(
188195
fail(

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
5656
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
5757
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
58+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
5859
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
5960
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
6061
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
@@ -3798,7 +3799,15 @@ public void testRerankFieldValidTypes() {
37983799
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
37993800
EsRelation relation = as(rerank.child(), EsRelation.class);
38003801
Attribute fieldAttribute = getAttributeByName(relation.output(), fieldName);
3801-
assertThat(rerank.rerankFields(), equalTo(List.of(alias(fieldName, fieldAttribute))));
3802+
if (DataType.isString(fieldAttribute.dataType())) {
3803+
assertThat(rerank.rerankFields(), equalTo(List.of(alias(fieldName, fieldAttribute))));
3804+
3805+
} else {
3806+
assertThat(
3807+
rerank.rerankFields(),
3808+
equalTo(List.of(alias(fieldName, new ToString(fieldAttribute.source(), fieldAttribute))))
3809+
);
3810+
}
38023811
}
38033812
}
38043813

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,30 +161,39 @@ public void chunkedInfer(
161161
}
162162

163163
private RankedDocsResults makeResults(List<String> input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
164-
int totalResults = input.size();
164+
if (taskSettings.useTextLength) {
165+
return makeResultFromTextInput(input, taskSettings);
166+
}
167+
165168
try {
169+
int totalResults = input.size();
166170
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
167171
for (int i = 0; i < totalResults; i++) {
168172
results.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(input.get(i)), input.get(i)));
169173
}
170174
return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList());
171175
} catch (NumberFormatException ex) {
172-
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
176+
return makeResultFromTextInput(input, taskSettings);
177+
}
178+
}
179+
180+
private RankedDocsResults makeResultFromTextInput(List<String> input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
181+
int totalResults = input.size();
173182

174-
float minScore = taskSettings.minScore();
175-
float resultDiff = taskSettings.resultDiff();
176-
for (int i = 0; i < input.size(); i++) {
177-
float relevanceScore = minScore + resultDiff * (totalResults - i);
178-
String inputText = input.get(totalResults - 1 - i);
179-
if (taskSettings.useTextLength()) {
180-
relevanceScore = 1f / inputText.length();
181-
}
182-
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
183+
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
184+
float minScore = taskSettings.minScore();
185+
float resultDiff = taskSettings.resultDiff();
186+
for (int i = 0; i < input.size(); i++) {
187+
float relevanceScore = minScore + resultDiff * (totalResults - i);
188+
String inputText = input.get(totalResults - 1 - i);
189+
if (taskSettings.useTextLength()) {
190+
relevanceScore = 1f / inputText.length();
183191
}
184-
// Ensure result are sorted by descending score
185-
results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
186-
return new RankedDocsResults(results);
192+
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
187193
}
194+
// Ensure result are sorted by descending score
195+
results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
196+
return new RankedDocsResults(results);
188197
}
189198

190199
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {

0 commit comments

Comments
 (0)