|
7 | 7 | package org.elasticsearch.xpack.ml.action;
|
8 | 8 |
|
9 | 9 | import org.elasticsearch.ElasticsearchException;
|
| 10 | +import org.elasticsearch.ElasticsearchSecurityException; |
10 | 11 | import org.elasticsearch.ElasticsearchStatusException;
|
11 | 12 | import org.elasticsearch.ResourceNotFoundException;
|
12 | 13 | import org.elasticsearch.TransportVersion;
|
|
37 | 38 | import org.elasticsearch.license.License;
|
38 | 39 | import org.elasticsearch.license.LicenseUtils;
|
39 | 40 | import org.elasticsearch.license.XPackLicenseState;
|
| 41 | +import org.elasticsearch.rest.RestRequest; |
40 | 42 | import org.elasticsearch.rest.RestStatus;
|
41 | 43 | import org.elasticsearch.search.builder.SearchSourceBuilder;
|
42 | 44 | import org.elasticsearch.tasks.Task;
|
@@ -143,61 +145,7 @@ protected void masterOperation(
|
143 | 145 | // NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
|
144 | 146 | boolean hasModelDefinition = config.getModelDefinition() != null;
|
145 | 147 | if (hasModelDefinition) {
|
146 |
| - try { |
147 |
| - config.getModelDefinition().getTrainedModel().validate(); |
148 |
| - } catch (ElasticsearchException ex) { |
149 |
| - finalResponseListener.onFailure( |
150 |
| - ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId()) |
151 |
| - ); |
152 |
| - return; |
153 |
| - } |
154 |
| - |
155 |
| - TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel()); |
156 |
| - if (trainedModelType == null) { |
157 |
| - finalResponseListener.onFailure( |
158 |
| - ExceptionsHelper.badRequestException( |
159 |
| - "Unknown trained model definition class [{}]", |
160 |
| - config.getModelDefinition().getTrainedModel().getName() |
161 |
| - ) |
162 |
| - ); |
163 |
| - return; |
164 |
| - } |
165 |
| - |
166 |
| - if (config.getModelType() == null) { |
167 |
| - // Set the model type from the definition |
168 |
| - config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build(); |
169 |
| - } else if (trainedModelType != config.getModelType()) { |
170 |
| - finalResponseListener.onFailure( |
171 |
| - ExceptionsHelper.badRequestException( |
172 |
| - "{} [{}] does not match the model definition type [{}]", |
173 |
| - TrainedModelConfig.MODEL_TYPE.getPreferredName(), |
174 |
| - config.getModelType(), |
175 |
| - trainedModelType |
176 |
| - ) |
177 |
| - ); |
178 |
| - return; |
179 |
| - } |
180 |
| - |
181 |
| - if (config.getInferenceConfig().isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) { |
182 |
| - finalResponseListener.onFailure( |
183 |
| - ExceptionsHelper.badRequestException( |
184 |
| - "Model [{}] inference config type [{}] does not support definition target type [{}]", |
185 |
| - config.getModelId(), |
186 |
| - config.getInferenceConfig().getName(), |
187 |
| - config.getModelDefinition().getTrainedModel().targetType() |
188 |
| - ) |
189 |
| - ); |
190 |
| - return; |
191 |
| - } |
192 |
| - |
193 |
| - TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion(); |
194 |
| - if (state.getMinTransportVersion().before(minCompatibilityVersion)) { |
195 |
| - finalResponseListener.onFailure( |
196 |
| - ExceptionsHelper.badRequestException( |
197 |
| - "Cannot create model [{}] while cluster upgrade is in progress.", |
198 |
| - config.getModelId() |
199 |
| - ) |
200 |
| - ); |
| 148 | + if (validateModelDefinition(config, state, licenseState, finalResponseListener) == false) { |
201 | 149 | return;
|
202 | 150 | }
|
203 | 151 | }
|
@@ -507,6 +455,85 @@ private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> li
|
507 | 455 | );
|
508 | 456 | }
|
509 | 457 |
|
| 458 | + public static boolean validateModelDefinition( |
| 459 | + TrainedModelConfig config, |
| 460 | + ClusterState state, |
| 461 | + XPackLicenseState licenseState, |
| 462 | + ActionListener<Response> finalResponseListener |
| 463 | + ) { |
| 464 | + try { |
| 465 | + config.getModelDefinition().getTrainedModel().validate(); |
| 466 | + } catch (ElasticsearchException ex) { |
| 467 | + finalResponseListener.onFailure( |
| 468 | + ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId()) |
| 469 | + ); |
| 470 | + return false; |
| 471 | + } |
| 472 | + |
| 473 | + TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel()); |
| 474 | + if (trainedModelType == null) { |
| 475 | + finalResponseListener.onFailure( |
| 476 | + ExceptionsHelper.badRequestException( |
| 477 | + "Unknown trained model definition class [{}]", |
| 478 | + config.getModelDefinition().getTrainedModel().getName() |
| 479 | + ) |
| 480 | + ); |
| 481 | + return false; |
| 482 | + } |
| 483 | + |
| 484 | + var configModelType = config.getModelType(); |
| 485 | + if (configModelType == null) { |
| 486 | + // Set the model type from the definition |
| 487 | + config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build(); |
| 488 | + } else if (trainedModelType != configModelType) { |
| 489 | + finalResponseListener.onFailure( |
| 490 | + ExceptionsHelper.badRequestException( |
| 491 | + "{} [{}] does not match the model definition type [{}]", |
| 492 | + TrainedModelConfig.MODEL_TYPE.getPreferredName(), |
| 493 | + configModelType, |
| 494 | + trainedModelType |
| 495 | + ) |
| 496 | + ); |
| 497 | + return false; |
| 498 | + } |
| 499 | + |
| 500 | + var inferenceConfig = config.getInferenceConfig(); |
| 501 | + if (inferenceConfig.isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) { |
| 502 | + finalResponseListener.onFailure( |
| 503 | + ExceptionsHelper.badRequestException( |
| 504 | + "Model [{}] inference config type [{}] does not support definition target type [{}]", |
| 505 | + config.getModelId(), |
| 506 | + config.getInferenceConfig().getName(), |
| 507 | + config.getModelDefinition().getTrainedModel().targetType() |
| 508 | + ) |
| 509 | + ); |
| 510 | + return false; |
| 511 | + } |
| 512 | + |
| 513 | + var minLicenseSupported = inferenceConfig.getMinLicenseSupportedForAction(RestRequest.Method.PUT); |
| 514 | + if (licenseState.isAllowedByLicense(minLicenseSupported) == false) { |
| 515 | + finalResponseListener.onFailure( |
| 516 | + new ElasticsearchSecurityException( |
| 517 | + "Model of type [{}] requires [{}] license level", |
| 518 | + RestStatus.FORBIDDEN, |
| 519 | + config.getInferenceConfig().getName(), |
| 520 | + minLicenseSupported |
| 521 | + ) |
| 522 | + ); |
| 523 | + return false; |
| 524 | + } |
| 525 | + |
| 526 | + TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion(); |
| 527 | + if (state.getMinTransportVersion().before(minCompatibilityVersion)) { |
| 528 | + finalResponseListener.onFailure( |
| 529 | + ExceptionsHelper.badRequestException("Cannot create model [{}] while cluster upgrade is in progress.", config.getModelId()) |
| 530 | + ); |
| 531 | + return false; |
| 532 | + } |
| 533 | + |
| 534 | + return true; |
| 535 | + } |
| 536 | + |
510 | 537 | @Override
|
511 | 538 | protected ClusterBlockException checkBlock(Request request, ClusterState state) {
|
512 | 539 | return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
|
|
0 commit comments