Skip to content

Commit cd1ad0b

Browse files
Deterministic rerank test code (elastic#128527)
1 parent ddd4225 commit cd1ad0b

File tree

4 files changed

+111
-21
lines changed

4 files changed

+111
-21
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockRerankInferenceServiceIT.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
1717

18-
@SuppressWarnings("unchecked")
1918
public void testMockService() throws IOException {
2019
String inferenceEntityId = "test-mock";
2120
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
@@ -30,8 +29,7 @@ public void testMockService() throws IOException {
3029
List<String> input = List.of(randomAlphaOfLength(10));
3130
var inference = infer(inferenceEntityId, input);
3231
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
33-
// TODO: investigate score calculation inconsistency affecting this assertion. Uncomment when fixed
34-
// assertEquals(inference, infer(inferenceEntityId, input));
32+
assertEquals(inference, infer(inferenceEntityId, input));
3533
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
3634
}
3735

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public TestServiceModel parsePersistedConfigWithSecrets(
8686
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
8787

8888
var taskSettingsMap = getTaskSettingsMap(config);
89-
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
89+
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
9090

9191
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
9292
}
@@ -99,11 +99,15 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
9999
var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
100100

101101
var taskSettingsMap = getTaskSettingsMap(config);
102-
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
102+
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
103103

104104
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
105105
}
106106

107+
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
108+
return TestTaskSettings.fromMap(taskSettingsMap);
109+
}
110+
107111
protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);
108112

109113
@Override
@@ -149,15 +153,15 @@ public TestServiceModel(
149153
TaskType taskType,
150154
String service,
151155
ServiceSettings serviceSettings,
152-
TestTaskSettings taskSettings,
156+
TaskSettings taskSettings,
153157
TestSecretSettings secretSettings
154158
) {
155159
super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
156160
}
157161

