Skip to content

Commit e49602c

Browse files
committed
DataSampler: Fix Bootstrap signature
1 parent 4dad547 commit e49602c

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

Orange/widgets/data/owdatasampler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class OWDataSampler(OWWidget):
4747
selectedFold = Setting(1)
4848

4949
class Warning(OWWidget.Warning):
50-
could_not_stratify = Msg("Could not stratify\n{}")
50+
could_not_stratify = Msg("Stratification failed\n{}")
5151

5252
class Error(OWWidget.Error):
5353
too_many_folds = Msg("Number of folds exceeds data size")
@@ -385,7 +385,14 @@ def __init__(self, size=0, random_state=None):
385385
self.size = size
386386
self.random_state = random_state
387387

388-
def __call__(self):
388+
def __call__(self, table=None):
389+
"""Bootstrap indices
390+
391+
Args:
392+
table: Not used (but part of the signature)
393+
Returns:
394+
tuple (out_of_sample, sample) indices
395+
"""
389396
rgen = np.random.RandomState(self.random_state)
390397
sample = rgen.randint(0, self.size, self.size)
391398
sample.sort() # not needed for the code below, just for the user

Orange/widgets/data/tests/test_owdatasampler.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def setUpClass(cls):
1212
cls.iris = Table("iris")
1313

1414
def setUp(self):
15-
self.widget = self.create_widget(OWDataSampler)
15+
self.widget = self.create_widget(OWDataSampler) # type: OWDataSampler
1616

1717
def test_error_message(self):
1818
""" Check if error message appears and then disappears when
@@ -33,3 +33,29 @@ def test_stratified_on_unbalanced_data(self):
3333
self.widget.controls.stratify.setChecked(True)
3434
self.send_signal("Data", unbalanced_data)
3535
self.assertTrue(self.widget.Warning.could_not_stratify.is_shown())
36+
37+
def test_bootstrap(self):
38+
self.select_sampling_type(self.widget.Bootstrap)
39+
40+
self.send_signal("Data", self.iris)
41+
42+
in_input = set(self.iris.ids)
43+
sample = self.get_output("Data Sample")
44+
in_sample = set(sample.ids)
45+
in_remaining = set(self.get_output("Remaining Data").ids)
46+
47+
# Bootstrap should sample len(input) instances
48+
self.assertEqual(len(sample), len(self.iris))
49+
# Sample and remaining should cover all instances, while none
50+
# should be present in both
51+
self.assertEqual(len(in_sample | in_remaining), len(in_input))
52+
self.assertEqual(len(in_sample & in_remaining), 0)
53+
# Sampling with replacement will always produce at least one distinct
54+
# instance in sample, and at least one instance in remaining with
55+
# high probability (1-(1/150*2/150*...*145/150) ~= 1-2e-64)
56+
self.assertGreater(len(in_sample), 0)
57+
self.assertGreater(len(in_remaining), 0)
58+
59+
def select_sampling_type(self, sampling_type):
60+
buttons = self.widget.controls.sampling_type.group.buttons()
61+
buttons[sampling_type].click()

0 commit comments

Comments
 (0)