27
27
import org .elasticsearch .xpack .core .ml .dataframe .analyses .BoostedTreeParams ;
28
28
import org .elasticsearch .xpack .core .ml .dataframe .analyses .Regression ;
29
29
import org .elasticsearch .xpack .core .ml .inference .TrainedModelConfig ;
30
- import org .elasticsearch .xpack .core .ml .inference .TrainedModelDefinition ;
31
30
import org .elasticsearch .xpack .core .ml .inference .preprocessing .OneHotEncoding ;
32
31
import org .elasticsearch .xpack .core .ml .inference .preprocessing .PreProcessor ;
33
- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ensemble .Ensemble ;
34
- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .metadata .Hyperparameters ;
35
- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .metadata .TrainedModelMetadata ;
36
32
import org .hamcrest .Matchers ;
37
33
import org .junit .After ;
38
34
@@ -540,7 +536,6 @@ public void testWithDatastream() throws Exception {
540
536
);
541
537
}
542
538
543
- @ AwaitsFix (bugUrl = "https://github.com/elastic/elasticsearch/issues/93228" )
544
539
public void testAliasFields () throws Exception {
545
540
// The goal of this test is to assert alias fields are included in the analytics job.
546
541
// We have a simple dataset with two integer fields: field_1 and field_2.
@@ -585,10 +580,32 @@ public void testAliasFields() throws Exception {
585
580
// Very infrequently this test may fail as the algorithm underestimates the
586
581
// required number of trees for this simple problem. This failure is irrelevant
587
582
// for non-trivial real-world problem and improving estimation of the number of trees
588
- // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed.
583
+ // would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed
584
+ // and use the hyperparameters that are known to work.
589
585
long seed = 1000L ; // fix seed
590
586
591
- Regression regression = new Regression ("field_2" , BoostedTreeParams .builder ().build (), null , 90.0 , seed , null , null , null , null );
587
+ Regression regression = new Regression (
588
+ "field_2" ,
589
+ BoostedTreeParams .builder ()
590
+ .setDownsampleFactor (0.7520841625652861 )
591
+ .setAlpha (547.9095715556235 )
592
+ .setLambda (3.3008189603590044 )
593
+ .setGamma (1.6082763366825203 )
594
+ .setSoftTreeDepthLimit (4.733224114945455 )
595
+ .setSoftTreeDepthTolerance (0.15 )
596
+ .setEta (0.12371209659057758 )
597
+ .setEtaGrowthRatePerTree (1.0618560482952888 )
598
+ .setMaxTrees (30 )
599
+ .setFeatureBagFraction (0.8 )
600
+ .build (),
601
+ null ,
602
+ 90.0 ,
603
+ seed ,
604
+ null ,
605
+ null ,
606
+ null ,
607
+ null
608
+ );
592
609
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig .Builder ().setId (jobId )
593
610
.setSource (new DataFrameAnalyticsSource (new String [] { sourceIndex }, null , null , Collections .emptyMap ()))
594
611
.setDest (new DataFrameAnalyticsDest (destIndex , null ))
@@ -604,19 +621,6 @@ public void testAliasFields() throws Exception {
604
621
605
622
waitUntilAnalyticsIsStopped (jobId );
606
623
607
- // obtain addition information for investigation of #90599
608
- String modelId = getModelId (jobId );
609
- TrainedModelMetadata modelMetadata = getModelMetadata (modelId );
610
- assertThat (modelMetadata .getHyperparameters ().size (), greaterThan (0 ));
611
- StringBuilder hyperparameters = new StringBuilder (); // used to investigate #90599
612
- for (Hyperparameters hyperparameter : modelMetadata .getHyperparameters ()) {
613
- hyperparameters .append (hyperparameter .hyperparameterName ).append (": " ).append (hyperparameter .value ).append ("\n " );
614
- }
615
- TrainedModelDefinition modelDefinition = getModelDefinition (modelId );
616
- Ensemble ensemble = (Ensemble ) modelDefinition .getTrainedModel ();
617
- int numberTrees = ensemble .getModels ().size ();
618
-
619
- StringBuilder targetsPredictions = new StringBuilder (); // used to investigate #90599
620
624
assertResponse (prepareSearch (sourceIndex ).setSize (totalDocCount ), sourceData -> {
621
625
double predictionErrorSum = 0.0 ;
622
626
for (SearchHit hit : sourceData .getHits ()) {
@@ -629,19 +633,12 @@ public void testAliasFields() throws Exception {
629
633
int featureValue = (int ) destDoc .get ("field_1" );
630
634
double predictionValue = (double ) resultsObject .get (predictionField );
631
635
predictionErrorSum += Math .abs (predictionValue - 2 * featureValue );
632
-
633
- // collect the log of targets and predictions for debugging #90599
634
- targetsPredictions .append (2 * featureValue ).append (", " ).append (predictionValue ).append ("\n " );
635
636
}
636
637
// We assert on the mean prediction error in order to reduce the probability
637
638
// the test fails compared to asserting on the prediction of each individual doc.
638
639
double meanPredictionError = predictionErrorSum / sourceData .getHits ().getHits ().length ;
639
640
String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n " ;
640
- assertThat (
641
- Strings .format (str , seed , modelId , numberTrees ) + targetsPredictions + hyperparameters ,
642
- meanPredictionError ,
643
- lessThanOrEqualTo (3.0 )
644
- );
641
+ assertThat (meanPredictionError , lessThanOrEqualTo (3.0 ));
645
642
});
646
643
647
644
assertProgressComplete (jobId );
0 commit comments