99
1010import org .apache .logging .log4j .LogManager ;
1111import org .apache .logging .log4j .Logger ;
12+ import org .apache .lucene .internal .hppc .IntIntHashMap ;
1213import org .elasticsearch .ElasticsearchStatusException ;
14+ import org .elasticsearch .ExceptionsHelper ;
1315import org .elasticsearch .TransportVersion ;
1416import org .elasticsearch .TransportVersions ;
1517import org .elasticsearch .action .ActionListener ;
18+ import org .elasticsearch .action .support .RetryableAction ;
1619import org .elasticsearch .common .logging .DeprecationCategory ;
1720import org .elasticsearch .common .logging .DeprecationLogger ;
1821import org .elasticsearch .common .settings .Settings ;
1922import org .elasticsearch .common .util .LazyInitializable ;
23+ import org .elasticsearch .common .util .concurrent .AtomicArray ;
2024import org .elasticsearch .core .Nullable ;
2125import org .elasticsearch .core .Strings ;
2226import org .elasticsearch .core .TimeValue ;
3640import org .elasticsearch .inference .UnifiedCompletionRequest ;
3741import org .elasticsearch .inference .configuration .SettingsConfigurationFieldType ;
3842import org .elasticsearch .rest .RestStatus ;
43+ import org .elasticsearch .threadpool .ThreadPool ;
3944import org .elasticsearch .xpack .core .XPackSettings ;
4045import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
4146import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
5762import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextSimilarityConfigUpdate ;
5863import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsBuilder ;
5964import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
65+ import org .elasticsearch .xpack .inference .external .http .retry .RetrySettings ;
6066import org .elasticsearch .xpack .inference .services .ConfigurationParseContext ;
6167import org .elasticsearch .xpack .inference .services .ServiceUtils ;
6268
6874import java .util .Map ;
6975import java .util .Optional ;
7076import java .util .Set ;
77+ import java .util .concurrent .Executor ;
78+ import java .util .concurrent .atomic .AtomicBoolean ;
7179import java .util .concurrent .atomic .AtomicInteger ;
7280import java .util .function .Consumer ;
7381import java .util .function .Function ;
82+ import java .util .function .IntUnaryOperator ;
7483
7584import static org .elasticsearch .xpack .core .inference .results .ResultUtils .createInvalidChunkedResultException ;
7685import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMap ;
@@ -121,10 +130,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
121130 private static final String OLD_MODEL_ID_FIELD_NAME = "model_version" ;
122131
123132 private final Settings settings ;
133+ private final ThreadPool threadPool ;
134+ private final RetrySettings retrySettings ;
124135
125136 public ElasticsearchInternalService (InferenceServiceExtension .InferenceServiceFactoryContext context ) {
126137 super (context );
127138 this .settings = context .settings ();
139+ this .threadPool = context .threadPool ();
140+ this .retrySettings = new RetrySettings (context .settings (), context .clusterService ());
128141 }
129142
130143 // for testing
@@ -134,6 +147,8 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa
134147 ) {
135148 super (context , platformArch );
136149 this .settings = context .settings ();
150+ this .threadPool = context .threadPool ();
151+ this .retrySettings = new RetrySettings (context .settings (), context .clusterService ());
137152 }
138153
139154 @ Override
@@ -1126,10 +1141,150 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft
11261141 if (maybeDeploy ) {
11271142 listener = listener .delegateResponse ((l , exception ) -> maybeStartDeployment (esModel , exception , inferenceRequest , l ));
11281143 }
1129- client .execute (InferModelAction .INSTANCE , inferenceRequest , listener );
1144+
1145+ new BatchExecutor (retrySettings .getInitialDelay (), retrySettings .getTimeout (), inferenceRequest , listener , inferenceExecutor )
1146+ .run ();
1147+ }
1148+ }
1149+
1150+ private static final Set <RestStatus > RETRYABLE_STATUS = Set .of (
1151+ RestStatus .INTERNAL_SERVER_ERROR ,
1152+ RestStatus .TOO_MANY_REQUESTS ,
1153+ RestStatus .REQUEST_TIMEOUT
1154+ );
1155+
1156+ private class BatchExecutor extends RetryableAction <InferModelAction .Response > {
1157+ private final RetryState state ;
1158+
1159+ BatchExecutor (
1160+ TimeValue initialDelay ,
1161+ TimeValue timeoutValue ,
1162+ InferModelAction .Request request ,
1163+ ActionListener <InferModelAction .Response > listener ,
1164+ Executor executor
1165+ ) {
1166+ this (initialDelay , timeoutValue , new RetryState (request ), listener , executor );
1167+ }
1168+
1169+ private BatchExecutor (
1170+ TimeValue initialDelay ,
1171+ TimeValue timeoutValue ,
1172+ RetryState state ,
1173+ ActionListener <InferModelAction .Response > listener ,
1174+ Executor executor
1175+ ) {
1176+ super (logger , threadPool , initialDelay , timeoutValue , new ActionListener <>() {
1177+ @ Override
1178+ public void onResponse (InferModelAction .Response response ) {
1179+ listener .onResponse (state .getAccumulatedResponse (null ));
1180+ }
1181+
1182+ @ Override
1183+ public void onFailure (Exception exc ) {
1184+ if (state .hasPartialResponse ()) {
1185+ listener .onResponse (state .getAccumulatedResponse (exc instanceof RetryableException ? null : exc ));
1186+ } else {
1187+ listener .onFailure (exc );
1188+ }
1189+ }
1190+ }, executor );
1191+ this .state = state ;
1192+ }
1193+
1194+ @ Override
1195+ public void tryAction (ActionListener <InferModelAction .Response > listener ) {
1196+ client .execute (InferModelAction .INSTANCE , state .getCurrentRequest (), new ActionListener <>() {
1197+ @ Override
1198+ public void onResponse (InferModelAction .Response response ) {
1199+ if (state .consumeResponse (response )) {
1200+ listener .onResponse (response );
1201+ } else {
1202+ listener .onFailure (new RetryableException ());
1203+ }
1204+ }
1205+
1206+ @ Override
1207+ public void onFailure (Exception exc ) {
1208+ listener .onFailure (exc );
1209+ }
1210+ });
1211+ }
1212+
1213+ @ Override
1214+ public boolean shouldRetry (Exception exc ) {
1215+ return exc instanceof RetryableException
1216+ || RETRYABLE_STATUS .contains (ExceptionsHelper .status (ExceptionsHelper .unwrapCause (exc )));
11301217 }
11311218 }
11321219
1220+ private static class RetryState {
1221+ private final InferModelAction .Request originalRequest ;
1222+ private InferModelAction .Request currentRequest ;
1223+
1224+ private IntUnaryOperator currentToOriginalIndex ;
1225+ private final AtomicArray <InferenceResults > inferenceResults ;
1226+ private final AtomicBoolean hasPartialResponse ;
1227+
1228+ private RetryState (InferModelAction .Request originalRequest ) {
1229+ this .originalRequest = originalRequest ;
1230+ this .currentRequest = originalRequest ;
1231+ this .currentToOriginalIndex = index -> index ;
1232+ this .inferenceResults = new AtomicArray <>(originalRequest .getTextInput ().size ());
1233+ this .hasPartialResponse = new AtomicBoolean ();
1234+ }
1235+
1236+ boolean hasPartialResponse () {
1237+ return hasPartialResponse .get ();
1238+ }
1239+
1240+ InferModelAction .Request getCurrentRequest () {
1241+ return currentRequest ;
1242+ }
1243+
1244+ InferModelAction .Response getAccumulatedResponse (@ Nullable Exception exc ) {
1245+ List <InferenceResults > finalResults = new ArrayList <>();
1246+ for (int i = 0 ; i < inferenceResults .length (); i ++) {
1247+ var result = inferenceResults .get (i );
1248+ if (exc != null && result instanceof ErrorInferenceResults ) {
1249+ finalResults .add (new ErrorInferenceResults (exc ));
1250+ } else {
1251+ finalResults .add (result );
1252+ }
1253+ }
1254+ return new InferModelAction .Response (finalResults , originalRequest .getId (), originalRequest .isPreviouslyLicensed ());
1255+ }
1256+
1257+ private boolean consumeResponse (InferModelAction .Response response ) {
1258+ hasPartialResponse .set (true );
1259+ List <String > retryInputs = new ArrayList <>();
1260+ IntIntHashMap newIndexMap = new IntIntHashMap ();
1261+ for (int i = 0 ; i < response .getInferenceResults ().size (); i ++) {
1262+ var result = response .getInferenceResults ().get (i );
1263+ int index = currentToOriginalIndex .applyAsInt (i );
1264+ inferenceResults .set (index , result );
1265+ if (result instanceof ErrorInferenceResults error
1266+ && RETRYABLE_STATUS .contains (ExceptionsHelper .status (ExceptionsHelper .unwrapCause (error .getException ())))) {
1267+ newIndexMap .put (retryInputs .size (), index );
1268+ retryInputs .add (originalRequest .getTextInput ().get (index ));
1269+ }
1270+ }
1271+ if (retryInputs .isEmpty ()) {
1272+ return true ;
1273+ }
1274+ currentRequest = InferModelAction .Request .forTextInput (
1275+ originalRequest .getId (),
1276+ originalRequest .getUpdate (),
1277+ retryInputs ,
1278+ originalRequest .isPreviouslyLicensed (),
1279+ originalRequest .getInferenceTimeout ()
1280+ );
1281+ currentToOriginalIndex = newIndexMap ::get ;
1282+ return false ;
1283+ }
1284+ }
1285+
1286+ private static class RetryableException extends Exception {}
1287+
11331288 public static class Configuration {
11341289 public static InferenceServiceConfiguration get () {
11351290 return configuration .getOrCompute ();
0 commit comments