|
10 | 10 | import static org.opensearch.ml.permission.AccessController.getUserContext; |
11 | 11 | import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; |
12 | 12 | import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT; |
| 13 | +import static org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT; |
| 14 | +import static org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT; |
| 15 | +import static org.opensearch.ml.stats.StatNames.failureCountStat; |
| 16 | +import static org.opensearch.ml.stats.StatNames.requestCountStat; |
13 | 17 |
|
14 | 18 | import java.time.Instant; |
15 | 19 | import java.util.Base64; |
|
46 | 50 | import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; |
47 | 51 | import org.opensearch.ml.engine.MLEngine; |
48 | 52 | import org.opensearch.ml.indices.MLInputDatasetHandler; |
| 53 | +import org.opensearch.ml.stats.ActionName; |
49 | 54 | import org.opensearch.ml.stats.MLStats; |
50 | 55 | import org.opensearch.threadpool.ThreadPool; |
51 | 56 | import org.opensearch.transport.TransportService; |
@@ -146,86 +151,74 @@ private void predict( |
146 | 151 | ) { |
147 | 152 | ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); |
148 | 153 | // track ML task count and add ML task into cache |
149 | | - mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment(); |
| 154 | + mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment(); |
| 155 | + mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment(); |
| 156 | + mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment(); |
150 | 157 | mlTaskManager.add(mlTask); |
151 | | - MLInput mlInput = request.getMlInput(); |
152 | 158 |
|
153 | 159 | // run predict |
154 | | - try { |
| 160 | + if (request.getModelId() != null) { |
155 | 161 | // search model by model id. |
156 | | - Model model = new Model(); |
157 | | - if (request.getModelId() != null) { |
| 162 | + try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { |
| 163 | + MLInput mlInput = request.getMlInput(); |
| 164 | + ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> { |
| 165 | + if (r == null || !r.isExists()) { |
| 166 | + internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); |
| 167 | + return; |
| 168 | + } |
| 169 | + Map<String, Object> source = r.getSourceAsMap(); |
| 170 | + User requestUser = getUserContext(client); |
| 171 | + User resourceUser = User.parse((String) source.get(USER)); |
| 172 | + if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) { |
| 173 | + // The backend roles of request user and resource user doesn't have intersection |
| 174 | + OpenSearchException e = new OpenSearchException( |
| 175 | + "User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId() |
| 176 | + ); |
| 177 | + log.debug(e); |
| 178 | + handlePredictFailure(mlTask, internalListener, e, false); |
| 179 | + return; |
| 180 | + } |
| 181 | + |
| 182 | + Model model = new Model(); |
| 183 | + model.setName((String) source.get(MLModel.MODEL_NAME)); |
| 184 | + model.setVersion((Integer) source.get(MLModel.MODEL_VERSION)); |
| 185 | + byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT)); |
| 186 | + model.setContent(decoded); |
| 187 | + |
| 188 | + // run predict |
| 189 | + mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync()); |
| 190 | + MLOutput output = MLEngine |
| 191 | + .predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); |
| 192 | + if (output instanceof MLPredictionOutput) { |
| 193 | + ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); |
| 194 | + } |
| 195 | + |
| 196 | + // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state |
| 197 | + handleAsyncMLTaskComplete(mlTask); |
| 198 | + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); |
| 199 | + internalListener.onResponse(response); |
| 200 | + }, e -> { |
| 201 | + log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); |
| 202 | + handlePredictFailure(mlTask, internalListener, e, true); |
| 203 | + }); |
158 | 204 | GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId()); |
159 | | - try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { |
160 | | - ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> { |
161 | | - if (r == null || !r.isExists()) { |
162 | | - internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); |
163 | | - return; |
164 | | - } |
165 | | - Map<String, Object> source = r.getSourceAsMap(); |
166 | | - User requestUser = getUserContext(client); |
167 | | - User resourceUser = User.parse((String) source.get(USER)); |
168 | | - if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) { |
169 | | - // The backend roles of request user and resource user doesn't have intersection |
170 | | - OpenSearchException e = new OpenSearchException( |
171 | | - "User: " |
172 | | - + requestUser.getName() |
173 | | - + " does not have permissions to run predict by model: " |
174 | | - + request.getModelId() |
175 | | - ); |
176 | | - log.debug(e); |
177 | | - handlePredictFailure(mlTask, internalListener, e); |
178 | | - return; |
179 | | - } |
180 | | - |
181 | | - model.setName((String) source.get(MLModel.MODEL_NAME)); |
182 | | - model.setVersion((Integer) source.get(MLModel.MODEL_VERSION)); |
183 | | - byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT)); |
184 | | - model.setContent(decoded); |
185 | | - |
186 | | - // run predict |
187 | | - MLOutput output; |
188 | | - try { |
189 | | - mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync()); |
190 | | - output = MLEngine |
191 | | - .predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); |
192 | | - if (output instanceof MLPredictionOutput) { |
193 | | - ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); |
194 | | - } |
195 | | - |
196 | | - // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state |
197 | | - handleAsyncMLTaskComplete(mlTask); |
198 | | - } catch (Exception e) { |
199 | | - // todo need to specify what exception |
200 | | - log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + model.getName(), e); |
201 | | - handlePredictFailure(mlTask, internalListener, e); |
202 | | - return; |
203 | | - } |
204 | | - |
205 | | - MLTaskResponse response = MLTaskResponse.builder().output(output).build(); |
206 | | - internalListener.onResponse(response); |
207 | | - }, e -> { |
208 | | - log.error("Failed to predict model " + mlTask.getModelId(), e); |
209 | | - internalListener.onFailure(e); |
210 | | - }); |
211 | | - client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore())); |
212 | | - } catch (Exception e) { |
213 | | - log.error("Failed to get model " + mlTask.getModelId(), e); |
214 | | - internalListener.onFailure(e); |
215 | | - } |
216 | | - } else { |
217 | | - IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid"); |
218 | | - log.error("ModelId is invalid", e); |
219 | | - handlePredictFailure(mlTask, internalListener, e); |
220 | | - return; |
| 205 | + client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore())); |
| 206 | + } catch (Exception e) { |
| 207 | + log.error("Failed to get model " + mlTask.getModelId(), e); |
| 208 | + handlePredictFailure(mlTask, internalListener, e, true); |
221 | 209 | } |
222 | | - } catch (Exception e) { |
223 | | - log.error("Failed to predict " + mlInput.getAlgorithm(), e); |
224 | | - internalListener.onFailure(e); |
| 210 | + } else { |
| 211 | + IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid"); |
| 212 | + log.error("ModelId is invalid", e); |
| 213 | + handlePredictFailure(mlTask, internalListener, e, false); |
225 | 214 | } |
226 | 215 | } |
227 | 216 |
|
228 | | - private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e) { |
| 217 | + private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e, boolean trackFailure) { |
| 218 | + if (trackFailure) { |
| 219 | + mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment(); |
| 220 | + mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment(); |
| 221 | + } |
229 | 222 | handleAsyncMLTaskFailure(mlTask, e); |
230 | 223 | listener.onFailure(e); |
231 | 224 | } |
|
0 commit comments