Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {

@SuppressWarnings("unchecked")
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
Expand All @@ -30,8 +29,7 @@ public void testMockService() throws IOException {
List<String> input = List.of(randomAlphaOfLength(10));
var inference = infer(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
// TODO: investigate score calculation inconsistency affecting this assertion. Uncomment when fixed
// assertEquals(inference, infer(inferenceEntityId, input));
assertEquals(inference, infer(inferenceEntityId, input));
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public TestServiceModel parsePersistedConfigWithSecrets(
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);

return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
}
Expand All @@ -99,11 +99,15 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);

return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
}

protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
return TestTaskSettings.fromMap(taskSettingsMap);
}

protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);

@Override
Expand Down Expand Up @@ -149,15 +153,15 @@ public TestServiceModel(
TaskType taskType,
String service,
ServiceSettings serviceSettings,
TestTaskSettings taskSettings,
TaskSettings taskSettings,
TestSecretSettings secretSettings
) {
super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
}

@Override
public TestTaskSettings getTaskSettings() {
return (TestTaskSettings) super.getTaskSettings();
public TaskSettings getTaskSettings() {
return super.getTaskSettings();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
TestRerankingServiceExtension.TestServiceSettings.NAME,
TestRerankingServiceExtension.TestServiceSettings::new
),
new NamedWriteableRegistry.Entry(
TaskSettings.class,
TestRerankingServiceExtension.TestTaskSettings.NAME,
TestRerankingServiceExtension.TestTaskSettings::new
),
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
TestStreamingCompletionServiceExtension.TestServiceSettings.NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
Expand All @@ -43,6 +44,8 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService.random;

public class TestRerankingServiceExtension implements InferenceServiceExtension {

@Override
Expand Down Expand Up @@ -84,11 +87,16 @@ public void parseRequestConfig(
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
var taskSettings = TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);

parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
}

@Override
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
return TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
}

@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
Expand All @@ -107,13 +115,15 @@ public void infer(
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
Map<String, Object> taskSettingsMap,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap);

switch (model.getConfigurations().getTaskType()) {
case ANY, RERANK -> listener.onResponse(makeResults(input));
case ANY, RERANK -> listener.onResponse(makeResults(input, (TestRerankingServiceExtension.TestTaskSettings) taskSettings));
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
Expand Down Expand Up @@ -151,7 +161,7 @@ public void chunkedInfer(
);
}

private RankedDocsResults makeResults(List<String> input) {
private RankedDocsResults makeResults(List<String> input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
int totalResults = input.size();
try {
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
Expand All @@ -161,17 +171,19 @@ private RankedDocsResults makeResults(List<String> input) {
return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList());
} catch (NumberFormatException ex) {
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
float minScore = random.nextFloat(-1f, 1f);
float resultDiff = 0.2f;

float minScore = taskSettings.minScore();
float resultDiff = taskSettings.resultDiff();
for (int i = 0; i < input.size(); i++) {
results.add(
new RankedDocsResults.RankedDoc(
totalResults - 1 - i,
minScore + resultDiff * (totalResults - i),
input.get(totalResults - 1 - i)
)
);
float relevanceScore = minScore + resultDiff * (totalResults - i);
String inputText = input.get(totalResults - 1 - i);
if (taskSettings.useTextLength()) {
relevanceScore = 1f / inputText.length();
}
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
}
// Ensure result are sorted by descending score
results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
return new RankedDocsResults(results);
}
}
Expand Down Expand Up @@ -208,6 +220,77 @@ public static InferenceServiceConfiguration get() {
}
}

public record TestTaskSettings(boolean useTextLength, float minScore, float resultDiff) implements TaskSettings {

static final String NAME = "test_reranking_task_settings";

public static TestTaskSettings fromMap(Map<String, Object> map) {
boolean useTextLength = false;
float minScore = random.nextFloat(-1f, 1f);
float resultDiff = 0.2f;

if (map.containsKey("use_text_length")) {
useTextLength = Boolean.parseBoolean(map.remove("use_text_length").toString());
}

if (map.containsKey("min_score")) {
minScore = Float.parseFloat(map.remove("min_score").toString());
}

if (map.containsKey("result_diff")) {
resultDiff = Float.parseFloat(map.remove("result_diff").toString());
}

return new TestTaskSettings(useTextLength, minScore, resultDiff);
}

public TestTaskSettings(StreamInput in) throws IOException {
this(in.readBoolean(), in.readFloat(), in.readFloat());
}

@Override
public boolean isEmpty() {
return false;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(useTextLength);
out.writeFloat(minScore);
out.writeFloat(resultDiff);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("use_text_length", useTextLength);
builder.field("min_score", minScore);
builder.field("result_diff", resultDiff);
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettingsMap) {
TestTaskSettings newSettingsObject = fromMap(Map.copyOf(newSettingsMap));
return new TestTaskSettings(
newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength,
newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore,
newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff
);
}
}

public record TestServiceSettings(String modelId) implements ServiceSettings {

static final String NAME = "test_reranking_service_settings";
Expand Down