Skip to content

Commit 4d3bb06

Browse files
authored
fix weird edge case in iterative CV (#1121)
1 parent f518e9a commit 4d3bb06

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

autosklearn/evaluation/train_evaluator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,21 +386,23 @@ def fit_predict_and_loss(self, iterative: bool = False) -> None:
386386

387387
# Compute weights of each fold based on the number of samples in each
388388
# fold.
389-
train_fold_weights = [w / sum(train_fold_weights)
390-
for w in train_fold_weights]
391-
opt_fold_weights = [w / sum(opt_fold_weights)
392-
for w in opt_fold_weights]
389+
train_fold_weights_percentage = [
390+
w / sum(train_fold_weights) for w in train_fold_weights
391+
]
392+
opt_fold_weights_percentage = [
393+
w / sum(opt_fold_weights) for w in opt_fold_weights
394+
]
393395

394396
# train_losses is a list of either scalars or dicts. If it contains
395397
# dicts, then train_loss is computed using the target metric
396398
# (self.metric).
397399
if all(isinstance(elem, dict) for elem in train_losses):
398400
train_loss = np.average([train_losses[i][str(self.metric)]
399401
for i in range(self.num_cv_folds)],
400-
weights=train_fold_weights,
402+
weights=train_fold_weights_percentage,
401403
)
402404
else:
403-
train_loss = np.average(train_losses, weights=train_fold_weights)
405+
train_loss = np.average(train_losses, weights=train_fold_weights_percentage)
404406

405407
# if all_scoring_function is true, return a dict of opt_loss.
406408
# Otherwise, return a scalar.
@@ -412,10 +414,10 @@ def fit_predict_and_loss(self, iterative: bool = False) -> None:
412414
opt_losses[i][metric]
413415
for i in range(self.num_cv_folds)
414416
],
415-
weights=opt_fold_weights,
417+
weights=opt_fold_weights_percentage,
416418
)
417419
else:
418-
opt_loss = np.average(opt_losses, weights=opt_fold_weights)
420+
opt_loss = np.average(opt_losses, weights=opt_fold_weights_percentage)
419421

420422
Y_targets = self.Y_targets
421423
Y_train_targets = self.Y_train_targets

0 commit comments

Comments
 (0)