Skip to content

Commit 2365250

Browse files
committed
added Instances.cv_splits method to generate list of train/test tuples as used by cross-validation
1 parent 8dd162f commit 2365250

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Changelog
66

77
- added `logging_level` parameter to the `start` method of the `weka.core.jvm` module, enabling the user
88
to turn off debugging output in an easy way (https://github.com/fracpete/python-weka-wrapper3/issues/40)
9+
- added method `cv_splits` to class `Instances` from module `weka.core.dataset` to return a list of
10+
train/test tuples as used by cross-validation
911
- ...
1012

1113

python/weka/core/dataset.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import javabridge
1818
import logging
1919
import numpy as np
20-
from weka.core.classes import JavaObject
20+
from weka.core.classes import JavaObject, Random
2121
import weka.core.typeconv as typeconv
2222

2323
# logging setup
@@ -558,6 +558,36 @@ def train_test_split(self, percentage, rnd=None):
558558
test_inst = Instances.copy_instances(data, train_size, test_size)
559559
return train_inst, test_inst
560560

561+
def cv_splits(self, folds=10, rnd=None, stratify=True):
562+
"""
563+
Generates a list of train/test pairs used in cross-validation.
564+
Creates a copy of the dataset beforehand when randomizing.
565+
566+
:param folds: the number of folds to use, >= 2
567+
:type folds: int
568+
:param rnd: the random number generator to use for randomization, skips randomization if None
569+
:type rnd: Random
570+
:param stratify: whether to stratify the data after randomization
571+
:type stratify: bool
572+
:return: the list of train/test split tuples
573+
:rtype: list
574+
"""
575+
result = []
576+
if rnd is not None:
577+
data = Instances.copy_instances(self)
578+
data.randomize(rnd)
579+
if stratify:
580+
data.stratify(folds)
581+
else:
582+
data = self
583+
584+
for i in range(folds):
585+
train = data.train_cv(folds, i, random=rnd)
586+
test = data.test_cv(folds, i)
587+
result.append((train, test))
588+
589+
return result
590+
561591
@classmethod
562592
def summary(cls, inst):
563593
"""

0 commit comments

Comments
 (0)