diff --git a/xgboostlss/model.py b/xgboostlss/model.py index 131dd822..baf43820 100644 --- a/xgboostlss/model.py +++ b/xgboostlss/model.py @@ -319,6 +319,7 @@ def hyper_opt( dtrain: DMatrix, num_boost_round=500, nfold=10, + folds=None, early_stopping_rounds=20, max_minutes=10, n_trials=None, @@ -340,6 +341,13 @@ def hyper_opt( Number of boosting iterations. nfold: int Number of folds in CV. + folds: a KFold or StratifiedKFold instance or list of fold indices + Sklearn KFolds or StratifiedKFolds object. + Alternatively may explicitly pass sample indices for each fold. + For ``n`` folds, **folds** should be a length ``n`` list of tuples. + Each tuple is ``(in,out)`` where ``in`` is a list of indices to be used + as the training samples for the ``n`` th fold and ``out`` is a list of + indices to be used as the testing samples for the ``n`` th fold. early_stopping_rounds: int Activates early stopping. Cross-Validation metric (average of validation metric computed over CV folds) needs to improve at least once in @@ -417,7 +425,8 @@ def objective(trial): early_stopping_rounds=early_stopping_rounds, callbacks=[pruning_callback], seed=seed, - verbose_eval=False + verbose_eval=False, + folds=folds, ) # Add the optimal number of rounds