Skip to content

Commit ca477ba

Browse files
committed
Add Retry Logic for ElasticsearchInternalService#chunkedInfer
This PR introduces basic retry functionality for the internal inference service (`ElasticsearchInternalService`), which runs on ML nodes. We already use an exponential backoff strategy for retrying failed inference requests to external services. This change extends the same retry mechanism to the internal service, allowing it to automatically retry transient failures. To maintain consistency and reduce complexity, this implementation reuses the existing retry configuration settings: `xpack.inference.http.retry.*` **Note**: This PR is still a draft. Additional tests are needed, but I’d like to gather feedback on the approach before proceeding further.
1 parent 31efea7 commit ca477ba

File tree

4 files changed

+263
-65
lines changed

4 files changed

+263
-65
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@
7777
public class ModelRegistryIT extends ESSingleNodeTestCase {
7878
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
7979

80+
private ClusterService clusterService;
8081
private ModelRegistry modelRegistry;
8182

8283
@Before
8384
public void createComponents() {
85+
clusterService = node().injector().getInstance(ClusterService.class);
8486
modelRegistry = node().injector().getInstance(ModelRegistry.class);
8587
modelRegistry.clearDefaultIds();
8688
}
@@ -123,12 +125,11 @@ public void testGetModel() throws Exception {
123125
assertThat(modelHolder.get(), not(nullValue()));
124126

125127
assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
126-
127128
var elserService = new ElasticsearchInternalService(
128129
new InferenceServiceExtension.InferenceServiceFactoryContext(
129130
mock(Client.class),
130131
mock(ThreadPool.class),
131-
mock(ClusterService.class),
132+
clusterService,
132133
Settings.EMPTY
133134
)
134135
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
public class RetrySettings {
1818

19-
static final Setting<TimeValue> RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting(
19+
public static final Setting<TimeValue> RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting(
2020
"xpack.inference.http.retry.initial_delay",
2121
TimeValue.timeValueSeconds(1),
2222
Setting.Property.NodeScope,
@@ -30,7 +30,7 @@ public class RetrySettings {
3030
Setting.Property.Dynamic
3131
);
3232

33-
static final Setting<TimeValue> RETRY_TIMEOUT_SETTING = Setting.timeSetting(
33+
public static final Setting<TimeValue> RETRY_TIMEOUT_SETTING = Setting.timeSetting(
3434
"xpack.inference.http.retry.timeout",
3535
TimeValue.timeValueSeconds(30),
3636
Setting.Property.NodeScope,
@@ -106,15 +106,15 @@ public static List<Setting<?>> getSettingsDefinitions() {
106106
);
107107
}
108108

109-
TimeValue getInitialDelay() {
109+
public TimeValue getInitialDelay() {
110110
return initialDelay;
111111
}
112112

113113
TimeValue getMaxDelayBound() {
114114
return maxDelayBound;
115115
}
116116

117-
TimeValue getTimeout() {
117+
public TimeValue getTimeout() {
118118
return timeout;
119119
}
120120

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.apache.lucene.internal.hppc.IntIntHashMap;
1213
import org.elasticsearch.ElasticsearchStatusException;
14+
import org.elasticsearch.ExceptionsHelper;
1315
import org.elasticsearch.TransportVersion;
1416
import org.elasticsearch.TransportVersions;
1517
import org.elasticsearch.action.ActionListener;
18+
import org.elasticsearch.action.support.RetryableAction;
1619
import org.elasticsearch.common.logging.DeprecationCategory;
1720
import org.elasticsearch.common.logging.DeprecationLogger;
1821
import org.elasticsearch.common.settings.Settings;
1922
import org.elasticsearch.common.util.LazyInitializable;
23+
import org.elasticsearch.common.util.concurrent.AtomicArray;
2024
import org.elasticsearch.core.Nullable;
2125
import org.elasticsearch.core.Strings;
2226
import org.elasticsearch.core.TimeValue;
@@ -36,6 +40,7 @@
3640
import org.elasticsearch.inference.UnifiedCompletionRequest;
3741
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
3842
import org.elasticsearch.rest.RestStatus;
43+
import org.elasticsearch.threadpool.ThreadPool;
3944
import org.elasticsearch.xpack.core.XPackSettings;
4045
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
4146
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -57,6 +62,7 @@
5762
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
5863
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
5964
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
65+
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
6066
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
6167
import org.elasticsearch.xpack.inference.services.ServiceUtils;
6268

@@ -68,9 +74,12 @@
6874
import java.util.Map;
6975
import java.util.Optional;
7076
import java.util.Set;
77+
import java.util.concurrent.Executor;
78+
import java.util.concurrent.atomic.AtomicBoolean;
7179
import java.util.concurrent.atomic.AtomicInteger;
7280
import java.util.function.Consumer;
7381
import java.util.function.Function;
82+
import java.util.function.IntUnaryOperator;
7483

7584
import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
7685
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
@@ -121,10 +130,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
121130
private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";
122131

123132
private final Settings settings;
133+
private final ThreadPool threadPool;
134+
private final RetrySettings retrySettings;
124135

125136
public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
126137
super(context);
127138
this.settings = context.settings();
139+
this.threadPool = context.threadPool();
140+
this.retrySettings = new RetrySettings(context.settings(), context.clusterService());
128141
}
129142

130143
// for testing
@@ -134,6 +147,8 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa
134147
) {
135148
super(context, platformArch);
136149
this.settings = context.settings();
150+
this.threadPool = context.threadPool();
151+
this.retrySettings = new RetrySettings(context.settings(), context.clusterService());
137152
}
138153

139154
@Override
@@ -1126,10 +1141,150 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft
11261141
if (maybeDeploy) {
11271142
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
11281143
}
1129-
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
1144+
1145+
new BatchExecutor(retrySettings.getInitialDelay(), retrySettings.getTimeout(), inferenceRequest, listener, inferenceExecutor)
1146+
.run();
1147+
}
1148+
}
1149+
1150+
private static final Set<RestStatus> RETRYABLE_STATUS = Set.of(
1151+
RestStatus.INTERNAL_SERVER_ERROR,
1152+
RestStatus.TOO_MANY_REQUESTS,
1153+
RestStatus.REQUEST_TIMEOUT
1154+
);
1155+
1156+
private class BatchExecutor extends RetryableAction<InferModelAction.Response> {
1157+
private final RetryState state;
1158+
1159+
BatchExecutor(
1160+
TimeValue initialDelay,
1161+
TimeValue timeoutValue,
1162+
InferModelAction.Request request,
1163+
ActionListener<InferModelAction.Response> listener,
1164+
Executor executor
1165+
) {
1166+
this(initialDelay, timeoutValue, new RetryState(request), listener, executor);
1167+
}
1168+
1169+
private BatchExecutor(
1170+
TimeValue initialDelay,
1171+
TimeValue timeoutValue,
1172+
RetryState state,
1173+
ActionListener<InferModelAction.Response> listener,
1174+
Executor executor
1175+
) {
1176+
super(logger, threadPool, initialDelay, timeoutValue, new ActionListener<>() {
1177+
@Override
1178+
public void onResponse(InferModelAction.Response response) {
1179+
listener.onResponse(state.getAccumulatedResponse(null));
1180+
}
1181+
1182+
@Override
1183+
public void onFailure(Exception exc) {
1184+
if (state.hasPartialResponse()) {
1185+
listener.onResponse(state.getAccumulatedResponse(exc instanceof RetryableException ? null : exc));
1186+
} else {
1187+
listener.onFailure(exc);
1188+
}
1189+
}
1190+
}, executor);
1191+
this.state = state;
1192+
}
1193+
1194+
@Override
1195+
public void tryAction(ActionListener<InferModelAction.Response> listener) {
1196+
client.execute(InferModelAction.INSTANCE, state.getCurrentRequest(), new ActionListener<>() {
1197+
@Override
1198+
public void onResponse(InferModelAction.Response response) {
1199+
if (state.consumeResponse(response)) {
1200+
listener.onResponse(response);
1201+
} else {
1202+
listener.onFailure(new RetryableException());
1203+
}
1204+
}
1205+
1206+
@Override
1207+
public void onFailure(Exception exc) {
1208+
listener.onFailure(exc);
1209+
}
1210+
});
1211+
}
1212+
1213+
@Override
1214+
public boolean shouldRetry(Exception exc) {
1215+
return exc instanceof RetryableException
1216+
|| RETRYABLE_STATUS.contains(ExceptionsHelper.status(ExceptionsHelper.unwrapCause(exc)));
11301217
}
11311218
}
11321219

1220+
private static class RetryState {
1221+
private final InferModelAction.Request originalRequest;
1222+
private InferModelAction.Request currentRequest;
1223+
1224+
private IntUnaryOperator currentToOriginalIndex;
1225+
private final AtomicArray<InferenceResults> inferenceResults;
1226+
private final AtomicBoolean hasPartialResponse;
1227+
1228+
private RetryState(InferModelAction.Request originalRequest) {
1229+
this.originalRequest = originalRequest;
1230+
this.currentRequest = originalRequest;
1231+
this.currentToOriginalIndex = index -> index;
1232+
this.inferenceResults = new AtomicArray<>(originalRequest.getTextInput().size());
1233+
this.hasPartialResponse = new AtomicBoolean();
1234+
}
1235+
1236+
boolean hasPartialResponse() {
1237+
return hasPartialResponse.get();
1238+
}
1239+
1240+
InferModelAction.Request getCurrentRequest() {
1241+
return currentRequest;
1242+
}
1243+
1244+
InferModelAction.Response getAccumulatedResponse(@Nullable Exception exc) {
1245+
List<InferenceResults> finalResults = new ArrayList<>();
1246+
for (int i = 0; i < inferenceResults.length(); i++) {
1247+
var result = inferenceResults.get(i);
1248+
if (exc != null && result instanceof ErrorInferenceResults) {
1249+
finalResults.add(new ErrorInferenceResults(exc));
1250+
} else {
1251+
finalResults.add(result);
1252+
}
1253+
}
1254+
return new InferModelAction.Response(finalResults, originalRequest.getId(), originalRequest.isPreviouslyLicensed());
1255+
}
1256+
1257+
private boolean consumeResponse(InferModelAction.Response response) {
1258+
hasPartialResponse.set(true);
1259+
List<String> retryInputs = new ArrayList<>();
1260+
IntIntHashMap newIndexMap = new IntIntHashMap();
1261+
for (int i = 0; i < response.getInferenceResults().size(); i++) {
1262+
var result = response.getInferenceResults().get(i);
1263+
int index = currentToOriginalIndex.applyAsInt(i);
1264+
inferenceResults.set(index, result);
1265+
if (result instanceof ErrorInferenceResults error
1266+
&& RETRYABLE_STATUS.contains(ExceptionsHelper.status(ExceptionsHelper.unwrapCause(error.getException())))) {
1267+
newIndexMap.put(retryInputs.size(), index);
1268+
retryInputs.add(originalRequest.getTextInput().get(index));
1269+
}
1270+
}
1271+
if (retryInputs.isEmpty()) {
1272+
return true;
1273+
}
1274+
currentRequest = InferModelAction.Request.forTextInput(
1275+
originalRequest.getId(),
1276+
originalRequest.getUpdate(),
1277+
retryInputs,
1278+
originalRequest.isPreviouslyLicensed(),
1279+
originalRequest.getInferenceTimeout()
1280+
);
1281+
currentToOriginalIndex = newIndexMap::get;
1282+
return false;
1283+
}
1284+
}
1285+
1286+
private static class RetryableException extends Exception {}
1287+
11331288
public static class Configuration {
11341289
public static InferenceServiceConfiguration get() {
11351290
return configuration.getOrCompute();

0 commit comments

Comments
 (0)