1414import  org .elasticsearch .action .support .ActionTestUtils ;
1515import  org .elasticsearch .action .support .master .AcknowledgedResponse ;
1616import  org .elasticsearch .client .internal .Client ;
17+ import  org .elasticsearch .common .breaker .CircuitBreaker ;
1718import  org .elasticsearch .common .hash .MessageDigests ;
1819import  org .elasticsearch .common .settings .Settings ;
20+ import  org .elasticsearch .indices .breaker .CircuitBreakerService ;
1921import  org .elasticsearch .rest .RestStatus ;
2022import  org .elasticsearch .test .ESTestCase ;
2123import  org .elasticsearch .threadpool .TestThreadPool ;
@@ -63,6 +65,8 @@ public void testDownloadModelDefinition() throws InterruptedException, URISyntax
6365        var  task  = ModelDownloadTaskTests .testTask ();
6466        var  config  = mockConfigWithRepoLinks ();
6567        var  vocab  = new  ModelLoaderUtils .VocabularyParts (List .of (), List .of (), List .of ());
68+         var  cbs  = mock (CircuitBreakerService .class );
69+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
6670
6771        int  totalParts  = 5 ;
6872        int  chunkSize  = 10 ;
@@ -74,7 +78,7 @@ public void testDownloadModelDefinition() throws InterruptedException, URISyntax
7478        when (config .getSha256 ()).thenReturn (digest );
7579        when (config .getSize ()).thenReturn (size );
7680
77-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
81+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
7882
7983        var  latch  = new  CountDownLatch (1 );
8084        var  latchedListener  = new  LatchedActionListener <AcknowledgedResponse >(ActionTestUtils .assertNoFailureListener (ignore  -> {}), latch );
@@ -91,6 +95,8 @@ public void testReadModelDefinitionFromFile() throws InterruptedException, URISy
9195        var  task  = ModelDownloadTaskTests .testTask ();
9296        var  config  = mockConfigWithRepoLinks ();
9397        var  vocab  = new  ModelLoaderUtils .VocabularyParts (List .of (), List .of (), List .of ());
98+         var  cbs  = mock (CircuitBreakerService .class );
99+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
94100
95101        int  totalParts  = 3 ;
96102        int  chunkSize  = 10 ;
@@ -101,7 +107,7 @@ public void testReadModelDefinitionFromFile() throws InterruptedException, URISy
101107        when (config .getSha256 ()).thenReturn (digest );
102108        when (config .getSize ()).thenReturn (size );
103109
104-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
110+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
105111        var  streamChunker  = new  ModelLoaderUtils .InputStreamChunker (new  ByteArrayInputStream (modelDef ), chunkSize );
106112
107113        var  latch  = new  CountDownLatch (1 );
@@ -118,6 +124,8 @@ public void testSizeMismatch() throws InterruptedException, URISyntaxException {
118124        var  client  = mockClient (false );
119125        var  task  = mock (ModelDownloadTask .class );
120126        var  config  = mockConfigWithRepoLinks ();
127+         var  cbs  = mock (CircuitBreakerService .class );
128+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
121129
122130        int  totalParts  = 5 ;
123131        int  chunkSize  = 10 ;
@@ -137,7 +145,7 @@ public void testSizeMismatch() throws InterruptedException, URISyntaxException {
137145            latch 
138146        );
139147
140-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
148+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
141149        importer .downloadModelDefinition (size , totalParts , null , streamers , latchedListener );
142150
143151        latch .await ();
@@ -149,6 +157,8 @@ public void testDigestMismatch() throws InterruptedException, URISyntaxException
149157        var  client  = mockClient (false );
150158        var  task  = mock (ModelDownloadTask .class );
151159        var  config  = mockConfigWithRepoLinks ();
160+         var  cbs  = mock (CircuitBreakerService .class );
161+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
152162
153163        int  totalParts  = 5 ;
154164        int  chunkSize  = 10 ;
@@ -166,7 +176,7 @@ public void testDigestMismatch() throws InterruptedException, URISyntaxException
166176            latch 
167177        );
168178
169-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
179+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
170180        // Message digest can only be calculated for the file reader 
171181        var  streamChunker  = new  ModelLoaderUtils .InputStreamChunker (new  ByteArrayInputStream (modelDef ), chunkSize );
172182        importer .readModelDefinitionFromFile (size , totalParts , streamChunker , null , latchedListener );
@@ -180,6 +190,8 @@ public void testPutFailure() throws InterruptedException, URISyntaxException {
180190        var  client  = mockClient (true );  // client will fail put 
181191        var  task  = mock (ModelDownloadTask .class );
182192        var  config  = mockConfigWithRepoLinks ();
193+         var  cbs  = mock (CircuitBreakerService .class );
194+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
183195
184196        int  totalParts  = 4 ;
185197        int  chunkSize  = 10 ;
@@ -194,7 +206,7 @@ public void testPutFailure() throws InterruptedException, URISyntaxException {
194206            latch 
195207        );
196208
197-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
209+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
198210        importer .downloadModelDefinition (size , totalParts , null , streamers , latchedListener );
199211
200212        latch .await ();
@@ -206,6 +218,8 @@ public void testReadFailure() throws IOException, InterruptedException, URISynta
206218        var  client  = mockClient (true );
207219        var  task  = mock (ModelDownloadTask .class );
208220        var  config  = mockConfigWithRepoLinks ();
221+         var  cbs  = mock (CircuitBreakerService .class );
222+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
209223
210224        int  totalParts  = 4 ;
211225        int  chunkSize  = 10 ;
@@ -222,7 +236,7 @@ public void testReadFailure() throws IOException, InterruptedException, URISynta
222236            latch 
223237        );
224238
225-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
239+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
226240        importer .downloadModelDefinition (size , totalParts , null , List .of (streamer ), latchedListener );
227241
228242        latch .await ();
@@ -237,6 +251,8 @@ public void testUploadVocabFailure() throws InterruptedException, URISyntaxExcep
237251            listener .onFailure (new  ElasticsearchStatusException ("put vocab failed" , RestStatus .BAD_REQUEST ));
238252            return  null ;
239253        }).when (client ).execute (eq (PutTrainedModelVocabularyAction .INSTANCE ), any (), any ());
254+         var  cbs  = mock (CircuitBreakerService .class );
255+         when (cbs .getBreaker (eq (CircuitBreaker .REQUEST ))).thenReturn (mock (CircuitBreaker .class ));
240256
241257        var  task  = mock (ModelDownloadTask .class );
242258        var  config  = mockConfigWithRepoLinks ();
@@ -250,7 +266,7 @@ public void testUploadVocabFailure() throws InterruptedException, URISyntaxExcep
250266            latch 
251267        );
252268
253-         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool );
269+         var  importer  = new  ModelImporter (client , "foo" , config , task , threadPool ,  cbs );
254270        importer .downloadModelDefinition (100 , 5 , vocab , List .of (), latchedListener );
255271
256272        latch .await ();
0 commit comments