|
64 | 64 | import org.elasticsearch.xpack.esql.plan.logical.Rename; |
65 | 65 | import org.elasticsearch.xpack.esql.plan.logical.Row; |
66 | 66 | import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; |
| 67 | +import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; |
67 | 68 | import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; |
68 | 69 | import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin; |
69 | 70 |
|
@@ -3088,6 +3089,112 @@ public void testInvalidFork() { |
3088 | 3089 | expectError("FROM foo* | FORK ( LIMIT 10 ) ( y+2 )", "line 1:33: mismatched input 'y' expecting {'limit', 'sort', 'where'}"); |
3089 | 3090 | } |
3090 | 3091 |
|
| 3092 | + public void testRerankSingleField() { |
| 3093 | + assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3094 | + |
| 3095 | + var plan = processingCommand("RERANK \"query text\" ON title WITH \"inferenceID\""); |
| 3096 | + var rerank = as(plan, Rerank.class); |
| 3097 | + |
| 3098 | + assertThat(rerank.queryText(), equalTo("query text")); |
| 3099 | + assertThat(rerank.inferenceId(), equalTo("inferenceID")); |
| 3100 | + assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title"))))); |
| 3101 | + } |
| 3102 | + |
| 3103 | + public void testRerankMultipleFields() { |
| 3104 | + assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3105 | + |
| 3106 | + var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH \"inferenceID\""); |
| 3107 | + var rerank = as(plan, Rerank.class); |
| 3108 | + |
| 3109 | + assertThat(rerank.queryText(), equalTo("query text")); |
| 3110 | + assertThat(rerank.inferenceId(), equalTo("inferenceID")); |
| 3111 | + assertThat( |
| 3112 | + rerank.rerankFields(), |
| 3113 | + equalTo( |
| 3114 | + List.of( |
| 3115 | + alias("title", attribute("title")), |
| 3116 | + alias("description", attribute("description")), |
| 3117 | + alias("authors_renamed", attribute("authors")) |
| 3118 | + ) |
| 3119 | + ) |
| 3120 | + ); |
| 3121 | + } |
| 3122 | + |
| 3123 | + public void testRerankComputedFields() { |
| 3124 | + assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3125 | + |
| 3126 | + var plan = processingCommand( |
| 3127 | + "RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH \"inferenceID\"" |
| 3128 | + ); |
| 3129 | + var rerank = as(plan, Rerank.class); |
| 3130 | + |
| 3131 | + assertThat(rerank.queryText(), equalTo("query text")); |
| 3132 | + assertThat(rerank.inferenceId(), equalTo("inferenceID")); |
| 3133 | + assertThat( |
| 3134 | + rerank.rerankFields(), |
| 3135 | + equalTo( |
| 3136 | + List.of( |
| 3137 | + alias("title", attribute("title")), |
| 3138 | + alias( |
| 3139 | + "short_description", |
| 3140 | + function( |
| 3141 | + "SUBSTRING", |
| 3142 | + List.of(attribute("description"), new Literal(EMPTY, 0, INTEGER), new Literal(EMPTY, 100, INTEGER)) |
| 3143 | + ) |
| 3144 | + ) |
| 3145 | + ) |
| 3146 | + ) |
| 3147 | + ); |
| 3148 | + } |
| 3149 | + |
| 3150 | + public void testRerankNoFieldSpecified() { |
| 3151 | + assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3152 | + |
| 3153 | + var plan = parser.createStatement("row a = 1, b = \"foo\", c=true | RERANK \"query text\" WITH \"inferenceID\""); |
| 3154 | + var rerank = as(plan, Rerank.class); |
| 3155 | + |
| 3156 | + assertThat(rerank.queryText(), equalTo("query text")); |
| 3157 | + assertThat(rerank.inferenceId(), equalTo("inferenceID")); |
| 3158 | + |
| 3159 | + // When no field are specified, all the fields are used. |
| 3160 | + assertThat(rerank.rerankFields(), contains(alias("a", attribute("a")), alias("b", attribute("b")), alias("c", attribute("c")))); |
| 3161 | + } |
| 3162 | + |
| 3163 | + public void testRerankWithPositionalParameters() { |
| 3164 | + assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3165 | + |
| 3166 | + var queryParams = new QueryParams(List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker"))); |
| 3167 | + var rerank = as(parser.createStatement("row a = 1 | RERANK ? ON title WITH ?", queryParams), Rerank.class); |
| 3168 | + |
| 3169 | + assertThat(rerank.queryText(), equalTo("query text")); |
| 3170 | + assertThat(rerank.inferenceId(), equalTo("reranker")); |
| 3171 | + assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title"))))); |
| 3172 | + } |
| 3173 | + |
| 3174 | + public void testRerankWithNamedParameters() { |
| 3175 | + assumeTrue("FORK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3176 | + |
| 3177 | + var queryParams = new QueryParams(List.of(paramAsConstant("queryText", "query text"), paramAsConstant("inferenceId", "reranker"))); |
| 3178 | + var rerank = as(parser.createStatement("row a = 1 | RERANK ?queryText ON title WITH ?inferenceId", queryParams), Rerank.class); |
| 3179 | + |
| 3180 | + assertThat(rerank.queryText(), equalTo("query text")); |
| 3181 | + assertThat(rerank.inferenceId(), equalTo("reranker")); |
| 3182 | + assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title"))))); |
| 3183 | + } |
| 3184 | + |
| 3185 | + public void testInvalidRerank() { |
| 3186 | + assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled()); |
| 3187 | + expectError( |
| 3188 | + "FROM foo* | RERANK ON title WITH inferenceId", |
| 3189 | + "line 1:20: mismatched input 'ON' expecting {QUOTED_STRING, '?', NAMED_OR_POSITIONAL_PARAM}" |
| 3190 | + ); |
| 3191 | + |
| 3192 | + expectError( |
| 3193 | + "FROM foo* | RERANK \"query text\" ON title", |
| 3194 | + "line 1:42: mismatched input '<EOF>' expecting {'with', 'and', '::', ',', '.', 'or', '+', '-', '*', '/', '%'}" |
| 3195 | + ); |
| 3196 | + } |
| 3197 | + |
3091 | 3198 | static Alias alias(String name, Expression value) { |
3092 | 3199 | return new Alias(EMPTY, name, value); |
3093 | 3200 | } |
|
0 commit comments