Skip to content

Commit 8bb2491

Browse files
authored
Merge pull request #2348 from jerneju/various-testlearners
[FIX] Test & Learn: number of folds causes many errors
2 parents 554b885 + ac7c093 commit 8bb2491

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

Orange/evaluation/testing.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,11 @@ def get_fold(self, fold):
228228

229229
return results
230230

231-
def get_augmented_data(self, model_names, include_attrs=True, include_predictions=True, include_probabilities=True):
231+
def get_augmented_data(self, model_names, include_attrs=True, include_predictions=True,
232+
include_probabilities=True):
232233
"""
233-
Return the data, augmented with predictions, probabilities (if the task is classification) and folds info.
234-
Predictions, probabilities and folds are inserted as meta attributes.
234+
Return the data, augmented with predictions, probabilities (if the task is classification)
235+
and folds info. Predictions, probabilities and folds are inserted as meta attributes.
235236
236237
Args:
237238
model_names (list): A list of strings containing learners' names.
@@ -276,7 +277,9 @@ def get_augmented_data(self, model_names, include_attrs=True, include_prediction
276277

277278
# add fold info
278279
if self.folds is not None:
279-
new_meta_attr.append(DiscreteVariable(name="Fold", values=[i+1 for i, s in enumerate(self.folds)]))
280+
new_meta_attr.append(
281+
DiscreteVariable(name="Fold",
282+
values=[str(i+1) for i, _ in enumerate(self.folds)]))
280283
fold = np.empty((len(data), 1))
281284
for i, s in enumerate(self.folds):
282285
fold[s, 0] = i
@@ -356,7 +359,7 @@ def prepare_arrays(self, test_data):
356359
row_indices = []
357360

358361
ptr = 0
359-
for train, test in self.indices:
362+
for _, test in self.indices:
360363
self.folds.append(slice(ptr, ptr + len(test)))
361364
row_indices.append(test)
362365
ptr += len(test)
@@ -459,7 +462,7 @@ class CrossValidationFeature(Results):
459462
460463
"""
461464
def __init__(self, data, learners, feature, store_data=False, store_models=False,
462-
preprocessor=None, callback=None, n_jobs=1):
465+
preprocessor=None, callback=None, n_jobs=1):
463466
self.feature = feature
464467
super().__init__(data, learners=learners, store_data=store_data,
465468
store_models=store_models, preprocessor=preprocessor,
@@ -570,7 +573,7 @@ def __init__(self, data, learners, store_data=False, store_models=False,
570573

571574

572575
def sample(table, n=0.7, stratified=False, replace=False,
573-
random_state=None):
576+
random_state=None):
574577
"""
575578
Samples data instances from a data table. Returns the sample and
576579
a data set from input data table that are not in the sample. Also

0 commit comments

Comments
 (0)