158162
@Override
159-
public TestTaskSettings getTaskSettings() {
160-
return (TestTaskSettings) super.getTaskSettings();
163+
public TaskSettings getTaskSettings() {
164+
return super.getTaskSettings();
161165
}
162166

163167
@Override

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
4545
TestRerankingServiceExtension.TestServiceSettings.NAME,
4646
TestRerankingServiceExtension.TestServiceSettings::new
4747
),
48+
new NamedWriteableRegistry.Entry(
49+
TaskSettings.class,
50+
TestRerankingServiceExtension.TestTaskSettings.NAME,
51+
TestRerankingServiceExtension.TestTaskSettings::new
52+
),
4853
new NamedWriteableRegistry.Entry(
4954
ServiceSettings.class,
5055
TestStreamingCompletionServiceExtension.TestServiceSettings.NAME,

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.inference.ModelSecrets;
2828
import org.elasticsearch.inference.ServiceSettings;
2929
import org.elasticsearch.inference.SettingsConfiguration;
30+
import org.elasticsearch.inference.TaskSettings;
3031
import org.elasticsearch.inference.TaskType;
3132
import org.elasticsearch.inference.UnifiedCompletionRequest;
3233
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
@@ -43,6 +44,8 @@
4344
import java.util.List;
4445
import java.util.Map;
4546

47+
import static org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService.random;
48+
4649
public class TestRerankingServiceExtension implements InferenceServiceExtension {
4750

4851
@Override
@@ -84,11 +87,16 @@ public void parseRequestConfig(
8487
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
8588

8689
var taskSettingsMap = getTaskSettingsMap(config);
87-
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
90+
var taskSettings = TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
8891

8992
parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
9093
}
9194

95+
@Override
96+
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
97+
return TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
98+
}
99+
92100
@Override
93101
public InferenceServiceConfiguration getConfiguration() {
94102
return Configuration.get();
@@ -107,13 +115,15 @@ public void infer(
107115
@Nullable Integer topN,
108116
List<String> input,
109117
boolean stream,
110-
Map<String, Object> taskSettings,
118+
Map<String, Object> taskSettingsMap,
111119
InputType inputType,
112120
TimeValue timeout,
113121
ActionListener<InferenceServiceResults> listener
114122
) {
123+
TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap);
124+
115125
switch (model.getConfigurations().getTaskType()) {
116-
case ANY, RERANK -> listener.onResponse(makeResults(input));
126+
case ANY, RERANK -> listener.onResponse(makeResults(input, (TestRerankingServiceExtension.TestTaskSettings) taskSettings));
117127
default -> listener.onFailure(
118128
new ElasticsearchStatusException(
119129
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -151,7 +161,7 @@ public void chunkedInfer(
151161
);
152162
}
153163

154-
private RankedDocsResults makeResults(List<String> input) {
164+
private RankedDocsResults makeResults(List<String> input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
155165
int totalResults = input.size();
156166
try {
157167
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
@@ -161,17 +171,19 @@ private RankedDocsResults makeResults(List<String> input) {
161171
return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList());
162172
} catch (NumberFormatException ex) {
163173
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
164-
float minScore = random.nextFloat(-1f, 1f);
165-
float resultDiff = 0.2f;
174+
175+
float minScore = taskSettings.minScore();
176+
float resultDiff = taskSettings.resultDiff();
166177
for (int i = 0; i < input.size(); i++) {
167-
results.add(
168-
new RankedDocsResults.RankedDoc(
169-
totalResults - 1 - i,
170-
minScore + resultDiff * (totalResults - i),
171-
input.get(totalResults - 1 - i)
172-
)
173-
);
178+
float relevanceScore = minScore + resultDiff * (totalResults - i);
179+
String inputText = input.get(totalResults - 1 - i);
180+
if (taskSettings.useTextLength()) {
181+
relevanceScore = 1f / inputText.length();
182+
}
183+
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
174184
}
185+
// Ensure result are sorted by descending score
186+
results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
175187
return new RankedDocsResults(results);
176188
}
177189
}
@@ -208,6 +220,77 @@ public static InferenceServiceConfiguration get() {
208220
}
209221
}
210222

223+
public record TestTaskSettings(boolean useTextLength, float minScore, float resultDiff) implements TaskSettings {
224+
225+
static final String NAME = "test_reranking_task_settings";
226+
227+
public static TestTaskSettings fromMap(Map<String, Object> map) {
228+
boolean useTextLength = false;
229+
float minScore = random.nextFloat(-1f, 1f);
230+
float resultDiff = 0.2f;
231+
232+
if (map.containsKey("use_text_length")) {
233+
useTextLength = Boolean.parseBoolean(map.remove("use_text_length").toString());
234+
}
235+
236+
if (map.containsKey("min_score")) {
237+
minScore = Float.parseFloat(map.remove("min_score").toString());
238+
}
239+
240+
if (map.containsKey("result_diff")) {
241+
resultDiff = Float.parseFloat(map.remove("result_diff").toString());
242+
}
243+
244+
return new TestTaskSettings(useTextLength, minScore, resultDiff);
245+
}
246+
247+
public TestTaskSettings(StreamInput in) throws IOException {
248+
this(in.readBoolean(), in.readFloat(), in.readFloat());
249+
}
250+
251+
@Override
252+
public boolean isEmpty() {
253+
return false;
254+
}
255+
256+
@Override
257+
public void writeTo(StreamOutput out) throws IOException {
258+
out.writeBoolean(useTextLength);
259+
out.writeFloat(minScore);
260+
out.writeFloat(resultDiff);
261+
}
262+
263+
@Override
264+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
265+
builder.startObject();
266+
builder.field("use_text_length", useTextLength);
267+
builder.field("min_score", minScore);
268+
builder.field("result_diff", resultDiff);
269+
builder.endObject();
270+
return builder;
271+
}
272+
273+
@Override
274+
public String getWriteableName() {
275+
return NAME;
276+
}
277+
278+
@Override
279+
public TransportVersion getMinimalSupportedVersion() {
280+
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
281+
}
282+
283+
@Override
284+
public TaskSettings updatedTaskSettings(Map<String, Object> newSettingsMap) {
285+
TestTaskSettings newSettingsObject = fromMap(Map.copyOf(newSettingsMap));
286+
return new TestTaskSettings(
287+
newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength,
288+
newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore,
289+
newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff
290+
);
291+
}
292+
}
293+
211294
public record TestServiceSettings(String modelId) implements ServiceSettings {
212295

213296
static final String NAME = "test_reranking_service_settings";

0 commit comments

Comments
 (0)