Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
import org.hamcrest.Matchers;
import org.junit.After;

Expand Down Expand Up @@ -540,7 +536,6 @@ public void testWithDatastream() throws Exception {
);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/93228")
public void testAliasFields() throws Exception {
// The goal of this test is to assert alias fields are included in the analytics job.
// We have a simple dataset with two integer fields: field_1 and field_2.
Expand Down Expand Up @@ -585,10 +580,32 @@ public void testAliasFields() throws Exception {
// Very infrequently this test may fail as the algorithm underestimates the
// required number of trees for this simple problem. This failure is irrelevant
// for non-trivial real-world problem and improving estimation of the number of trees
// would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed.
// would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed
// and use the hyperparameters that are known to work.
long seed = 1000L; // fix seed

Regression regression = new Regression("field_2", BoostedTreeParams.builder().build(), null, 90.0, seed, null, null, null, null);
Regression regression = new Regression(
"field_2",
BoostedTreeParams.builder()
.setDownsampleFactor(0.7520841625652861)
.setAlpha(547.9095715556235)
.setLambda(3.3008189603590044)
.setGamma(1.6082763366825203)
.setSoftTreeDepthLimit(4.733224114945455)
.setSoftTreeDepthTolerance(0.15)
.setEta(0.12371209659057758)
.setEtaGrowthRatePerTree(1.0618560482952888)
.setMaxTrees(30)
.setFeatureBagFraction(0.8)
.build(),
null,
90.0,
seed,
null,
null,
null,
null
);
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder().setId(jobId)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null, Collections.emptyMap()))
.setDest(new DataFrameAnalyticsDest(destIndex, null))
Expand All @@ -604,19 +621,6 @@ public void testAliasFields() throws Exception {

waitUntilAnalyticsIsStopped(jobId);

// obtain addition information for investigation of #90599
String modelId = getModelId(jobId);
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90599
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
}
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
int numberTrees = ensemble.getModels().size();

StringBuilder targetsPredictions = new StringBuilder(); // used to investigate #90599
assertResponse(prepareSearch(sourceIndex).setSize(totalDocCount), sourceData -> {
double predictionErrorSum = 0.0;
for (SearchHit hit : sourceData.getHits()) {
Expand All @@ -629,19 +633,12 @@ public void testAliasFields() throws Exception {
int featureValue = (int) destDoc.get("field_1");
double predictionValue = (double) resultsObject.get(predictionField);
predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);

// collect the log of targets and predictions for debugging #90599
targetsPredictions.append(2 * featureValue).append(", ").append(predictionValue).append("\n");
}
// We assert on the mean prediction error in order to reduce the probability
// the test fails compared to asserting on the prediction of each individual doc.
double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n";
assertThat(
Strings.format(str, seed, modelId, numberTrees) + targetsPredictions + hyperparameters,
meanPredictionError,
lessThanOrEqualTo(3.0)
);
assertThat(meanPredictionError, lessThanOrEqualTo(3.0));
});

assertProgressComplete(jobId);
Expand Down