Skip to content

Commit 7ae4192

Browse files
authored
check state before deleting model or task (#725)
Signed-off-by: Bhavana Goud Ramaram <[email protected]>
1 parent 5904a41 commit 7ae4192

File tree

4 files changed

+282
-37
lines changed

4 files changed

+282
-37
lines changed

plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
package org.opensearch.ml.action.models;
77

8+
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
89
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
910
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
11+
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
12+
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;
1013

1114
import lombok.AccessLevel;
1215
import lombok.experimental.FieldDefaults;
@@ -18,18 +21,26 @@
1821
import org.opensearch.action.ActionRequest;
1922
import org.opensearch.action.delete.DeleteRequest;
2023
import org.opensearch.action.delete.DeleteResponse;
24+
import org.opensearch.action.get.GetRequest;
2125
import org.opensearch.action.support.ActionFilters;
2226
import org.opensearch.action.support.HandledTransportAction;
2327
import org.opensearch.client.Client;
2428
import org.opensearch.common.inject.Inject;
2529
import org.opensearch.common.util.concurrent.ThreadContext;
30+
import org.opensearch.common.xcontent.NamedXContentRegistry;
31+
import org.opensearch.common.xcontent.XContentParser;
2632
import org.opensearch.index.query.TermsQueryBuilder;
2733
import org.opensearch.index.reindex.BulkByScrollResponse;
2834
import org.opensearch.index.reindex.DeleteByQueryAction;
2935
import org.opensearch.index.reindex.DeleteByQueryRequest;
36+
import org.opensearch.ml.common.MLModel;
37+
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
38+
import org.opensearch.ml.common.model.MLModelState;
3039
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
3140
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
41+
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
3242
import org.opensearch.rest.RestStatus;
43+
import org.opensearch.search.fetch.subphase.FetchSourceContext;
3344
import org.opensearch.tasks.Task;
3445
import org.opensearch.transport.TransportService;
3546

@@ -44,36 +55,66 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq
4455
static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of ";
4556
static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks";
4657
Client client;
58+
NamedXContentRegistry xContentRegistry;
4759

4860
@Inject
49-
public DeleteModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
61+
public DeleteModelTransportAction(
62+
TransportService transportService,
63+
ActionFilters actionFilters,
64+
Client client,
65+
NamedXContentRegistry xContentRegistry
66+
) {
5067
super(MLModelDeleteAction.NAME, transportService, actionFilters, MLModelDeleteRequest::new);
5168
this.client = client;
69+
this.xContentRegistry = xContentRegistry;
5270
}
5371

5472
@Override
5573
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
5674
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request);
5775
String modelId = mlModelDeleteRequest.getModelId();
58-
59-
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
76+
MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false);
77+
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent());
78+
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext);
6079

6180
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
62-
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
63-
@Override
64-
public void onResponse(DeleteResponse deleteResponse) {
65-
deleteModelChunks(modelId, deleteResponse, actionListener);
66-
}
81+
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
82+
if (r != null && r.isExists()) {
83+
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
84+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
85+
MLModel mlModel = MLModel.parse(parser);
86+
MLModelState mlModelState = mlModel.getModelState();
87+
if (mlModelState.equals(MLModelState.LOADED)
88+
|| mlModelState.equals(MLModelState.LOADING)
89+
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)) {
90+
actionListener
91+
.onFailure(
92+
new Exception("Model cannot be deleted in loading or loaded state. Try unloading first and then delete")
93+
);
94+
} else {
95+
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
96+
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
97+
@Override
98+
public void onResponse(DeleteResponse deleteResponse) {
99+
deleteModelChunks(modelId, deleteResponse, actionListener);
100+
}
67101

68-
@Override
69-
public void onFailure(Exception e) {
70-
log.error("Failed to delete model meta data for model: " + modelId, e);
71-
if (e instanceof ResourceNotFoundException) {
72-
deleteModelChunks(modelId, null, actionListener);
102+
@Override
103+
public void onFailure(Exception e) {
104+
log.error("Failed to delete model meta data for model: " + modelId, e);
105+
if (e instanceof ResourceNotFoundException) {
106+
deleteModelChunks(modelId, null, actionListener);
107+
}
108+
actionListener.onFailure(e);
109+
}
110+
});
111+
}
112+
} catch (Exception e) {
113+
log.error("Failed to parse ml model" + r.getId(), e);
114+
actionListener.onFailure(e);
73115
}
74-
actionListener.onFailure(e);
75116
}
76-
});
117+
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find model")); }), () -> context.restore()));
77118
} catch (Exception e) {
78119
log.error("Failed to delete ML model " + modelId, e);
79120
actionListener.onFailure(e);

plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,27 @@
55

66
package org.opensearch.ml.action.tasks;
77

8+
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
89
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
10+
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
911

1012
import lombok.extern.log4j.Log4j2;
1113

1214
import org.opensearch.action.ActionListener;
1315
import org.opensearch.action.ActionRequest;
1416
import org.opensearch.action.delete.DeleteRequest;
1517
import org.opensearch.action.delete.DeleteResponse;
18+
import org.opensearch.action.get.GetRequest;
1619
import org.opensearch.action.support.ActionFilters;
1720
import org.opensearch.action.support.HandledTransportAction;
1821
import org.opensearch.client.Client;
1922
import org.opensearch.common.inject.Inject;
2023
import org.opensearch.common.util.concurrent.ThreadContext;
24+
import org.opensearch.common.xcontent.NamedXContentRegistry;
25+
import org.opensearch.common.xcontent.XContentParser;
26+
import org.opensearch.ml.common.MLTask;
27+
import org.opensearch.ml.common.MLTaskState;
28+
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
2129
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
2230
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
2331
import org.opensearch.tasks.Task;
@@ -28,35 +36,62 @@ public class DeleteTaskTransportAction extends HandledTransportAction<ActionRequ
2836

2937
Client client;
3038

39+
NamedXContentRegistry xContentRegistry;
40+
3141
@Inject
32-
public DeleteTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
42+
public DeleteTaskTransportAction(
43+
TransportService transportService,
44+
ActionFilters actionFilters,
45+
Client client,
46+
NamedXContentRegistry xContentRegistry
47+
) {
3348
super(MLTaskDeleteAction.NAME, transportService, actionFilters, MLTaskDeleteRequest::new);
3449
this.client = client;
50+
this.xContentRegistry = xContentRegistry;
3551
}
3652

3753
@Override
3854
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
3955
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.fromActionRequest(request);
4056
String taskId = mlTaskDeleteRequest.getTaskId();
41-
42-
DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);
57+
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);
4358

