|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.action.models; |
7 | 7 |
|
8 | | -import org.junit.Before; |
9 | | -import org.junit.Ignore; |
10 | | -import org.junit.Rule; |
11 | | -import org.junit.rules.ExpectedException; |
| 8 | +import org.opensearch.OpenSearchTimeoutException; |
12 | 9 | import org.opensearch.action.ActionRequestValidationException; |
13 | 10 | import org.opensearch.ml.action.MLCommonsIntegTestCase; |
14 | | -import org.opensearch.ml.common.MLModel; |
15 | 11 | import org.opensearch.ml.common.exception.MLResourceNotFoundException; |
16 | 12 | import org.opensearch.test.OpenSearchIntegTestCase; |
17 | 13 |
|
18 | 14 | @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 2) |
19 | 15 | public class GetModelITTests extends MLCommonsIntegTestCase { |
20 | | - private String irisIndexName; |
21 | 16 |
|
22 | | - @Rule |
23 | | - public ExpectedException exceptionRule = ExpectedException.none(); |
| 17 | + private static final int MAX_RETRIES = 3; |
24 | 18 |
|
25 | | - @Before |
26 | | - public void setUp() throws Exception { |
27 | | - super.setUp(); |
28 | | - irisIndexName = "iris_data_for_model_it"; |
29 | | - loadIrisData(irisIndexName); |
30 | | - } |
31 | | - |
32 | | - @Ignore |
33 | 19 | public void testGetModel_IndexNotFound() { |
34 | | - exceptionRule.expect(MLResourceNotFoundException.class); |
35 | | - MLModel model = getModel("test_id"); |
| 20 | + testGetModelExceptionsWithRetry(MLResourceNotFoundException.class, "test_id"); |
36 | 21 | } |
37 | 22 |
|
38 | 23 | public void testGetModel_NullModelIdException() { |
39 | | - exceptionRule.expect(ActionRequestValidationException.class); |
40 | | - MLModel model = getModel(null); |
| 24 | + testGetModelExceptionsWithRetry(ActionRequestValidationException.class, null); |
| 25 | + } |
| 26 | + |
| 27 | + private void testGetModelExceptionsWithRetry(Class<? extends Exception> expectedException, String modelId) { |
| 28 | + assertThrows(expectedException, () -> { |
| 29 | + for (int retryAttempt = 1; retryAttempt <= MAX_RETRIES; retryAttempt++) { |
| 30 | + try { |
| 31 | + getModel(modelId); |
| 32 | + return; |
| 33 | + } catch (OpenSearchTimeoutException e) { |
| 34 | + logger.info("GetModelITTests attempt: {}", retryAttempt); |
| 35 | + |
| 36 | + if (retryAttempt == MAX_RETRIES) { |
| 37 | + logger.error("Failed to execute test GetModelITTests after {} retries due to timeout", MAX_RETRIES); |
| 38 | + throw e; |
| 39 | + } |
| 40 | + |
| 41 | + // adding small delay between retries |
| 42 | + try { |
| 43 | + Thread.sleep(1000); |
| 44 | + } catch (InterruptedException ie) { |
| 45 | + Thread.currentThread().interrupt(); |
| 46 | + throw new RuntimeException("Thread was interrupted during retry", ie); |
| 47 | + } |
| 48 | + } |
| 49 | + } |
| 50 | + }); |
41 | 51 | } |
42 | 52 | } |
0 commit comments