Skip to content

Commit 086a85e

Browse files
authored
Merge branch 'main' into esql_less_data_for_lookup
2 parents b7bb7e6 + 1a641e5 commit 086a85e

File tree

34 files changed

+1576
-92
lines changed

34 files changed

+1576
-92
lines changed

docs/changelog/127966.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127966
2+
summary: "[ML] Add Rerank support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ static TransportVersion def(int id) {
179179
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
180180
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
181181
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
182+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
182183
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
183184
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
184185
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -261,7 +262,7 @@ static TransportVersion def(int id) {
261262
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
262263
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
263264
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
264-
265+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
265266
/*
266267
* STOP! READ THIS FIRST! No, really,
267268
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

server/src/main/java/org/elasticsearch/index/engine/ThreadPoolMergeScheduler.java

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ public class ThreadPoolMergeScheduler extends MergeScheduler implements Elastics
6767
private volatile boolean closed = false;
6868
private final MergeMemoryEstimateProvider mergeMemoryEstimateProvider;
6969

70+
/**
71+
* Creates a thread-pool-based merge scheduler that runs merges in a thread pool.
72+
*
73+
* @param shardId the shard id associated with this merge scheduler
74+
* @param indexSettings used to obtain the {@link MergeSchedulerConfig}
75+
* @param threadPoolMergeExecutorService the executor service used to execute merge tasks from this scheduler
76+
* @param mergeMemoryEstimateProvider provides an estimate for how much memory a merge will take
77+
*/
7078
public ThreadPoolMergeScheduler(
7179
ShardId shardId,
7280
IndexSettings indexSettings,
@@ -146,6 +154,16 @@ protected void beforeMerge(OnGoingMerge merge) {}
146154
*/
147155
protected void afterMerge(OnGoingMerge merge) {}
148156

157+
/**
158+
* A callback allowing for custom logic when a merge is queued.
159+
*/
160+
protected void mergeQueued(OnGoingMerge merge) {}
161+
162+
/**
163+
* A callback allowing for custom logic after a merge is executed or aborted.
164+
*/
165+
protected void mergeExecutedOrAborted(OnGoingMerge merge) {}
166+
149167
/**
150168
* A callback that's invoked when indexing should throttle down indexing in order to let merging to catch up.
151169
*/
@@ -157,6 +175,34 @@ protected void enableIndexingThrottling(int numRunningMerges, int numQueuedMerge
157175
*/
158176
protected void disableIndexingThrottling(int numRunningMerges, int numQueuedMerges, int configuredMaxMergeCount) {}
159177

178+
/**
179+
* Returns true if scheduled merges should be skipped (aborted)
180+
*/
181+
protected boolean shouldSkipMerge() {
182+
return false;
183+
}
184+
185+
/**
186+
* Returns true if IO-throttling is enabled
187+
*/
188+
protected boolean isAutoThrottle() {
189+
return config.isAutoThrottle();
190+
}
191+
192+
/**
193+
* Returns the maximum number of active merges before being throttled
194+
*/
195+
protected int getMaxMergeCount() {
196+
return config.getMaxMergeCount();
197+
}
198+
199+
/**
200+
* Returns the maximum number of threads running merges before being throttled
201+
*/
202+
protected int getMaxThreadCount() {
203+
return config.getMaxThreadCount();
204+
}
205+
160206
/**
161207
* A callback for exceptions thrown while merging.
162208
*/
@@ -168,6 +214,7 @@ protected void handleMergeException(Throwable t) {
168214
boolean submitNewMergeTask(MergeSource mergeSource, MergePolicy.OneMerge merge, MergeTrigger mergeTrigger) {
169215
try {
170216
MergeTask mergeTask = newMergeTask(mergeSource, merge, mergeTrigger);
217+
mergeQueued(mergeTask.onGoingMerge);
171218
return threadPoolMergeExecutorService.submitMergeTask(mergeTask);
172219
} finally {
173220
checkMergeTaskThrottling();
@@ -183,7 +230,7 @@ MergeTask newMergeTask(MergeSource mergeSource, MergePolicy.OneMerge merge, Merg
183230
return new MergeTask(
184231
mergeSource,
185232
merge,
186-
isAutoThrottle && config.isAutoThrottle(),
233+
isAutoThrottle && isAutoThrottle(),
187234
"Lucene Merge Task #" + submittedMergeTaskCount.incrementAndGet() + " for shard " + shardId,
188235
estimateMergeMemoryBytes
189236
);
@@ -193,7 +240,7 @@ private void checkMergeTaskThrottling() {
193240
long submittedMergesCount = submittedMergeTaskCount.get();
194241
long doneMergesCount = doneMergeTaskCount.get();
195242
int runningMergesCount = runningMergeTasks.size();
196-
int configuredMaxMergeCount = config.getMaxMergeCount();
243+
int configuredMaxMergeCount = getMaxMergeCount();
197244
// both currently running and enqueued merge tasks are considered "active" for throttling purposes
198245
int activeMerges = (int) (submittedMergesCount - doneMergesCount);
199246
if (activeMerges > configuredMaxMergeCount
@@ -223,7 +270,12 @@ synchronized Schedule schedule(MergeTask mergeTask) {
223270
if (closed) {
224271
// do not run or backlog tasks when closing the merge scheduler, instead abort them
225272
return Schedule.ABORT;
226-
} else if (runningMergeTasks.size() < config.getMaxThreadCount()) {
273+
} else if (shouldSkipMerge()) {
274+
if (verbose()) {
275+
message(String.format(Locale.ROOT, "skipping merge task %s", mergeTask));
276+
}
277+
return Schedule.ABORT;
278+
} else if (runningMergeTasks.size() < getMaxThreadCount()) {
227279
boolean added = runningMergeTasks.put(mergeTask.onGoingMerge.getMerge(), mergeTask) == null;
228280
assert added : "starting merge task [" + mergeTask + "] registered as already running";
229281
return Schedule.RUN;
@@ -243,8 +295,9 @@ synchronized void mergeTaskFinishedRunning(MergeTask mergeTask) {
243295
maybeSignalAllMergesDoneAfterClose();
244296
}
245297

246-
private void mergeTaskDone() {
298+
private void mergeTaskDone(OnGoingMerge merge) {
247299
doneMergeTaskCount.incrementAndGet();
300+
mergeExecutedOrAborted(merge);
248301
checkMergeTaskThrottling();
249302
}
250303

@@ -255,7 +308,7 @@ private synchronized void maybeSignalAllMergesDoneAfterClose() {
255308
}
256309

257310
private synchronized void enqueueBackloggedTasks() {
258-
int maxBackloggedTasksToEnqueue = config.getMaxThreadCount() - runningMergeTasks.size();
311+
int maxBackloggedTasksToEnqueue = getMaxThreadCount() - runningMergeTasks.size();
259312
// enqueue all backlogged tasks when closing, as the queue expects all backlogged tasks to always be enqueued back
260313
while (closed || maxBackloggedTasksToEnqueue-- > 0) {
261314
MergeTask backloggedMergeTask = backloggedMergeTasks.poll();
@@ -408,7 +461,7 @@ public void run() {
408461
try {
409462
mergeTaskFinishedRunning(this);
410463
} finally {
411-
mergeTaskDone();
464+
mergeTaskDone(onGoingMerge);
412465
}
413466
try {
414467
// kick-off any follow-up merge
@@ -452,7 +505,7 @@ void abort() {
452505
if (verbose()) {
453506
message(String.format(Locale.ROOT, "merge task %s end abort", this));
454507
}
455-
mergeTaskDone();
508+
mergeTaskDone(onGoingMerge);
456509
}
457510
}
458511

server/src/test/java/org/elasticsearch/index/engine/ThreadPoolMergeSchedulerTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,31 @@ public void testAutoIOThrottleForMergeTasks() throws Exception {
662662
}
663663
}
664664

665+
public void testMergeSchedulerAbortsMergeWhenShouldSkipMergeIsTrue() {
666+
ThreadPoolMergeExecutorService threadPoolMergeExecutorService = mock(ThreadPoolMergeExecutorService.class);
667+
// build a scheduler that always returns true for shouldSkipMerge
668+
ThreadPoolMergeScheduler threadPoolMergeScheduler = new ThreadPoolMergeScheduler(
669+
new ShardId("index", "_na_", 1),
670+
IndexSettingsModule.newIndexSettings("index", Settings.builder().build()),
671+
threadPoolMergeExecutorService,
672+
merge -> 0
673+
) {
674+
@Override
675+
protected boolean shouldSkipMerge() {
676+
return true;
677+
}
678+
};
679+
MergeSource mergeSource = mock(MergeSource.class);
680+
OneMerge oneMerge = mock(OneMerge.class);
681+
when(oneMerge.getStoreMergeInfo()).thenReturn(getNewMergeInfo(randomLongBetween(1L, 10L)));
682+
when(oneMerge.getMergeProgress()).thenReturn(new MergePolicy.OneMergeProgress());
683+
when(mergeSource.getNextMerge()).thenReturn(oneMerge, (OneMerge) null);
684+
MergeTask mergeTask = threadPoolMergeScheduler.newMergeTask(mergeSource, oneMerge, randomFrom(MergeTrigger.values()));
685+
// verify that calling schedule on the merge task indicates the merge should be aborted
686+
Schedule schedule = threadPoolMergeScheduler.schedule(mergeTask);
687+
assertThat(schedule, is(Schedule.ABORT));
688+
}
689+
665690
private static MergeInfo getNewMergeInfo(long estimatedMergeBytes) {
666691
return getNewMergeInfo(estimatedMergeBytes, randomFrom(-1, randomNonNegativeInt()));
667692
}

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
public class SettingsConfigurationTestUtils {
2121

2222
public static SettingsConfiguration getRandomSettingsConfigurationField() {
23-
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
24-
randomAlphaOfLength(10)
25-
)
23+
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
24+
.setDefaultValue(randomAlphaOfLength(10))
2625
.setDescription(randomAlphaOfLength(10))
2726
.setLabel(randomAlphaOfLength(10))
2827
.setRequired(randomBoolean())

test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/restspec/ClientYamlSuiteRestApiParser.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ public class ClientYamlSuiteRestApiParser {
3434
public ClientYamlSuiteRestApi parse(String location, XContentParser parser) throws IOException {
3535

3636
while (parser.nextToken() != XContentParser.Token.FIELD_NAME) {
37+
if (parser.currentToken() == null) {
38+
throw new ParsingException(
39+
parser.getTokenLocation(),
40+
"Invalid rest spec file found at [" + location + "]. No API name found in file"
41+
);
42+
}
3743
// move to first field name
3844
}
3945

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,13 @@ private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<Ranke
8282
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
8383
return RankedDocsResults.createParser(true).apply(parser, null);
8484
}
85+
86+
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
87+
88+
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
89+
return Map.of(
90+
RankedDocsResults.RERANK,
91+
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
92+
);
93+
}
8594
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockRerankServiceModelConfig() {
175+
return """
176+
{
177+
"service": "test_reranking_service",
178+
"service_settings": {
179+
"model_id": "my_model",
180+
"api_key": "abc64"
181+
},
182+
"task_settings": {
183+
}
184+
}
185+
""";
186+
}
187+
174188
static void deleteModel(String modelId) throws IOException {
175189
var request = new Request("DELETE", "_inference/" + modelId);
176190
var response = client().performRequest(request);
@@ -484,6 +498,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
484498
@SuppressWarnings("unchecked")
485499
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
486500
switch (taskType) {
501+
case RERANK -> {
502+
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
503+
assertThat(results, hasSize(expectedNumberOfResults));
504+
}
487505
case SPARSE_EMBEDDING -> {
488506
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
489507
assertThat(results, hasSize(expectedNumberOfResults));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ public void testCRUD() throws IOException {
5353
for (int i = 0; i < 4; i++) {
5454
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
5555
}
56+
for (int i = 0; i < 3; i++) {
57+
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
58+
}
5659

5760
var getAllModels = getAllModels();
58-
int numModels = 12;
61+
int numModels = 15;
5962
assertThat(getAllModels, hasSize(numModels));
6063

6164
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -71,6 +74,13 @@ public void testCRUD() throws IOException {
7174
for (var denseModel : getDenseModels) {
7275
assertEquals("text_embedding", denseModel.get("task_type"));
7376
}
77+
78+
var getRerankModels = getModels("_all", TaskType.RERANK);
79+
int numRerankModels = 4;
80+
assertThat(getRerankModels, hasSize(numRerankModels));
81+
for (var denseModel : getRerankModels) {
82+
assertEquals("rerank", denseModel.get("task_type"));
83+
}
7484
String oldApiKey;
7585
{
7686
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
@@ -100,6 +110,9 @@ public void testCRUD() throws IOException {
100110
for (int i = 0; i < 4; i++) {
101111
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
102112
}
113+
for (int i = 0; i < 3; i++) {
114+
deleteModel("re-model-" + i, TaskType.RERANK);
115+
}
103116
}
104117

105118
public void testGetModelWithWrongTaskType() throws IOException {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
101101

102102
public void testGetServicesWithRerankTaskType() throws IOException {
103103
List<Object> services = getServices(TaskType.RERANK);
104-
assertThat(services.size(), equalTo(7));
104+
assertThat(services.size(), equalTo(8));
105105

106106
var providers = providers(services);
107107

@@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
115115
"googlevertexai",
116116
"jinaai",
117117
"test_reranking_service",
118-
"voyageai"
118+
"voyageai",
119+
"hugging_face"
119120
).toArray()
120121
)
121122
);

0 commit comments

Comments
 (0)