Skip to content

Commit 8ca36a3

Browse files
authored
Merge pull request #1995 from VesnaT/fix_ds
[FIX] OWDataSampler: Fix 'Fixed proportion of data' option
2 parents a9c1d28 + 162b0c6 commit 8ca36a3

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

Orange/widgets/data/owdatasampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def __init__(self, p=0, stratified=False, random_state=None):
377377

378378
def __call__(self, table):
379379
n = int(math.ceil(len(table) * self.p))
380-
return SampleRandomN(n, self.stratified, self.random_state)(table)
380+
return SampleRandomN(n, self.stratified,
381+
random_state=self.random_state)(table)
381382

382383

383384
class SampleBootstrap(Reprable):

Orange/widgets/data/tests/test_owdatasampler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class TestOWDataSampler(WidgetTest):
1010
def setUpClass(cls):
1111
super().setUpClass()
1212
cls.iris = Table("iris")
13+
cls.zoo = Table("zoo")
1314

1415
def setUp(self):
1516
self.widget = self.create_widget(OWDataSampler) # type: OWDataSampler
@@ -59,3 +60,27 @@ def test_bootstrap(self):
5960
def select_sampling_type(self, sampling_type):
6061
buttons = self.widget.controls.sampling_type.group.buttons()
6162
buttons[sampling_type].click()
63+
64+
def test_no_intersection_in_outputs(self):
65+
""" Check whether outputs intersect and whether length of outputs sums
66+
to length of original data """
67+
self.send_signal("Data", self.zoo)
68+
w = self.widget
69+
sampling_types = [w.FixedProportion, w.FixedSize, w.CrossValidation]
70+
71+
for replicable in [True, False]:
72+
for stratified in [True, False]:
73+
for sampling_type in sampling_types:
74+
self.widget.cb_seed.setChecked(replicable)
75+
self.widget.cb_stratify.setChecked(stratified)
76+
self.select_sampling_type(sampling_type)
77+
self.widget.commit()
78+
79+
sample = self.get_output("Data Sample")
80+
other = self.get_output("Remaining Data")
81+
self.assertEqual(len(self.zoo), len(sample) + len(other))
82+
self.assertNoIntersection(sample, other)
83+
84+
def assertNoIntersection(self, sample, other):
85+
for inst in sample:
86+
self.assertNotIn(inst, other)

0 commit comments

Comments
 (0)