4459
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
45-
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
46-
@Override
47-
public void onResponse(DeleteResponse deleteResponse) {
48-
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
49-
actionListener.onResponse(deleteResponse);
50-
}
60+
client.get(getRequest, ActionListener.wrap(r -> {
61+
62+
if (r != null && r.isExists()) {
63+
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
64+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
65+
MLTask mlTask = MLTask.parse(parser);
66+
MLTaskState mlTaskState = mlTask.getState();
67+
if (mlTaskState.equals(MLTaskState.RUNNING)) {
68+
actionListener.onFailure(new Exception("Task cannot be deleted in running state. Try after sometime"));
69+
} else {
70+
DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);
71+
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
72+
@Override
73+
public void onResponse(DeleteResponse deleteResponse) {
74+
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
75+
actionListener.onResponse(deleteResponse);
76+
}
5177

52-
@Override
53-
public void onFailure(Exception e) {
54-
log.error("Failed to delete ML Task " + taskId, e);
55-
actionListener.onFailure(e);
78+
@Override
79+
public void onFailure(Exception e) {
80+
log.error("Failed to delete ML Task " + taskId, e);
81+
actionListener.onFailure(e);
82+
}
83+
});
84+
}
85+
} catch (Exception e) {
86+
log.error("Failed to parse ML task " + taskId, e);
87+
actionListener.onFailure(e);
88+
}
89+
} else {
90+
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
5691
}
57-
});
92+
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
5893
} catch (Exception e) {
59-
log.error("Failed to delete ML task " + taskId, e);
94+
log.error("Failed to delete ml task " + taskId, e);
6095
actionListener.onFailure(e);
6196
}
6297
}

plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,25 @@
2626
import org.mockito.ArgumentCaptor;
2727
import org.mockito.Mock;
2828
import org.mockito.MockitoAnnotations;
29+
import org.opensearch.ResourceNotFoundException;
2930
import org.opensearch.action.ActionListener;
3031
import org.opensearch.action.bulk.BulkItemResponse;
3132
import org.opensearch.action.delete.DeleteResponse;
33+
import org.opensearch.action.get.GetResponse;
3234
import org.opensearch.action.support.ActionFilters;
3335
import org.opensearch.client.Client;
36+
import org.opensearch.common.bytes.BytesReference;
3437
import org.opensearch.common.settings.Settings;
3538
import org.opensearch.common.util.concurrent.ThreadContext;
39+
import org.opensearch.common.xcontent.NamedXContentRegistry;
40+
import org.opensearch.common.xcontent.ToXContent;
41+
import org.opensearch.common.xcontent.XContentBuilder;
42+
import org.opensearch.common.xcontent.XContentFactory;
43+
import org.opensearch.index.get.GetResult;
3644
import org.opensearch.index.reindex.BulkByScrollResponse;
3745
import org.opensearch.index.reindex.ScrollableHitSource;
46+
import org.opensearch.ml.common.MLModel;
47+
import org.opensearch.ml.common.model.MLModelState;
3848
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
3949
import org.opensearch.test.OpenSearchTestCase;
4050
import org.opensearch.threadpool.ThreadPool;
@@ -62,27 +72,31 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase {
6272
@Mock
6373
BulkByScrollResponse bulkByScrollResponse;
6474

75+
@Mock
76+
NamedXContentRegistry xContentRegistry;
77+
6578
@Rule
6679
public ExpectedException exceptionRule = ExpectedException.none();
6780

6881
DeleteModelTransportAction deleteModelTransportAction;
6982
MLModelDeleteRequest mlModelDeleteRequest;
7083
ThreadContext threadContext;
84+
MLModel model;
7185

7286
@Before
7387
public void setup() throws IOException {
7488
MockitoAnnotations.openMocks(this);
7589

7690
mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId("test_id").build();
77-
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client));
91+
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client, xContentRegistry));
7892

