Skip to content

Commit d5ebe19

Browse files
committed
Add rerank to the unit tests
# Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
1 parent 3004d7c commit d5ebe19

File tree

28 files changed

+269
-42
lines changed

28 files changed

+269
-42
lines changed

server/src/main/java/org/elasticsearch/inference/RerankingInferenceService.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@
1111

1212
public interface RerankingInferenceService {
1313

14+
/**
15+
* The default window size for small reranking models.
16+
*/
17+
int CONSERVATIVE_DEFAULT_WINDOW_SIZE = 250;
18+
int LARGE_WINDOW_SIZE = 500;
19+
1420
/**
1521
* The reranking model's max window or an approximation of
1622
* measured in the number of words.
23+
* @param modelId The model ID
1724
* @return Window size in words
1825
*/
19-
int rerankerWindowSize();
26+
int rerankerWindowSize(String modelId);
2027
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.inference.Model;
2727
import org.elasticsearch.inference.ModelConfigurations;
2828
import org.elasticsearch.inference.ModelSecrets;
29+
import org.elasticsearch.inference.RerankingInferenceService;
2930
import org.elasticsearch.inference.SettingsConfiguration;
3031
import org.elasticsearch.inference.SimilarityMeasure;
3132
import org.elasticsearch.inference.TaskType;
@@ -72,7 +73,7 @@
7273
import static org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS;
7374
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
7475

75-
public class AzureAiStudioService extends SenderService {
76+
public class AzureAiStudioService extends SenderService implements RerankingInferenceService {
7677

7778
public static final String NAME = "azureaistudio";
7879

@@ -400,6 +401,11 @@ private static void checkProviderAndEndpointTypeForTask(
400401
}
401402
}
402403

404+
@Override
405+
public int rerankerWindowSize(String modelId) {
406+
return RerankingInferenceService.LARGE_WINDOW_SIZE;
407+
}
408+
403409
public static class Configuration {
404410
public static InferenceServiceConfiguration get() {
405411
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.inference.Model;
2727
import org.elasticsearch.inference.ModelConfigurations;
2828
import org.elasticsearch.inference.ModelSecrets;
29+
import org.elasticsearch.inference.RerankingInferenceService;
2930
import org.elasticsearch.inference.SettingsConfiguration;
3031
import org.elasticsearch.inference.SimilarityMeasure;
3132
import org.elasticsearch.inference.TaskType;
@@ -66,7 +67,7 @@
6667
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6768
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE;
6869

69-
public class CohereService extends SenderService {
70+
public class CohereService extends SenderService implements RerankingInferenceService {
7071
public static final String NAME = "cohere";
7172

7273
private static final String SERVICE_NAME = "Cohere";
@@ -361,6 +362,12 @@ public Set<TaskType> supportedStreamingTasks() {
361362
return COMPLETION_ONLY;
362363
}
363364

365+
@Override
366+
public int rerankerWindowSize(String modelId) {
367+
// Cohere rerank model truncates at 4096 tokens https://docs.cohere.com/reference/rerank
368+
return RerankingInferenceService.LARGE_WINDOW_SIZE;
369+
}
370+
364371
public static class Configuration {
365372
public static InferenceServiceConfiguration get() {
366373
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.inference.MinimalServiceSettings;
3232
import org.elasticsearch.inference.Model;
3333
import org.elasticsearch.inference.ModelConfigurations;
34+
import org.elasticsearch.inference.RerankingInferenceService;
3435
import org.elasticsearch.inference.SettingsConfiguration;
3536
import org.elasticsearch.inference.TaskType;
3637
import org.elasticsearch.inference.UnifiedCompletionRequest;
@@ -84,7 +85,7 @@
8485
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL;
8586
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86;
8687

87-
public class ElasticsearchInternalService extends BaseElasticsearchInternalService {
88+
public class ElasticsearchInternalService extends BaseElasticsearchInternalService implements RerankingInferenceService {
8889

8990
public static final String NAME = "elasticsearch";
9091
public static final String OLD_ELSER_SERVICE_NAME = "elser";
@@ -1060,6 +1061,12 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
10601061
}
10611062
}
10621063

1064+
@Override
1065+
public int rerankerWindowSize(String modelId) {
1066+
// TODO rerank chunking should use the same value
1067+
return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE;
1068+
}
1069+
10631070
/**
10641071
* Iterates over the batch executing a limited number requests at a time to avoid
10651072
* filling the ML node inference queue.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.Model;
2626
import org.elasticsearch.inference.ModelConfigurations;
2727
import org.elasticsearch.inference.ModelSecrets;
28+
import org.elasticsearch.inference.RerankingInferenceService;
2829
import org.elasticsearch.inference.SettingsConfiguration;
2930
import org.elasticsearch.inference.TaskType;
3031
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
@@ -69,7 +70,7 @@
6970
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
7071
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;
7172

72-
public class GoogleVertexAiService extends SenderService {
73+
public class GoogleVertexAiService extends SenderService implements RerankingInferenceService {
7374

7475
public static final String NAME = "googlevertexai";
7576

@@ -383,6 +384,11 @@ private static GoogleVertexAiModel createModel(
383384
};
384385
}
385386

387+
@Override
388+
public int rerankerWindowSize(String modelId) {
389+
return 0; // TODO
390+
}
391+
386392
public static class Configuration {
387393
public static InferenceServiceConfiguration get() {
388394
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.Model;
2626
import org.elasticsearch.inference.ModelConfigurations;
2727
import org.elasticsearch.inference.ModelSecrets;
28+
import org.elasticsearch.inference.RerankingInferenceService;
2829
import org.elasticsearch.inference.SettingsConfiguration;
2930
import org.elasticsearch.inference.SimilarityMeasure;
3031
import org.elasticsearch.inference.TaskType;
@@ -61,7 +62,7 @@
6162
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
6263
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6364

64-
public class VoyageAIService extends SenderService {
65+
public class VoyageAIService extends SenderService implements RerankingInferenceService {
6566
public static final String NAME = "voyageai";
6667

6768
private static final String SERVICE_NAME = "Voyage AI";
@@ -369,6 +370,12 @@ public TransportVersion getMinimalSupportedVersion() {
369370
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
370371
}
371372

373+
@Override
374+
public int rerankerWindowSize(String modelId) {
375+
https://docs.voyageai.com/reference/reranker-api
376+
return RerankingInferenceService.LARGE_WINDOW_SIZE;
377+
}
378+
372379
public static class Configuration {
373380
public static InferenceServiceConfiguration get() {
374381
return configuration.getOrCompute();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.inference.Model;
1818
import org.elasticsearch.inference.SimilarityMeasure;
1919
import org.elasticsearch.inference.TaskType;
20-
import org.elasticsearch.test.ESTestCase;
2120
import org.elasticsearch.test.http.MockWebServer;
2221
import org.elasticsearch.threadpool.ThreadPool;
2322
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
@@ -52,7 +51,7 @@
5251
* To use this class, extend it and pass the constructor a configuration.
5352
* </p>
5453
*/
55-
public abstract class AbstractInferenceServiceTests extends ESTestCase {
54+
public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase {
5655

5756
protected final MockWebServer webServer = new MockWebServer();
5857
protected ThreadPool threadPool;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceServiceTestCase.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,47 @@
88
package org.elasticsearch.xpack.inference.services;
99

1010
import org.elasticsearch.inference.InferenceService;
11+
import org.elasticsearch.inference.RerankingInferenceService;
12+
import org.elasticsearch.inference.TaskType;
1113
import org.elasticsearch.test.ESTestCase;
1214

15+
import java.io.IOException;
16+
1317
public abstract class InferenceServiceTestCase extends ESTestCase {
1418

1519
public abstract InferenceService createInferenceService();
20+
21+
public void testRerankersImplementRerankInterface() throws IOException {
22+
try (InferenceService inferenceService = createInferenceService()) {
23+
boolean implementsReranking = inferenceService instanceof RerankingInferenceService;
24+
boolean hasRerankTaskType = inferenceService.supportedTaskTypes().contains(TaskType.RERANK);
25+
if (implementsReranking != hasRerankTaskType) {
26+
fail(
27+
"Reranking inference services should implement RerankingInferenceService and support the RERANK task type. "
28+
+ "Service ["
29+
+ inferenceService.name()
30+
+ "] supports task type: ["
31+
+ hasRerankTaskType
32+
+ "] and implements"
33+
+ " RerankingInferenceService: ["
34+
+ implementsReranking
35+
+ "]"
36+
);
37+
}
38+
}
39+
}
40+
41+
public void testRerankersHaveWindowSize() throws IOException {
42+
try (InferenceService inferenceService = createInferenceService()) {
43+
if (inferenceService instanceof RerankingInferenceService rerankingInferenceService) {
44+
assertRerankerWindowSize(rerankingInferenceService);
45+
}
46+
}
47+
}
48+
49+
protected void assertRerankerWindowSize(
50+
RerankingInferenceService rerankingInferenceService
51+
) {
52+
fail("Reranking services should override this test method to verify window size");
53+
}
1654
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import org.elasticsearch.inference.ChunkInferenceInput;
2020
import org.elasticsearch.inference.ChunkedInference;
2121
import org.elasticsearch.inference.ChunkingSettings;
22+
import org.elasticsearch.inference.InferenceService;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
2324
import org.elasticsearch.inference.InferenceServiceResults;
2425
import org.elasticsearch.inference.InputType;
2526
import org.elasticsearch.inference.Model;
2627
import org.elasticsearch.inference.ModelConfigurations;
2728
import org.elasticsearch.inference.SimilarityMeasure;
2829
import org.elasticsearch.inference.TaskType;
29-
import org.elasticsearch.test.ESTestCase;
3030
import org.elasticsearch.threadpool.ThreadPool;
3131
import org.elasticsearch.xcontent.ToXContent;
3232
import org.elasticsearch.xcontent.XContentType;
@@ -42,6 +42,7 @@
4242
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
4343
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
4444
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
45+
import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
4546
import org.elasticsearch.xpack.inference.services.ServiceFields;
4647
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.action.AlibabaCloudSearchActionVisitor;
4748
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
@@ -73,7 +74,7 @@
7374
import static org.hamcrest.Matchers.instanceOf;
7475
import static org.mockito.Mockito.mock;
7576

76-
public class AlibabaCloudSearchServiceTests extends ESTestCase {
77+
public class AlibabaCloudSearchServiceTests extends InferenceServiceTestCase {
7778
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
7879
private ThreadPool threadPool;
7980
private HttpClientManager clientManager;
@@ -710,4 +711,9 @@ private Map<String, Object> getRequestConfigMap(
710711
Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)
711712
);
712713
}
714+
715+
@Override
716+
public InferenceService createInferenceService() {
717+
return new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
718+
}
713719
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.inference.ChunkInferenceInput;
2424
import org.elasticsearch.inference.ChunkedInference;
2525
import org.elasticsearch.inference.ChunkingSettings;
26+
import org.elasticsearch.inference.InferenceService;
2627
import org.elasticsearch.inference.InferenceServiceConfiguration;
2728
import org.elasticsearch.inference.InferenceServiceResults;
2829
import org.elasticsearch.inference.InputType;
@@ -31,7 +32,6 @@
3132
import org.elasticsearch.inference.ModelSecrets;
3233
import org.elasticsearch.inference.SimilarityMeasure;
3334
import org.elasticsearch.inference.TaskType;
34-
import org.elasticsearch.test.ESTestCase;
3535
import org.elasticsearch.threadpool.ThreadPool;
3636
import org.elasticsearch.xcontent.ToXContent;
3737
import org.elasticsearch.xcontent.XContentType;
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
4444
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
4545
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
46+
import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
4647
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
4748
import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockMockRequestSender;
4849
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
@@ -92,7 +93,7 @@
9293
import static org.mockito.Mockito.verifyNoMoreInteractions;
9394
import static org.mockito.Mockito.when;
9495

95-
public class AmazonBedrockServiceTests extends ESTestCase {
96+
public class AmazonBedrockServiceTests extends InferenceServiceTestCase {
9697
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
9798
private ThreadPool threadPool;
9899

@@ -1405,6 +1406,11 @@ private AmazonBedrockService createAmazonBedrockService() {
14051406
);
14061407
}
14071408

1409+
@Override
1410+
public InferenceService createInferenceService() {
1411+
return createAmazonBedrockService();
1412+
}
1413+
14081414
private Map<String, Object> getRequestConfigMap(
14091415
Map<String, Object> serviceSettings,
14101416
Map<String, Object> taskSettings,

0 commit comments

Comments
 (0)