|
15 | 15 | import org.elasticsearch.action.search.SearchResponse; |
16 | 16 | import org.elasticsearch.action.support.WriteRequest; |
17 | 17 | import org.elasticsearch.common.settings.Settings; |
| 18 | +import org.elasticsearch.core.Strings; |
18 | 19 | import org.elasticsearch.search.SearchHit; |
19 | 20 | import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; |
20 | 21 | import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; |
21 | 22 | import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; |
| 23 | +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; |
| 24 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; |
| 25 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters; |
| 26 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; |
22 | 27 | import org.junit.After; |
23 | 28 | import org.junit.Before; |
24 | 29 |
|
@@ -1541,6 +1546,7 @@ public void cleanup() { |
1541 | 1546 | } |
1542 | 1547 |
|
1543 | 1548 | public void testFeatureImportanceValues() throws Exception { |
| 1549 | + String predictionField = TARGET_FIELD + "_prediction"; |
1544 | 1550 | initialize("classification_house_pricing_test_feature_importance_values"); |
1545 | 1551 | indexData(sourceIndex); |
1546 | 1552 | DataFrameAnalyticsConfig config = buildAnalytics( |
@@ -1571,12 +1577,34 @@ public void testFeatureImportanceValues() throws Exception { |
1571 | 1577 |
|
1572 | 1578 | client().admin().indices().refresh(new RefreshRequest(destIndex)); |
1573 | 1579 | SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); |
| 1580 | + |
| 1581 | + // obtain addition information for investigation of #90599 |
| 1582 | + String modelId = getModelId(jobId); |
| 1583 | + TrainedModelMetadata modelMetadata = getModelMetadata(modelId); |
| 1584 | + assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0)); |
| 1585 | + StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90019 |
| 1586 | + for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) { |
| 1587 | + hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n"); |
| 1588 | + } |
| 1589 | + TrainedModelDefinition modelDefinition = getModelDefinition(modelId); |
| 1590 | + Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel(); |
| 1591 | + int numberTrees = ensemble.getModels().size(); |
| 1592 | + String str = "Failure: failed for modelId %s numberTrees %d\n"; |
1574 | 1593 | for (SearchHit hit : sourceData.getHits()) { |
1575 | 1594 | Map<String, Object> destDoc = getDestDoc(config, hit); |
| 1595 | + assertNotNull(destDoc); |
1576 | 1596 | Map<String, Object> resultsObject = getFieldValue(destDoc, "ml"); |
| 1597 | + assertThat(resultsObject.containsKey(predictionField), is(true)); |
| 1598 | + String predictionValue = (String) resultsObject.get(predictionField); |
| 1599 | + assertNotNull(predictionValue); |
| 1600 | + assertThat(resultsObject.containsKey("feature_importance"), is(true)); |
1577 | 1601 | @SuppressWarnings("unchecked") |
1578 | 1602 | List<Map<String, Object>> importanceArray = (List<Map<String, Object>>) resultsObject.get("feature_importance"); |
1579 | | - assertThat(importanceArray, hasSize(greaterThan(0))); |
| 1603 | + assertThat( |
| 1604 | + Strings.format(str, modelId, numberTrees) + predictionValue + hyperparameters + modelDefinition, |
| 1605 | + importanceArray, |
| 1606 | + hasSize(greaterThan(0)) |
| 1607 | + ); |
1580 | 1608 | } |
1581 | 1609 |
|
1582 | 1610 | } |
|
0 commit comments