7993
Settings settings = Settings.builder().build();
8094
threadContext = new ThreadContext(settings);
8195
when(client.threadPool()).thenReturn(threadPool);
8296
when(threadPool.getThreadContext()).thenReturn(threadContext);
8397
}
8498

85-
public void testDeleteModel_Success() {
99+
public void testDeleteModel_Success() throws IOException {
86100
doAnswer(invocation -> {
87101
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
88102
listener.onResponse(deleteResponse);
@@ -96,10 +110,74 @@ public void testDeleteModel_Success() {
96110
return null;
97111
}).when(client).execute(any(), any(), any());
98112

113+
GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
114+
doAnswer(invocation -> {
115+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
116+
actionListener.onResponse(getResponse);
117+
return null;
118+
}).when(client).get(any(), any());
119+
99120
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
100121
verify(actionListener).onResponse(deleteResponse);
101122
}
102123

124+
public void testDeleteModel_CheckModelState() throws IOException {
125+
GetResponse getResponse = prepareMLModel(MLModelState.LOADING);
126+
doAnswer(invocation -> {
127+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
128+
actionListener.onResponse(getResponse);
129+
return null;
130+
}).when(client).get(any(), any());
131+
132+
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
133+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
134+
verify(actionListener).onFailure(argumentCaptor.capture());
135+
assertEquals(
136+
"Model cannot be deleted in loading or loaded state. Try unloading first and then delete",
137+
argumentCaptor.getValue().getMessage()
138+
);
139+
}
140+
141+
public void testDeleteModel_ModelNotFoundException() throws IOException {
142+
doAnswer(invocation -> {
143+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
144+
actionListener.onFailure(new Exception());
145+
return null;
146+
}).when(client).get(any(), any());
147+
148+
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
149+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
150+
verify(actionListener).onFailure(argumentCaptor.capture());
151+
assertEquals("Fail to find model", argumentCaptor.getValue().getMessage());
152+
}
153+
154+
public void testDeleteModel_ResourceNotFoundException() throws IOException {
155+
doAnswer(invocation -> {
156+
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
157+
listener.onFailure(new ResourceNotFoundException("errorMessage"));
158+
return null;
159+
}).when(client).delete(any(), any());
160+
161+
doAnswer(invocation -> {
162+
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
163+
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
164+
listener.onResponse(response);
165+
return null;
166+
}).when(client).execute(any(), any(), any());
167+
168+
GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
169+
doAnswer(invocation -> {
170+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
171+
actionListener.onResponse(getResponse);
172+
return null;
173+
}).when(client).get(any(), any());
174+
175+
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
176+
ArgumentCaptor<ResourceNotFoundException> argumentCaptor = ArgumentCaptor.forClass(ResourceNotFoundException.class);
177+
verify(actionListener).onFailure(argumentCaptor.capture());
178+
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
179+
}
180+
103181
public void testDeleteModelChunks_Success() {
104182
when(bulkByScrollResponse.getBulkFailures()).thenReturn(null);
105183
doAnswer(invocation -> {
@@ -112,7 +190,14 @@ public void testDeleteModelChunks_Success() {
112190
verify(actionListener).onResponse(deleteResponse);
113191
}
114192

115-
public void testDeleteModel_RuntimeException() {
193+
public void testDeleteModel_RuntimeException() throws IOException {
194+
GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
195+
doAnswer(invocation -> {
196+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
197+
actionListener.onResponse(getResponse);
198+
return null;
199+
}).when(client).get(any(), any());
200+
116201
doAnswer(invocation -> {
117202
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
118203
listener.onFailure(new RuntimeException("errorMessage"));
@@ -198,4 +283,13 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() {
198283
verify(actionListener).onFailure(argumentCaptor.capture());
199284
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
200285
}
286+
287+
public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException {
288+
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build();
289+
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
290+
BytesReference bytesReference = BytesReference.bytes(content);
291+
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
292+
GetResponse getResponse = new GetResponse(getResult);
293+
return getResponse;
294+
}
201295
}

0 commit comments

Comments
 (0)