diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java index 80477af94ec12..7aabb4f58d04c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java @@ -15,7 +15,6 @@ public class MockRerankInferenceServiceIT extends InferenceBaseRestTest { - @SuppressWarnings("unchecked") public void testMockService() throws IOException { String inferenceEntityId = "test-mock"; var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK); @@ -30,8 +29,7 @@ public void testMockService() throws IOException { List input = List.of(randomAlphaOfLength(10)); var inference = infer(inferenceEntityId, input); assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK); - // TODO: investigate score calculation inconsistency affecting this assertion. Uncomment when fixed - // assertEquals(inference, infer(inferenceEntityId, input)); + assertEquals(inference, infer(inferenceEntityId, input)); assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 7d4a120668a8b..34e2af8034527 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -86,7 +86,7 @@ public TestServiceModel parsePersistedConfigWithSecrets( var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + var taskSettings = getTasksSettingsFromMap(taskSettingsMap); return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); } @@ -99,11 +99,15 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map taskSettingsMap) { + return TestTaskSettings.fromMap(taskSettingsMap); + } + protected abstract ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap); @Override @@ -149,15 +153,15 @@ public TestServiceModel( TaskType taskType, String service, ServiceSettings serviceSettings, - TestTaskSettings taskSettings, + TaskSettings taskSettings, TestSecretSettings secretSettings ) { super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); } @Override - public TestTaskSettings getTaskSettings() { - return (TestTaskSettings) super.getTaskSettings(); + public TaskSettings getTaskSettings() { + return super.getTaskSettings(); } @Override diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java index eef0da909f529..1d04aab022f91 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java @@ -45,6 +45,11 @@ public List getNamedWriteables() { TestRerankingServiceExtension.TestServiceSettings.NAME, TestRerankingServiceExtension.TestServiceSettings::new ), + new NamedWriteableRegistry.Entry( + TaskSettings.class, + TestRerankingServiceExtension.TestTaskSettings.NAME, + TestRerankingServiceExtension.TestTaskSettings::new + ), new NamedWriteableRegistry.Entry( ServiceSettings.class, TestStreamingCompletionServiceExtension.TestServiceSettings.NAME, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 989726443ecf4..7575cf5197dff 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; @@ -43,6 +44,8 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService.random; + public class TestRerankingServiceExtension implements InferenceServiceExtension { @Override @@ -84,11 +87,16 @@ public void parseRequestConfig( var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + var taskSettings = TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap); parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings)); } + @Override + protected TaskSettings getTasksSettingsFromMap(Map taskSettingsMap) { + return TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap); + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -107,13 +115,15 @@ public void infer( @Nullable Integer topN, List input, boolean stream, - Map taskSettings, + Map taskSettingsMap, InputType inputType, TimeValue timeout, ActionListener listener ) { + TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap); + switch (model.getConfigurations().getTaskType()) { - case ANY, RERANK -> listener.onResponse(makeResults(input)); + case ANY, RERANK -> listener.onResponse(makeResults(input, (TestRerankingServiceExtension.TestTaskSettings) taskSettings)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -151,7 +161,7 @@ public void chunkedInfer( ); } - private RankedDocsResults makeResults(List input) { + private RankedDocsResults makeResults(List input, TestRerankingServiceExtension.TestTaskSettings taskSettings) { int totalResults = input.size(); try { List results = new ArrayList<>(); @@ -161,17 +171,19 @@ private RankedDocsResults makeResults(List input) { return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList()); } catch (NumberFormatException ex) { List results = new ArrayList<>(); - float minScore = random.nextFloat(-1f, 1f); - float resultDiff = 0.2f; + + float minScore = taskSettings.minScore(); + float resultDiff = taskSettings.resultDiff(); for (int i = 0; i < input.size(); i++) { - results.add( - new RankedDocsResults.RankedDoc( - totalResults - 1 - i, - minScore + resultDiff * (totalResults - i), - input.get(totalResults - 1 - i) - ) - ); + float relevanceScore = minScore + resultDiff * (totalResults - i); + String inputText = input.get(totalResults - 1 - i); + if (taskSettings.useTextLength()) { + relevanceScore = 1f / inputText.length(); + } + results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText)); } + // Ensure result are sorted by descending score + results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore())); return new RankedDocsResults(results); } } @@ -208,6 +220,77 @@ public static InferenceServiceConfiguration get() { } } + public record TestTaskSettings(boolean useTextLength, float minScore, float resultDiff) implements TaskSettings { + + static final String NAME = "test_reranking_task_settings"; + + public static TestTaskSettings fromMap(Map map) { + boolean useTextLength = false; + float minScore = random.nextFloat(-1f, 1f); + float resultDiff = 0.2f; + + if (map.containsKey("use_text_length")) { + useTextLength = Boolean.parseBoolean(map.remove("use_text_length").toString()); + } + + if (map.containsKey("min_score")) { + minScore = Float.parseFloat(map.remove("min_score").toString()); + } + + if (map.containsKey("result_diff")) { + resultDiff = Float.parseFloat(map.remove("result_diff").toString()); + } + + return new TestTaskSettings(useTextLength, minScore, resultDiff); + } + + public TestTaskSettings(StreamInput in) throws IOException { + this(in.readBoolean(), in.readFloat(), in.readFloat()); + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(useTextLength); + out.writeFloat(minScore); + out.writeFloat(resultDiff); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("use_text_length", useTextLength); + builder.field("min_score", minScore); + builder.field("result_diff", resultDiff); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettingsMap) { + TestTaskSettings newSettingsObject = fromMap(Map.copyOf(newSettingsMap)); + return new TestTaskSettings( + newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength, + newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore, + newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff + ); + } + } + public record TestServiceSettings(String modelId) implements ServiceSettings { static final String NAME = "test_reranking_service_settings";