1111import org .elasticsearch .common .io .stream .StreamInput ;
1212import org .elasticsearch .common .io .stream .StreamOutput ;
1313import org .elasticsearch .inference .TaskType ;
14+ import org .elasticsearch .xpack .esql .capabilities .PostAnalysisVerificationAware ;
1415import org .elasticsearch .xpack .esql .capabilities .TelemetryAware ;
16+ import org .elasticsearch .xpack .esql .common .Failures ;
1517import org .elasticsearch .xpack .esql .core .capabilities .Resolvables ;
1618import org .elasticsearch .xpack .esql .core .expression .Alias ;
1719import org .elasticsearch .xpack .esql .core .expression .Attribute ;
1820import org .elasticsearch .xpack .esql .core .expression .AttributeSet ;
1921import org .elasticsearch .xpack .esql .core .expression .Expression ;
20- import org .elasticsearch .xpack .esql .core .expression .Expressions ;
2122import org .elasticsearch .xpack .esql .core .expression .Literal ;
2223import org .elasticsearch .xpack .esql .core .expression .NameId ;
2324import org .elasticsearch .xpack .esql .core .tree .NodeInfo ;
2425import org .elasticsearch .xpack .esql .core .tree .Source ;
26+ import org .elasticsearch .xpack .esql .core .type .DataType ;
2527import org .elasticsearch .xpack .esql .io .stream .PlanStreamInput ;
28+ import org .elasticsearch .xpack .esql .plan .logical .Eval ;
2629import org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
2730import org .elasticsearch .xpack .esql .plan .logical .UnaryPlan ;
2831
2932import java .io .IOException ;
3033import java .util .List ;
3134import java .util .Objects ;
35+ import java .util .function .Predicate ;
3236
33- import static org .elasticsearch .xpack .esql .core . expression . Expressions . asAttributes ;
37+ import static org .elasticsearch .xpack .esql .common . Failure . fail ;
3438import static org .elasticsearch .xpack .esql .expression .NamedExpressions .mergeOutputAttributes ;
3539
36- public class Rerank extends InferencePlan <Rerank > implements TelemetryAware {
40+ public class Rerank extends InferencePlan <Rerank > implements PostAnalysisVerificationAware , TelemetryAware {
3741
3842 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (LogicalPlan .class , "Rerank" , Rerank ::new );
3943 public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch" ;
@@ -155,8 +159,7 @@ private Attribute renameScoreAttribute(String newName) {
155159 }
156160
157161 public static AttributeSet computeReferences (List <Alias > fields ) {
158- AttributeSet rerankFields = AttributeSet .of (asAttributes (fields ));
159- return Expressions .references (fields ).subtract (rerankFields );
162+ return Eval .computeReferences (fields );
160163 }
161164
162165 @ Override
@@ -169,6 +172,28 @@ protected NodeInfo<? extends LogicalPlan> info() {
169172 return NodeInfo .create (this , Rerank ::new , child (), inferenceId (), queryText , rerankFields , scoreAttribute );
170173 }
171174
175+ @ Override
176+ public void postAnalysisVerification (Failures failures ) {
177+ if (queryText .resolved () && (DataType .isString (queryText .dataType ()) == false || queryText .foldable () == false )) {
178+ // Rerank only supports string as query
179+ failures .add (fail (queryText , "query must be a valid string in RERANK, found [{}]" , queryText .source ().text ()));
180+ }
181+
182+ // When using multiple fields the content is transformed into YAML before it is reranked
183+ // We can use any of string, numeric or boolean field.
184+ rerankFields .stream ()
185+ .filter (Predicate .not (f -> DataType .isString (f .dataType ()) || f .dataType () == DataType .BOOLEAN || f .dataType ().isNumeric ()))
186+ .forEach (
187+ rerankField -> failures .add (
188+ fail (
189+ rerankField ,
190+ "rerank field must be a valid string, numeric or boolean expression, found [{}]" ,
191+ rerankField .source ().text ()
192+ )
193+ )
194+ );
195+ }
196+
172197 @ Override
173198 public boolean equals (Object o ) {
174199 if (this == o ) return true ;
0 commit comments