@@ -423,7 +423,7 @@ def _build_interactions(self, native_ebm):
423423
424424 def _fit_main (self , native_ebm , main_attr_sets ):
425425 log .debug ("Train main effects" )
426- self .current_metric_ = self ._cyclic_gradient_boost (
426+ self .current_metric_ , self . main_episode_idx_ = self ._cyclic_gradient_boost (
427427 native_ebm , main_attr_sets , "Main"
428428 )
429429 log .debug ("Main Metric: {0}" .format (self .current_metric_ ))
@@ -438,11 +438,13 @@ def _fit_main(self, native_ebm, main_attr_sets):
438438 def staged_fit_interactions (self , X , y , inter_indices = []):
439439 check_is_fitted (self , "has_fitted_" )
440440
441- log .debug ("Train interactions" )
442-
441+ self .inter_episode_idx_ = 0
443442 if len (inter_indices ) == 0 :
443+ log .debug ("No interactions to train" )
444444 return self
445445
446+ log .debug ("Training interactions" )
447+
446448 # Split data into train/val
447449 X_train , X_val , y_train , y_val = train_test_split (
448450 X ,
@@ -488,7 +490,7 @@ def staged_fit_interactions(self, X, y, inter_indices=[]):
488490 )
489491 ) as native_ebm :
490492 log .debug ("Train interactions" )
491- self .current_metric_ = self ._cyclic_gradient_boost (
493+ self .current_metric_ , self . inter_episode_idx_ = self ._cyclic_gradient_boost (
492494 native_ebm , inter_attr_sets , "Pair"
493495 )
494496 log .debug ("Interaction Metric: {0}" .format (self .current_metric_ ))
@@ -513,15 +515,17 @@ def _cyclic_gradient_boost(self, native_ebm, attribute_sets, name=None):
513515 min_metric = np .inf
514516 bp_metric = np .inf
515517 log .debug ("Start boosting {0}" .format (name ))
518+ curr_episode_index = 0
516519 for data_episode_index in range (self .data_n_episodes ):
520+ curr_episode_index = data_episode_index
521+
517522 if data_episode_index % 10 == 0 :
518523 log .debug ("Sweep Index for {0}: {1}" .format (name , data_episode_index ))
519524 log .debug ("Metric: {0}" .format (curr_metric ))
520525
521526 if len (attribute_sets ) == 0 :
522527 log .debug ("No sets to boost for {0}" .format (name ))
523528
524- log .debug ("Start boosting {0}" .format (name ))
525529 for index , attribute_set in enumerate (attribute_sets ):
526530 curr_metric = native_ebm .training_step (
527531 index ,
@@ -533,6 +537,7 @@ def _cyclic_gradient_boost(self, native_ebm, attribute_sets, name=None):
533537 validation_weights = 0 ,
534538 )
535539
540+ # NOTE: Out of per-feature boosting on purpose.
536541 min_metric = min (curr_metric , min_metric )
537542
538543 if no_change_run_length == 0 :
@@ -541,12 +546,16 @@ def _cyclic_gradient_boost(self, native_ebm, attribute_sets, name=None):
541546 no_change_run_length = 0
542547 else :
543548 no_change_run_length += 1
544- if no_change_run_length >= self .early_stopping_run_length :
549+
550+ if (
551+ self .early_stopping_run_length >= 0
552+ and no_change_run_length >= self .early_stopping_run_length
553+ ):
545554 log .debug ("Early break {0}: {1}" .format (name , data_episode_index ))
546555 break
547556 log .debug ("End boosting {0}" .format (name ))
548557
549- return curr_metric
558+ return curr_metric , curr_episode_index
550559
551560
552561class CoreEBMClassifier (BaseCoreEBM , ClassifierMixin ):
@@ -826,6 +835,13 @@ def staged_fit_fn(estimator, X, y, inter_indices=[]):
826835 self .attribute_set_models_ .append (averaged_model )
827836 self .model_errors_ .append (model_errors )
828837
838+ # Get episode indexes for base estimators.
839+ self .main_episode_idxs_ = []
840+ self .inter_episode_idxs_ = []
841+ for estimator in estimators :
842+ self .main_episode_idxs_ .append (estimator .main_episode_idx_ )
843+ self .inter_episode_idxs_ .append (estimator .inter_episode_idx_ )
844+
829845 # Extract feature names and feature types.
830846 self .feature_names = []
831847 self .feature_types = []
@@ -844,6 +860,8 @@ def staged_fit_fn(estimator, X, y, inter_indices=[]):
844860 X , self .attribute_sets_ , self .attribute_set_models_ , []
845861 )
846862 self ._attrib_set_model_means_ = []
863+
864+ # TODO: Clean this up before release.
847865 for set_idx , attribute_set , scores in scores_gen :
848866 score_mean = np .mean (scores )
849867
0 commit comments