Skip to content

Commit bb7e148

Browse files
authored
[ML] RegressionIT: Fix hyperparameters for regression tests and unmute the test (#135541) (#135770)
This PR fixes the flaky test muted in #93228 by fixing hyperparameters to the values that always work. Since the test is for alias fields and not for the training algorithm, fixing the hyperparameters is not dangerous. Closes #93228
1 parent 0e3b3f3 commit bb7e148

File tree

1 file changed

+25
-28
lines changed
  • x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration

1 file changed

+25
-28
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,8 @@
2727
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
2828
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
2929
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
30-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
3130
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
3231
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
33-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
34-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters;
35-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
3632
import org.hamcrest.Matchers;
3733
import org.junit.After;
3834

@@ -540,7 +536,6 @@ public void testWithDatastream() throws Exception {
540536
);
541537
}
542538

543-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/93228")
544539
public void testAliasFields() throws Exception {
545540
// The goal of this test is to assert alias fields are included in the analytics job.
546541
// We have a simple dataset with two integer fields: field_1 and field_2.
@@ -585,10 +580,32 @@ public void testAliasFields() throws Exception {
585580
// Very infrequently this test may fail as the algorithm underestimates the
586581
// required number of trees for this simple problem. This failure is irrelevant
587582
// for non-trivial real-world problem and improving estimation of the number of trees
588-
// would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed.
583+
// would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed
584+
// and use the hyperparameters that are known to work.
589585
long seed = 1000L; // fix seed
590586

591-
Regression regression = new Regression("field_2", BoostedTreeParams.builder().build(), null, 90.0, seed, null, null, null, null);
587+
Regression regression = new Regression(
588+
"field_2",
589+
BoostedTreeParams.builder()
590+
.setDownsampleFactor(0.7520841625652861)
591+
.setAlpha(547.9095715556235)
592+
.setLambda(3.3008189603590044)
593+
.setGamma(1.6082763366825203)
594+
.setSoftTreeDepthLimit(4.733224114945455)
595+
.setSoftTreeDepthTolerance(0.15)
596+
.setEta(0.12371209659057758)
597+
.setEtaGrowthRatePerTree(1.0618560482952888)
598+
.setMaxTrees(30)
599+
.setFeatureBagFraction(0.8)
600+
.build(),
601+
null,
602+
90.0,
603+
seed,
604+
null,
605+
null,
606+
null,
607+
null
608+
);
592609
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder().setId(jobId)
593610
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null, Collections.emptyMap()))
594611
.setDest(new DataFrameAnalyticsDest(destIndex, null))
@@ -604,19 +621,6 @@ public void testAliasFields() throws Exception {
604621

605622
waitUntilAnalyticsIsStopped(jobId);
606623

607-
// obtain addition information for investigation of #90599
608-
String modelId = getModelId(jobId);
609-
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
610-
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
611-
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90599
612-
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
613-
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
614-
}
615-
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
616-
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
617-
int numberTrees = ensemble.getModels().size();
618-
619-
StringBuilder targetsPredictions = new StringBuilder(); // used to investigate #90599
620624
assertResponse(prepareSearch(sourceIndex).setSize(totalDocCount), sourceData -> {
621625
double predictionErrorSum = 0.0;
622626
for (SearchHit hit : sourceData.getHits()) {
@@ -629,19 +633,12 @@ public void testAliasFields() throws Exception {
629633
int featureValue = (int) destDoc.get("field_1");
630634
double predictionValue = (double) resultsObject.get(predictionField);
631635
predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);
632-
633-
// collect the log of targets and predictions for debugging #90599
634-
targetsPredictions.append(2 * featureValue).append(", ").append(predictionValue).append("\n");
635636
}
636637
// We assert on the mean prediction error in order to reduce the probability
637638
// the test fails compared to asserting on the prediction of each individual doc.
638639
double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
639640
String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n";
640-
assertThat(
641-
Strings.format(str, seed, modelId, numberTrees) + targetsPredictions + hyperparameters,
642-
meanPredictionError,
643-
lessThanOrEqualTo(3.0)
644-
);
641+
assertThat(meanPredictionError, lessThanOrEqualTo(3.0));
645642
});
646643

647644
assertProgressComplete(jobId);

0 commit comments

Comments
 (0)