2727import org .elasticsearch .inference .ModelSecrets ;
2828import org .elasticsearch .inference .ServiceSettings ;
2929import org .elasticsearch .inference .SettingsConfiguration ;
30+ import org .elasticsearch .inference .TaskSettings ;
3031import org .elasticsearch .inference .TaskType ;
3132import org .elasticsearch .inference .UnifiedCompletionRequest ;
3233import org .elasticsearch .inference .configuration .SettingsConfigurationFieldType ;
4344import java .util .List ;
4445import java .util .Map ;
4546
47+ import static org .elasticsearch .xpack .inference .mock .AbstractTestInferenceService .random ;
48+
4649public class TestRerankingServiceExtension implements InferenceServiceExtension {
4750
4851 @ Override
@@ -84,11 +87,16 @@ public void parseRequestConfig(
8487 var secretSettings = TestSecretSettings .fromMap (serviceSettingsMap );
8588
8689 var taskSettingsMap = getTaskSettingsMap (config );
87- var taskSettings = TestTaskSettings .fromMap (taskSettingsMap );
90+ var taskSettings = TestRerankingServiceExtension . TestTaskSettings .fromMap (taskSettingsMap );
8891
8992 parsedModelListener .onResponse (new TestServiceModel (modelId , taskType , name (), serviceSettings , taskSettings , secretSettings ));
9093 }
9194
95+ @ Override
96+ protected TaskSettings getTasksSettingsFromMap (Map <String , Object > taskSettingsMap ) {
97+ return TestRerankingServiceExtension .TestTaskSettings .fromMap (taskSettingsMap );
98+ }
99+
92100 @ Override
93101 public InferenceServiceConfiguration getConfiguration () {
94102 return Configuration .get ();
@@ -107,13 +115,15 @@ public void infer(
107115 @ Nullable Integer topN ,
108116 List <String > input ,
109117 boolean stream ,
110- Map <String , Object > taskSettings ,
118+ Map <String , Object > taskSettingsMap ,
111119 InputType inputType ,
112120 TimeValue timeout ,
113121 ActionListener <InferenceServiceResults > listener
114122 ) {
123+ TaskSettings taskSettings = model .getTaskSettings ().updatedTaskSettings (taskSettingsMap );
124+
115125 switch (model .getConfigurations ().getTaskType ()) {
116- case ANY , RERANK -> listener .onResponse (makeResults (input ));
126+ case ANY , RERANK -> listener .onResponse (makeResults (input , ( TestRerankingServiceExtension . TestTaskSettings ) taskSettings ));
117127 default -> listener .onFailure (
118128 new ElasticsearchStatusException (
119129 TaskType .unsupportedTaskTypeErrorMsg (model .getConfigurations ().getTaskType (), name ()),
@@ -151,7 +161,7 @@ public void chunkedInfer(
151161 );
152162 }
153163
154- private RankedDocsResults makeResults (List <String > input ) {
164+ private RankedDocsResults makeResults (List <String > input , TestRerankingServiceExtension . TestTaskSettings taskSettings ) {
155165 int totalResults = input .size ();
156166 try {
157167 List <RankedDocsResults .RankedDoc > results = new ArrayList <>();
@@ -161,17 +171,19 @@ private RankedDocsResults makeResults(List<String> input) {
161171 return new RankedDocsResults (results .stream ().sorted (Comparator .reverseOrder ()).toList ());
162172 } catch (NumberFormatException ex ) {
163173 List <RankedDocsResults .RankedDoc > results = new ArrayList <>();
164- float minScore = random .nextFloat (-1f , 1f );
165- float resultDiff = 0.2f ;
174+
175+ float minScore = taskSettings .minScore ();
176+ float resultDiff = taskSettings .resultDiff ();
166177 for (int i = 0 ; i < input .size (); i ++) {
167- results .add (
168- new RankedDocsResults .RankedDoc (
169- totalResults - 1 - i ,
170- minScore + resultDiff * (totalResults - i ),
171- input .get (totalResults - 1 - i )
172- )
173- );
178+ float relevanceScore = minScore + resultDiff * (totalResults - i );
179+ String inputText = input .get (totalResults - 1 - i );
180+ if (taskSettings .useTextLength ()) {
181+ relevanceScore = 1f / inputText .length ();
182+ }
183+ results .add (new RankedDocsResults .RankedDoc (totalResults - 1 - i , relevanceScore , inputText ));
174184 }
185+ // Ensure result are sorted by descending score
186+ results .sort ((a , b ) -> -Float .compare (a .relevanceScore (), b .relevanceScore ()));
175187 return new RankedDocsResults (results );
176188 }
177189 }
@@ -208,6 +220,77 @@ public static InferenceServiceConfiguration get() {
208220 }
209221 }
210222
223+ public record TestTaskSettings (boolean useTextLength , float minScore , float resultDiff ) implements TaskSettings {
224+
225+ static final String NAME = "test_reranking_task_settings" ;
226+
227+ public static TestTaskSettings fromMap (Map <String , Object > map ) {
228+ boolean useTextLength = false ;
229+ float minScore = random .nextFloat (-1f , 1f );
230+ float resultDiff = 0.2f ;
231+
232+ if (map .containsKey ("use_text_length" )) {
233+ useTextLength = Boolean .parseBoolean (map .remove ("use_text_length" ).toString ());
234+ }
235+
236+ if (map .containsKey ("min_score" )) {
237+ minScore = Float .parseFloat (map .remove ("min_score" ).toString ());
238+ }
239+
240+ if (map .containsKey ("result_diff" )) {
241+ resultDiff = Float .parseFloat (map .remove ("result_diff" ).toString ());
242+ }
243+
244+ return new TestTaskSettings (useTextLength , minScore , resultDiff );
245+ }
246+
247+ public TestTaskSettings (StreamInput in ) throws IOException {
248+ this (in .readBoolean (), in .readFloat (), in .readFloat ());
249+ }
250+
251+ @ Override
252+ public boolean isEmpty () {
253+ return false ;
254+ }
255+
256+ @ Override
257+ public void writeTo (StreamOutput out ) throws IOException {
258+ out .writeBoolean (useTextLength );
259+ out .writeFloat (minScore );
260+ out .writeFloat (resultDiff );
261+ }
262+
263+ @ Override
264+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
265+ builder .startObject ();
266+ builder .field ("use_text_length" , useTextLength );
267+ builder .field ("min_score" , minScore );
268+ builder .field ("result_diff" , resultDiff );
269+ builder .endObject ();
270+ return builder ;
271+ }
272+
273+ @ Override
274+ public String getWriteableName () {
275+ return NAME ;
276+ }
277+
278+ @ Override
279+ public TransportVersion getMinimalSupportedVersion () {
280+ return TransportVersion .current (); // fine for these tests but will not work for cluster upgrade tests
281+ }
282+
283+ @ Override
284+ public TaskSettings updatedTaskSettings (Map <String , Object > newSettingsMap ) {
285+ TestTaskSettings newSettingsObject = fromMap (Map .copyOf (newSettingsMap ));
286+ return new TestTaskSettings (
287+ newSettingsMap .containsKey ("use_text_length" ) ? newSettingsObject .useTextLength () : useTextLength ,
288+ newSettingsMap .containsKey ("min_score" ) ? newSettingsObject .minScore () : minScore ,
289+ newSettingsMap .containsKey ("result_diff" ) ? newSettingsObject .resultDiff () : resultDiff
290+ );
291+ }
292+ }
293+
211294 public record TestServiceSettings (String modelId ) implements ServiceSettings {
212295
213296 static final String NAME = "test_reranking_service_settings" ;
0 commit comments