|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | +package org.opensearch.ml.model; |
| 6 | + |
| 7 | +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; |
| 8 | +import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX; |
| 9 | +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; |
| 10 | +import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD; |
| 11 | +import static org.opensearch.ml.common.CommonValue.NODE_ID_FIELD; |
| 12 | +import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; |
| 13 | +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE; |
| 14 | +import static org.opensearch.ml.settings.MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES; |
| 15 | +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; |
| 16 | + |
| 17 | +import java.util.Arrays; |
| 18 | +import java.util.HashMap; |
| 19 | +import java.util.List; |
| 20 | +import java.util.Map; |
| 21 | +import java.util.concurrent.ExecutionException; |
| 22 | +import java.util.stream.Collectors; |
| 23 | + |
| 24 | +import lombok.extern.log4j.Log4j2; |
| 25 | + |
| 26 | +import org.opensearch.action.ActionListener; |
| 27 | +import org.opensearch.action.StepListener; |
| 28 | +import org.opensearch.action.index.IndexAction; |
| 29 | +import org.opensearch.action.index.IndexRequestBuilder; |
| 30 | +import org.opensearch.action.index.IndexResponse; |
| 31 | +import org.opensearch.action.search.SearchAction; |
| 32 | +import org.opensearch.action.search.SearchRequestBuilder; |
| 33 | +import org.opensearch.action.search.SearchResponse; |
| 34 | +import org.opensearch.action.support.WriteRequest; |
| 35 | +import org.opensearch.client.Client; |
| 36 | +import org.opensearch.cluster.node.DiscoveryNode; |
| 37 | +import org.opensearch.cluster.service.ClusterService; |
| 38 | +import org.opensearch.common.settings.Settings; |
| 39 | +import org.opensearch.common.util.CollectionUtils; |
| 40 | +import org.opensearch.common.xcontent.NamedXContentRegistry; |
| 41 | +import org.opensearch.common.xcontent.XContentParser; |
| 42 | +import org.opensearch.index.query.QueryBuilder; |
| 43 | +import org.opensearch.index.query.QueryBuilders; |
| 44 | +import org.opensearch.ml.cluster.DiscoveryNodeHelper; |
| 45 | +import org.opensearch.ml.common.MLTask; |
| 46 | +import org.opensearch.ml.common.exception.MLException; |
| 47 | +import org.opensearch.ml.common.transport.load.MLLoadModelAction; |
| 48 | +import org.opensearch.ml.common.transport.load.MLLoadModelRequest; |
| 49 | +import org.opensearch.ml.utils.MLNodeUtils; |
| 50 | +import org.opensearch.rest.RestStatus; |
| 51 | +import org.opensearch.search.SearchHit; |
| 52 | +import org.opensearch.search.builder.SearchSourceBuilder; |
| 53 | +import org.opensearch.search.sort.FieldSortBuilder; |
| 54 | +import org.opensearch.search.sort.SortBuilder; |
| 55 | +import org.opensearch.search.sort.SortOrder; |
| 56 | +import org.opensearch.threadpool.ThreadPool; |
| 57 | + |
| 58 | +import com.google.common.annotations.VisibleForTesting; |
| 59 | + |
| 60 | +/** |
| 61 | + * Manager class for ML models and nodes. It contains ML model auto reload operations etc. |
| 62 | + */ |
| 63 | +@Log4j2 |
| 64 | +public class MLModelAutoReloader { |
| 65 | + |
| 66 | + private final Client client; |
| 67 | + private final ClusterService clusterService; |
| 68 | + private final NamedXContentRegistry xContentRegistry; |
| 69 | + private final DiscoveryNodeHelper nodeHelper; |
| 70 | + private final ThreadPool threadPool; |
| 71 | + private volatile Boolean enableAutoReloadModel; |
| 72 | + private volatile Integer autoReloadMaxRetryTimes; |
| 73 | + |
| 74 | + /** |
| 75 | + * constructor method, init all the params necessary for model auto reloading |
| 76 | + * |
| 77 | + * @param clusterService clusterService |
| 78 | + * @param threadPool threadPool |
| 79 | + * @param client client |
| 80 | + * @param xContentRegistry xContentRegistry |
| 81 | + * @param nodeHelper nodeHelper |
| 82 | + * @param settings settings |
| 83 | + */ |
| 84 | + public MLModelAutoReloader( |
| 85 | + ClusterService clusterService, |
| 86 | + ThreadPool threadPool, |
| 87 | + Client client, |
| 88 | + NamedXContentRegistry xContentRegistry, |
| 89 | + DiscoveryNodeHelper nodeHelper, |
| 90 | + Settings settings |
| 91 | + ) { |
| 92 | + this.clusterService = clusterService; |
| 93 | + this.client = client; |
| 94 | + this.xContentRegistry = xContentRegistry; |
| 95 | + this.nodeHelper = nodeHelper; |
| 96 | + this.threadPool = threadPool; |
| 97 | + |
| 98 | + enableAutoReloadModel = ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.get(settings); |
| 99 | + autoReloadMaxRetryTimes = ML_MODEL_RELOAD_MAX_RETRY_TIMES.get(settings); |
| 100 | + clusterService |
| 101 | + .getClusterSettings() |
| 102 | + .addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, it -> enableAutoReloadModel = it); |
| 103 | + |
| 104 | + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_MODEL_RELOAD_MAX_RETRY_TIMES, it -> autoReloadMaxRetryTimes = it); |
| 105 | + } |
| 106 | + |
| 107 | + /** |
| 108 | + * the main method: model auto reloading |
| 109 | + */ |
| 110 | + public void autoReloadModel() { |
| 111 | + log.info("auto reload model enabled: {} ", enableAutoReloadModel); |
| 112 | + |
| 113 | + // if we don't need to reload automatically, just return without doing anything |
| 114 | + if (!enableAutoReloadModel) { |
| 115 | + return; |
| 116 | + } |
| 117 | + |
| 118 | + // At opensearch startup, get local node id, if not ml node,we ignored, just return without doing anything |
| 119 | + if (!MLNodeUtils.isMLNode(clusterService.localNode())) { |
| 120 | + return; |
| 121 | + } |
| 122 | + |
| 123 | + String localNodeId = clusterService.localNode().getId(); |
| 124 | + // auto reload all models of this local ml node |
| 125 | + threadPool.executor(LOAD_THREAD_POOL).execute(() -> { |
| 126 | + try { |
| 127 | + autoReloadModelByNodeId(localNodeId); |
| 128 | + } catch (ExecutionException | InterruptedException e) { |
| 129 | + log.error("the model auto-reloading has exception,and the root cause message is: {}", e); |
| 130 | + throw new MLException(e); |
| 131 | + } |
| 132 | + }); |
| 133 | + } |
| 134 | + |
| 135 | + /** |
| 136 | + * auto reload all the models under the node id<br> the node must be a ml node<br> |
| 137 | + * |
| 138 | + * @param localNodeId node id |
| 139 | + */ |
| 140 | + @VisibleForTesting |
| 141 | + void autoReloadModelByNodeId(String localNodeId) throws ExecutionException, InterruptedException { |
| 142 | + StepListener<SearchResponse> queryTaskStep = new StepListener<>(); |
| 143 | + StepListener<SearchResponse> getRetryTimesStep = new StepListener<>(); |
| 144 | + StepListener<IndexResponse> saveLatestRetryTimesStep = new StepListener<>(); |
| 145 | + |
| 146 | + if (!clusterService.state().metadata().indices().containsKey(ML_TASK_INDEX)) { |
| 147 | + // ML_TASK_INDEX did not exist,do nothing |
| 148 | + return; |
| 149 | + } |
| 150 | + |
| 151 | + queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); |
| 152 | + |
| 153 | + getRetryTimes(localNodeId, ActionListener.wrap(getRetryTimesStep::onResponse, getRetryTimesStep::onFailure)); |
| 154 | + |
| 155 | + queryTaskStep.whenComplete(searchResponse -> { |
| 156 | + SearchHit[] hits = searchResponse.getHits().getHits(); |
| 157 | + if (CollectionUtils.isEmpty(hits)) { |
| 158 | + return; |
| 159 | + } |
| 160 | + |
| 161 | + getRetryTimesStep.whenComplete(getReTryTimesResponse -> { |
| 162 | + int retryTimes = 0; |
| 163 | + // if getReTryTimesResponse is null,it means we get retryTimes at the first time,and the index |
| 164 | + // .plugins-ml-model-reload doesn't exist,so we should let retryTimes be zero(init value) |
| 165 | + // we don't do anything |
| 166 | + // if getReTryTimesResponse is not null,it means we have saved the value of retryTimes into the index |
| 167 | + // .plugins-ml-model-reload,so we get the value of the field MODEL_LOAD_RETRY_TIMES_FIELD |
| 168 | + if (getReTryTimesResponse != null) { |
| 169 | + Map<String, Object> sourceAsMap = getReTryTimesResponse.getHits().getHits()[0].getSourceAsMap(); |
| 170 | + retryTimes = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD); |
| 171 | + } |
| 172 | + |
| 173 | + // According to the node id to get retry times, if more than the max retry times, don't need to retry |
| 174 | + // that the number of unsuccessful reload has reached the maximum number of times, do not need to reload |
| 175 | + if (retryTimes > autoReloadMaxRetryTimes) { |
| 176 | + log.info("Node: {} has reached to the max retry limit, failed to load models", localNodeId); |
| 177 | + return; |
| 178 | + } |
| 179 | + |
| 180 | + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, hits[0].getSourceRef())) { |
| 181 | + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); |
| 182 | + MLTask mlTask = MLTask.parse(parser); |
| 183 | + |
| 184 | + autoReloadModelByNodeAndModelId(localNodeId, mlTask.getModelId()); |
| 185 | + |
| 186 | + // if reload the model successfully,the number of unsuccessful reload should be reset to zero. |
| 187 | + retryTimes = 0; |
| 188 | + } catch (MLException e) { |
| 189 | + retryTimes++; |
| 190 | + log.error("Can't auto reload model in node id {} ,has tried {} times\nThe reason is:{}", localNodeId, retryTimes, e); |
| 191 | + } |
| 192 | + |
| 193 | + // Store the latest value of the retryTimes and node id under the index ".plugins-ml-model-reload" |
| 194 | + saveLatestRetryTimes( |
| 195 | + localNodeId, |
| 196 | + retryTimes, |
| 197 | + ActionListener.wrap(saveLatestRetryTimesStep::onResponse, saveLatestRetryTimesStep::onFailure) |
| 198 | + ); |
| 199 | + }, getRetryTimesStep::onFailure); |
| 200 | + }, queryTaskStep::onFailure); |
| 201 | + |
| 202 | + saveLatestRetryTimesStep.whenComplete(response -> log.info("successfully complete all steps"), saveLatestRetryTimesStep::onFailure); |
| 203 | + } |
| 204 | + |
| 205 | + /** |
| 206 | + * auto reload 1 model under the node id |
| 207 | + * |
| 208 | + * @param localNodeId node id |
| 209 | + * @param modelId model id |
| 210 | + */ |
| 211 | + @VisibleForTesting |
| 212 | + void autoReloadModelByNodeAndModelId(String localNodeId, String modelId) throws MLException { |
| 213 | + List<String> allMLNodeIdList = Arrays |
| 214 | + .stream(nodeHelper.getAllNodes()) |
| 215 | + .filter(MLNodeUtils::isMLNode) |
| 216 | + .map(DiscoveryNode::getId) |
| 217 | + .collect(Collectors.toList()); |
| 218 | + |
| 219 | + if (!allMLNodeIdList.contains(localNodeId)) { |
| 220 | + allMLNodeIdList.add(localNodeId); |
| 221 | + } |
| 222 | + MLLoadModelRequest mlLoadModelRequest = new MLLoadModelRequest(modelId, allMLNodeIdList.toArray(new String[] {}), false, false); |
| 223 | + |
| 224 | + client |
| 225 | + .execute( |
| 226 | + MLLoadModelAction.INSTANCE, |
| 227 | + mlLoadModelRequest, |
| 228 | + ActionListener |
| 229 | + .wrap(response -> log.info("the model {} is auto reloading under the node {} ", modelId, localNodeId), exception -> { |
| 230 | + log.error("fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception); |
| 231 | + throw new MLException( |
| 232 | + "fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception |
| 233 | + ); |
| 234 | + }) |
| 235 | + ); |
| 236 | + } |
| 237 | + |
| 238 | + /** |
| 239 | + * query task index, and get the result of "task_type"="LOAD_MODEL" and "state"="COMPLETED" and |
| 240 | + * "worker_node" match nodeId |
| 241 | + * |
| 242 | + * @param localNodeId one of query condition |
| 243 | + */ |
| 244 | + @VisibleForTesting |
| 245 | + void queryTask(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) { |
| 246 | + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().from(0).size(1); |
| 247 | + |
| 248 | + QueryBuilder queryBuilder = QueryBuilders |
| 249 | + .boolQuery() |
| 250 | + .must(QueryBuilders.matchPhraseQuery("task_type", "LOAD_MODEL")) |
| 251 | + .must(QueryBuilders.matchPhraseQuery("worker_node", localNodeId)) |
| 252 | + .must( |
| 253 | + QueryBuilders |
| 254 | + .boolQuery() |
| 255 | + .should(QueryBuilders.matchPhraseQuery("state", "COMPLETED")) |
| 256 | + .should(QueryBuilders.matchPhraseQuery("state", "COMPLETED_WITH_ERROR")) |
| 257 | + ); |
| 258 | + searchSourceBuilder.query(queryBuilder); |
| 259 | + |
| 260 | + SortBuilder<FieldSortBuilder> sortBuilderOrder = new FieldSortBuilder("create_time").order(SortOrder.DESC); |
| 261 | + searchSourceBuilder.sort(sortBuilderOrder); |
| 262 | + |
| 263 | + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) |
| 264 | + .setIndices(ML_TASK_INDEX) |
| 265 | + .setSource(searchSourceBuilder); |
| 266 | + |
| 267 | + searchRequestBuilder.execute(ActionListener.wrap(searchResponseActionListener::onResponse, exception -> { |
| 268 | + log.error("index {} not found, the reason is {}", ML_TASK_INDEX, exception); |
| 269 | + throw new MLException("index " + ML_TASK_INDEX + " not found"); |
| 270 | + })); |
| 271 | + } |
| 272 | + |
| 273 | + /** |
| 274 | + * get retry times from the index ".plugins-ml-model-reload" by 1 ml node |
| 275 | + * |
| 276 | + * @param localNodeId the filter condition to query |
| 277 | + */ |
| 278 | + @VisibleForTesting |
| 279 | + void getRetryTimes(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) { |
| 280 | + if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_RELOAD_INDEX)) { |
| 281 | + // ML_MODEL_RELOAD_INDEX did not exist, it means it is our first time to do model auto-reloading operation |
| 282 | + searchResponseActionListener.onResponse(null); |
| 283 | + } |
| 284 | + |
| 285 | + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); |
| 286 | + searchSourceBuilder.fetchSource(new String[] { MODEL_LOAD_RETRY_TIMES_FIELD }, null); |
| 287 | + QueryBuilder queryBuilder = QueryBuilders.idsQuery().addIds(localNodeId); |
| 288 | + searchSourceBuilder.query(queryBuilder); |
| 289 | + |
| 290 | + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) |
| 291 | + .setIndices(ML_MODEL_RELOAD_INDEX) |
| 292 | + .setSource(searchSourceBuilder); |
| 293 | + |
| 294 | + searchRequestBuilder.execute(ActionListener.wrap(searchResponse -> { |
| 295 | + SearchHit[] hits = searchResponse.getHits().getHits(); |
| 296 | + if (CollectionUtils.isEmpty(hits)) { |
| 297 | + searchResponseActionListener.onResponse(null); |
| 298 | + return; |
| 299 | + } |
| 300 | + |
| 301 | + searchResponseActionListener.onResponse(searchResponse); |
| 302 | + }, searchResponseActionListener::onFailure)); |
| 303 | + } |
| 304 | + |
| 305 | + /** |
| 306 | + * save retry times |
| 307 | + * @param localNodeId node id |
| 308 | + * @param retryTimes actual retry times |
| 309 | + */ |
| 310 | + @VisibleForTesting |
| 311 | + void saveLatestRetryTimes(String localNodeId, int retryTimes, ActionListener<IndexResponse> indexResponseActionListener) { |
| 312 | + Map<String, Object> content = new HashMap<>(2); |
| 313 | + content.put(NODE_ID_FIELD, localNodeId); |
| 314 | + content.put(MODEL_LOAD_RETRY_TIMES_FIELD, retryTimes); |
| 315 | + |
| 316 | + IndexRequestBuilder indexRequestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_MODEL_RELOAD_INDEX) |
| 317 | + .setId(localNodeId) |
| 318 | + .setSource(content) |
| 319 | + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); |
| 320 | + |
| 321 | + indexRequestBuilder.execute(ActionListener.wrap(indexResponse -> { |
| 322 | + if (indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK) { |
| 323 | + log.info("node id:{} insert retry times successfully", localNodeId); |
| 324 | + indexResponseActionListener.onResponse(indexResponse); |
| 325 | + } |
| 326 | + }, e -> { |
| 327 | + log.error("node id:" + localNodeId + " insert retry times unsuccessfully", e); |
| 328 | + indexResponseActionListener.onFailure(new MLException(e)); |
| 329 | + })); |
| 330 | + } |
| 331 | +} |
0 commit comments