@@ -107,7 +107,7 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
107
107
stats .timingStats ().getAverage (),
108
108
stats .timingStatsExcludingCacheHits ().getAverage (),
109
109
stats .lastUsed (),
110
- processContext .executorService .queueSize () + stats .numberOfPendingResults (),
110
+ processContext .priorityProcessWorker .queueSize () + stats .numberOfPendingResults (),
111
111
stats .errorCount (),
112
112
stats .cacheHitCount (),
113
113
processContext .rejectedExecutionCount .intValue (),
@@ -130,7 +130,7 @@ ProcessContext addProcessContext(Long id, ProcessContext processContext) {
130
130
private void doStartDeployment (TrainedModelDeploymentTask task , ActionListener <TrainedModelDeploymentTask > finalListener ) {
131
131
logger .info ("[{}] Starting model deployment" , task .getModelId ());
132
132
133
- ProcessContext processContext = new ProcessContext (task , executorServiceForProcess );
133
+ ProcessContext processContext = new ProcessContext (task );
134
134
if (addProcessContext (task .getId (), processContext ) != null ) {
135
135
finalListener .onFailure (
136
136
ExceptionsHelper .serverError ("[{}] Could not create inference process as one already exists" , task .getModelId ())
@@ -232,7 +232,10 @@ Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
232
232
private void startAndLoad (ProcessContext processContext , TrainedModelLocation modelLocation , ActionListener <Boolean > loadedListener ) {
233
233
try {
234
234
processContext .startProcess ();
235
- processContext .loadModel (modelLocation , loadedListener );
235
+ processContext .loadModel (modelLocation , ActionListener .wrap (success -> {
236
+ processContext .startPriorityProcessWorker ();
237
+ loadedListener .onResponse (success );
238
+ }, loadedListener ::onFailure ));
236
239
} catch (Exception e ) {
237
240
loadedListener .onFailure (e );
238
241
}
@@ -332,13 +335,13 @@ public void clearCache(TrainedModelDeploymentTask task, TimeValue timeout, Actio
332
335
executePyTorchAction (processContext , PriorityProcessWorkerExecutorService .RequestPriority .HIGHEST , controlMessageAction );
333
336
}
334
337
335
- public void executePyTorchAction (
338
+ void executePyTorchAction (
336
339
ProcessContext processContext ,
337
340
PriorityProcessWorkerExecutorService .RequestPriority priority ,
338
341
AbstractPyTorchAction <?> action
339
342
) {
340
343
try {
341
- processContext .getExecutorService ().executeWithPriority (action , priority , action .getRequestId ());
344
+ processContext .getPriorityProcessWorker ().executeWithPriority (action , priority , action .getRequestId ());
342
345
} catch (EsRejectedExecutionException e ) {
343
346
processContext .getRejectedExecutionCount ().incrementAndGet ();
344
347
action .onFailure (e );
@@ -376,21 +379,21 @@ class ProcessContext {
376
379
private final SetOnce <TrainedModelInput > modelInput = new SetOnce <>();
377
380
private final PyTorchResultProcessor resultProcessor ;
378
381
private final PyTorchStateStreamer stateStreamer ;
379
- private final PriorityProcessWorkerExecutorService executorService ;
382
+ private final PriorityProcessWorkerExecutorService priorityProcessWorker ;
380
383
private volatile Instant startTime ;
381
384
private volatile Integer numThreadsPerAllocation ;
382
385
private volatile Integer numAllocations ;
383
386
private final AtomicInteger rejectedExecutionCount = new AtomicInteger ();
384
387
private final AtomicInteger timeoutCount = new AtomicInteger ();
385
388
386
- ProcessContext (TrainedModelDeploymentTask task , ExecutorService executorService ) {
389
+ ProcessContext (TrainedModelDeploymentTask task ) {
387
390
this .task = Objects .requireNonNull (task );
388
391
resultProcessor = new PyTorchResultProcessor (task .getModelId (), threadSettings -> {
389
392
this .numThreadsPerAllocation = threadSettings .numThreadsPerAllocation ();
390
393
this .numAllocations = threadSettings .numAllocations ();
391
394
});
392
- this .stateStreamer = new PyTorchStateStreamer (client , executorService , xContentRegistry );
393
- this .executorService = new PriorityProcessWorkerExecutorService (
395
+ this .stateStreamer = new PyTorchStateStreamer (client , executorServiceForProcess , xContentRegistry );
396
+ this .priorityProcessWorker = new PriorityProcessWorkerExecutorService (
394
397
threadPool .getThreadContext (),
395
398
"inference process" ,
396
399
task .getParams ().getQueueCapacity ()
@@ -404,12 +407,15 @@ PyTorchResultProcessor getResultProcessor() {
404
407
synchronized void startProcess () {
405
408
process .set (pyTorchProcessFactory .createProcess (task , executorServiceForProcess , onProcessCrash ()));
406
409
startTime = Instant .now ();
407
- executorServiceForProcess .submit (executorService ::start );
410
+ }
411
+
412
+ void startPriorityProcessWorker () {
413
+ executorServiceForProcess .submit (priorityProcessWorker ::start );
408
414
}
409
415
410
416
synchronized void stopProcess () {
411
417
resultProcessor .stop ();
412
- executorService .shutdown ();
418
+ priorityProcessWorker .shutdown ();
413
419
try {
414
420
if (process .get () == null ) {
415
421
return ;
@@ -430,7 +436,7 @@ private Consumer<String> onProcessCrash() {
430
436
return reason -> {
431
437
logger .error ("[{}] inference process crashed due to reason [{}]" , task .getModelId (), reason );
432
438
resultProcessor .stop ();
433
- executorService .shutdownWithError (new IllegalStateException (reason ));
439
+ priorityProcessWorker .shutdownWithError (new IllegalStateException (reason ));
434
440
processContextByAllocation .remove (task .getId ());
435
441
if (nlpTaskProcessor .get () != null ) {
436
442
nlpTaskProcessor .get ().close ();
@@ -441,6 +447,7 @@ private Consumer<String> onProcessCrash() {
441
447
442
448
void loadModel (TrainedModelLocation modelLocation , ActionListener <Boolean > listener ) {
443
449
if (modelLocation instanceof IndexLocation indexLocation ) {
450
+ logger .debug ("[{}] loading model state" , task .getModelId ());
444
451
process .get ().loadModel (task .getModelId (), indexLocation .getIndexName (), stateStreamer , listener );
445
452
} else {
446
453
listener .onFailure (
@@ -455,8 +462,8 @@ AtomicInteger getTimeoutCount() {
455
462
}
456
463
457
464
// accessor used for mocking in tests
458
- PriorityProcessWorkerExecutorService getExecutorService () {
459
- return executorService ;
465
+ PriorityProcessWorkerExecutorService getPriorityProcessWorker () {
466
+ return priorityProcessWorker ;
460
467
}
461
468
462
469
// accessor used for mocking in tests
0 commit comments