From 7728a7fd4d2fc00a4c7dc205a83bfcdc44ef094f Mon Sep 17 00:00:00 2001 From: Valeriy Khakhutskyy <1292899+valeriy42@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:48:26 +0200 Subject: [PATCH] [ML] RegressionIT: Fix hyperparameters for regression tests and unmute the test (#135541) This PR fixes the flaky test muted in #93228 by fixing hyperparameters to the values that always work. Since the test is for alias fields and not for the training algorithm, fixing the hyperparameters is not dangerous. Closes #93228 --- .../xpack/ml/integration/RegressionIT.java | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 8a6499ec3bb6a..a06b85131814d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -27,12 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.hamcrest.Matchers; import org.junit.After; @@ -540,7 +536,6 @@ public void testWithDatastream() throws Exception { ); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/93228") public void testAliasFields() throws Exception { // The goal of this test is to assert alias fields are included in the analytics job. // We have a simple dataset with two integer fields: field_1 and field_2. @@ -585,10 +580,32 @@ public void testAliasFields() throws Exception { // Very infrequently this test may fail as the algorithm underestimates the // required number of trees for this simple problem. This failure is irrelevant // for non-trivial real-world problem and improving estimation of the number of trees - // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed. + // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed + // and use the hyperparameters that are known to work. long seed = 1000L; // fix seed - Regression regression = new Regression("field_2", BoostedTreeParams.builder().build(), null, 90.0, seed, null, null, null, null); + Regression regression = new Regression( + "field_2", + BoostedTreeParams.builder() + .setDownsampleFactor(0.7520841625652861) + .setAlpha(547.9095715556235) + .setLambda(3.3008189603590044) + .setGamma(1.6082763366825203) + .setSoftTreeDepthLimit(4.733224114945455) + .setSoftTreeDepthTolerance(0.15) + .setEta(0.12371209659057758) + .setEtaGrowthRatePerTree(1.0618560482952888) + .setMaxTrees(30) + .setFeatureBagFraction(0.8) + .build(), + null, + 90.0, + seed, + null, + null, + null, + null + ); DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder().setId(jobId) .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null, Collections.emptyMap())) .setDest(new DataFrameAnalyticsDest(destIndex, null)) @@ -604,19 +621,6 @@ public void testAliasFields() throws Exception { waitUntilAnalyticsIsStopped(jobId); - // obtain addition information for investigation of #90599 - String modelId = getModelId(jobId); - TrainedModelMetadata modelMetadata = getModelMetadata(modelId); - assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0)); - StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90599 - for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) { - hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n"); - } - TrainedModelDefinition modelDefinition = getModelDefinition(modelId); - Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel(); - int numberTrees = ensemble.getModels().size(); - - StringBuilder targetsPredictions = new StringBuilder(); // used to investigate #90599 assertResponse(prepareSearch(sourceIndex).setSize(totalDocCount), sourceData -> { double predictionErrorSum = 0.0; for (SearchHit hit : sourceData.getHits()) { @@ -629,19 +633,12 @@ public void testAliasFields() throws Exception { int featureValue = (int) destDoc.get("field_1"); double predictionValue = (double) resultsObject.get(predictionField); predictionErrorSum += Math.abs(predictionValue - 2 * featureValue); - - // collect the log of targets and predictions for debugging #90599 - targetsPredictions.append(2 * featureValue).append(", ").append(predictionValue).append("\n"); } // We assert on the mean prediction error in order to reduce the probability // the test fails compared to asserting on the prediction of each individual doc. double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length; String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n"; - assertThat( - Strings.format(str, seed, modelId, numberTrees) + targetsPredictions + hyperparameters, - meanPredictionError, - lessThanOrEqualTo(3.0) - ); + assertThat(meanPredictionError, lessThanOrEqualTo(3.0)); }); assertProgressComplete(jobId);