Skip to content

Commit 78905e5

Browse files
committed
add circuit breaker check
1 parent a728ddf commit 78905e5

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import org.elasticsearch.action.support.RefCountingListener;
1515
import org.elasticsearch.action.support.master.AcknowledgedResponse;
1616
import org.elasticsearch.client.internal.Client;
17+
import org.elasticsearch.common.breaker.CircuitBreaker;
1718
import org.elasticsearch.common.bytes.BytesArray;
1819
import org.elasticsearch.core.Nullable;
20+
import org.elasticsearch.indices.breaker.CircuitBreakerService;
1921
import org.elasticsearch.rest.RestStatus;
2022
import org.elasticsearch.tasks.TaskCancelledException;
2123
import org.elasticsearch.threadpool.ThreadPool;
@@ -61,9 +63,16 @@ public class ModelImporter {
6163
private final ExecutorService executorService;
6264
private final AtomicInteger progressCounter = new AtomicInteger();
6365
private final URI uri;
64-
65-
ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task, ThreadPool threadPool)
66-
throws URISyntaxException {
66+
private final CircuitBreakerService breakerService;
67+
68+
ModelImporter(
69+
Client client,
70+
String modelId,
71+
ModelPackageConfig packageConfig,
72+
ModelDownloadTask task,
73+
ThreadPool threadPool,
74+
CircuitBreakerService cbs
75+
) throws URISyntaxException {
6776
this.client = client;
6877
this.modelId = Objects.requireNonNull(modelId);
6978
this.config = Objects.requireNonNull(packageConfig);
@@ -73,6 +82,7 @@ public class ModelImporter {
7382
config.getModelRepository(),
7483
config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION
7584
);
85+
this.breakerService = cbs;
7686
}
7787

7888
public void doImport(ActionListener<AcknowledgedResponse> listener) {
@@ -99,12 +109,19 @@ private void doImportInternal(ActionListener<AcknowledgedResponse> finalListener
99109
int totalParts = (int) ((config.getSize() + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE);
100110

101111
if (ModelLoaderUtils.uriIsFile(uri) == false) {
112+
breakerService.getBreaker(CircuitBreaker.REQUEST)
113+
.addEstimateBytesAndMaybeBreak(DEFAULT_CHUNK_SIZE * NUMBER_OF_STREAMS, "model importer");
114+
var breakerFreeingListener = ActionListener.runAfter(
115+
finalListener,
116+
() -> breakerService.getBreaker(CircuitBreaker.REQUEST).addWithoutBreaking(-(DEFAULT_CHUNK_SIZE * NUMBER_OF_STREAMS))
117+
);
118+
102119
var ranges = ModelLoaderUtils.split(config.getSize(), NUMBER_OF_STREAMS, DEFAULT_CHUNK_SIZE);
103120
var downloaders = new ArrayList<ModelLoaderUtils.HttpStreamChunker>(ranges.size());
104121
for (var range : ranges) {
105122
downloaders.add(new ModelLoaderUtils.HttpStreamChunker(uri, range, DEFAULT_CHUNK_SIZE));
106123
}
107-
downloadModelDefinition(config.getSize(), totalParts, vocabularyParts, downloaders, finalListener);
124+
downloadModelDefinition(config.getSize(), totalParts, vocabularyParts, downloaders, breakerFreeingListener);
108125
} else {
109126
InputStream modelInputStream = ModelLoaderUtils.getFileInputStream(uri);
110127
ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(
@@ -115,7 +132,6 @@ private void doImportInternal(ActionListener<AcknowledgedResponse> finalListener
115132
}
116133
} catch (Exception e) {
117134
finalListener.onFailure(e);
118-
return;
119135
}
120136
}
121137

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
2424
import org.elasticsearch.cluster.service.ClusterService;
2525
import org.elasticsearch.common.util.concurrent.EsExecutors;
26+
import org.elasticsearch.indices.breaker.CircuitBreakerService;
2627
import org.elasticsearch.injection.guice.Inject;
2728
import org.elasticsearch.rest.RestStatus;
2829
import org.elasticsearch.tasks.Task;
@@ -55,6 +56,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
5556
private static final Logger logger = LogManager.getLogger(TransportLoadTrainedModelPackage.class);
5657

5758
private final Client client;
59+
private final CircuitBreakerService circuitBreakerService;
5860

5961
@Inject
6062
public TransportLoadTrainedModelPackage(
@@ -63,7 +65,8 @@ public TransportLoadTrainedModelPackage(
6365
ThreadPool threadPool,
6466
ActionFilters actionFilters,
6567
IndexNameExpressionResolver indexNameExpressionResolver,
66-
Client client
68+
Client client,
69+
CircuitBreakerService circuitBreakerService
6770
) {
6871
super(
6972
LoadTrainedModelPackageAction.NAME,
@@ -77,6 +80,7 @@ public TransportLoadTrainedModelPackage(
7780
EsExecutors.DIRECT_EXECUTOR_SERVICE
7881
);
7982
this.client = new OriginSettingClient(client, ML_ORIGIN);
83+
this.circuitBreakerService = circuitBreakerService;
8084
}
8185

8286
@Override
@@ -97,7 +101,8 @@ protected void masterOperation(Task task, Request request, ClusterState state, A
97101
request.getModelId(),
98102
request.getModelPackageConfig(),
99103
downloadTask,
100-
threadPool
104+
threadPool,
105+
circuitBreakerService
101106
);
102107

103108
var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.<AcknowledgedResponse>noop();

x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import org.elasticsearch.action.support.ActionTestUtils;
1515
import org.elasticsearch.action.support.master.AcknowledgedResponse;
1616
import org.elasticsearch.client.internal.Client;
17+
import org.elasticsearch.common.breaker.CircuitBreaker;
1718
import org.elasticsearch.common.hash.MessageDigests;
1819
import org.elasticsearch.common.settings.Settings;
20+
import org.elasticsearch.indices.breaker.CircuitBreakerService;
1921
import org.elasticsearch.rest.RestStatus;
2022
import org.elasticsearch.test.ESTestCase;
2123
import org.elasticsearch.threadpool.TestThreadPool;
@@ -63,6 +65,8 @@ public void testDownloadModelDefinition() throws InterruptedException, URISyntax
6365
var task = ModelDownloadTaskTests.testTask();
6466
var config = mockConfigWithRepoLinks();
6567
var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of());
68+
var cbs = mock(CircuitBreakerService.class);
69+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
6670

6771
int totalParts = 5;
6872
int chunkSize = 10;
@@ -74,7 +78,7 @@ public void testDownloadModelDefinition() throws InterruptedException, URISyntax
7478
when(config.getSha256()).thenReturn(digest);
7579
when(config.getSize()).thenReturn(size);
7680

77-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
81+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
7882

7983
var latch = new CountDownLatch(1);
8084
var latchedListener = new LatchedActionListener<AcknowledgedResponse>(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch);
@@ -91,6 +95,8 @@ public void testReadModelDefinitionFromFile() throws InterruptedException, URISy
9195
var task = ModelDownloadTaskTests.testTask();
9296
var config = mockConfigWithRepoLinks();
9397
var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of());
98+
var cbs = mock(CircuitBreakerService.class);
99+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
94100

95101
int totalParts = 3;
96102
int chunkSize = 10;
@@ -101,7 +107,7 @@ public void testReadModelDefinitionFromFile() throws InterruptedException, URISy
101107
when(config.getSha256()).thenReturn(digest);
102108
when(config.getSize()).thenReturn(size);
103109

104-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
110+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
105111
var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize);
106112

107113
var latch = new CountDownLatch(1);
@@ -118,6 +124,8 @@ public void testSizeMismatch() throws InterruptedException, URISyntaxException {
118124
var client = mockClient(false);
119125
var task = mock(ModelDownloadTask.class);
120126
var config = mockConfigWithRepoLinks();
127+
var cbs = mock(CircuitBreakerService.class);
128+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
121129

122130
int totalParts = 5;
123131
int chunkSize = 10;
@@ -137,7 +145,7 @@ public void testSizeMismatch() throws InterruptedException, URISyntaxException {
137145
latch
138146
);
139147

140-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
148+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
141149
importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener);
142150

143151
latch.await();
@@ -149,6 +157,8 @@ public void testDigestMismatch() throws InterruptedException, URISyntaxException
149157
var client = mockClient(false);
150158
var task = mock(ModelDownloadTask.class);
151159
var config = mockConfigWithRepoLinks();
160+
var cbs = mock(CircuitBreakerService.class);
161+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
152162

153163
int totalParts = 5;
154164
int chunkSize = 10;
@@ -166,7 +176,7 @@ public void testDigestMismatch() throws InterruptedException, URISyntaxException
166176
latch
167177
);
168178

169-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
179+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
170180
// Message digest can only be calculated for the file reader
171181
var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize);
172182
importer.readModelDefinitionFromFile(size, totalParts, streamChunker, null, latchedListener);
@@ -180,6 +190,8 @@ public void testPutFailure() throws InterruptedException, URISyntaxException {
180190
var client = mockClient(true); // client will fail put
181191
var task = mock(ModelDownloadTask.class);
182192
var config = mockConfigWithRepoLinks();
193+
var cbs = mock(CircuitBreakerService.class);
194+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
183195

184196
int totalParts = 4;
185197
int chunkSize = 10;
@@ -194,7 +206,7 @@ public void testPutFailure() throws InterruptedException, URISyntaxException {
194206
latch
195207
);
196208

197-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
209+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
198210
importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener);
199211

200212
latch.await();
@@ -206,6 +218,8 @@ public void testReadFailure() throws IOException, InterruptedException, URISynta
206218
var client = mockClient(true);
207219
var task = mock(ModelDownloadTask.class);
208220
var config = mockConfigWithRepoLinks();
221+
var cbs = mock(CircuitBreakerService.class);
222+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
209223

210224
int totalParts = 4;
211225
int chunkSize = 10;
@@ -222,7 +236,7 @@ public void testReadFailure() throws IOException, InterruptedException, URISynta
222236
latch
223237
);
224238

225-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
239+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
226240
importer.downloadModelDefinition(size, totalParts, null, List.of(streamer), latchedListener);
227241

228242
latch.await();
@@ -237,6 +251,8 @@ public void testUploadVocabFailure() throws InterruptedException, URISyntaxExcep
237251
listener.onFailure(new ElasticsearchStatusException("put vocab failed", RestStatus.BAD_REQUEST));
238252
return null;
239253
}).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any());
254+
var cbs = mock(CircuitBreakerService.class);
255+
when(cbs.getBreaker(eq(CircuitBreaker.REQUEST))).thenReturn(mock(CircuitBreaker.class));
240256

241257
var task = mock(ModelDownloadTask.class);
242258
var config = mockConfigWithRepoLinks();
@@ -250,7 +266,7 @@ public void testUploadVocabFailure() throws InterruptedException, URISyntaxExcep
250266
latch
251267
);
252268

253-
var importer = new ModelImporter(client, "foo", config, task, threadPool);
269+
var importer = new ModelImporter(client, "foo", config, task, threadPool, cbs);
254270
importer.downloadModelDefinition(100, 5, vocab, List.of(), latchedListener);
255271

256272
latch.await();

0 commit comments

Comments
 (0)