Skip to content

Commit 28e81b2

Browse files
authored
Merge pull request #29 from tqtg/master
Add tests for base_strategy and ratio
2 parents d361d42 + 023bb81 commit 28e81b2

File tree

6 files changed

+102
-11
lines changed

6 files changed

+102
-11
lines changed

cornac/eval_strategies/ratio_split.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
class RatioSplit(BaseStrategy):
15-
1615
"""Train-Test Split Evaluation Strategy.
1716
1817
Parameters
@@ -48,15 +47,15 @@ class RatioSplit(BaseStrategy):
4847
Output running log
4948
"""
5049

51-
def __init__(self, data, data_format='UIR', val_size=0.0, test_size=0.2, rating_threshold=1., shuffle=True, random_state=None,
52-
exclude_unknowns=False, verbose=False):
53-
BaseStrategy.__init__(self, data = data, data_format='UIR', rating_threshold=rating_threshold, exclude_unknowns=exclude_unknowns, verbose=verbose)
50+
def __init__(self, data, data_format='UIR', val_size=0.0, test_size=0.2, rating_threshold=1., shuffle=True,
51+
random_state=None, exclude_unknowns=False, verbose=False):
52+
BaseStrategy.__init__(self, data=data, data_format=data_format, rating_threshold=rating_threshold,
53+
exclude_unknowns=exclude_unknowns, verbose=verbose)
5454

5555
self._shuffle = shuffle
5656
self._random_state = random_state
5757
self._train_size, self._val_size, self._test_size = self._validate_sizes(val_size, test_size, len(self._data))
58-
self._split_run = False
59-
58+
self._split_ran = False
6059

6160

6261
@staticmethod
@@ -93,6 +92,11 @@ def _validate_sizes(val_size, test_size, num_ratings):
9392

9493

9594
def split(self):
95+
if self._split_ran:
96+
if self.verbose:
97+
print('Data is already split!')
98+
return
99+
96100
if self.verbose:
97101
print("Splitting the data")
98102

@@ -114,15 +118,13 @@ def split(self):
114118
if self._data_format == 'UIR':
115119
self.build_from_uir_format(train_data, val_data, test_data)
116120

117-
self._split_run = True
121+
self._split_ran = True
118122

119123
if self.verbose:
120124
print('Total users = {}'.format(self.total_users))
121125
print('Total items = {}'.format(self.total_items))
122126

123127

124128
def evaluate(self, model, metrics, user_based):
125-
if not self._split_run:
126-
self.split()
127-
129+
self.split()
128130
return BaseStrategy.evaluate(self, model, metrics, user_based)

cornac/eval_strategies/tests/__init__.py

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
@author: Quoc-Tuan Truong <tuantq.vnu@gmail.com>
5+
"""
6+
7+
from ..base_strategy import BaseStrategy
8+
9+
10+
def test_init():
11+
bs = BaseStrategy(None, verbose=True)
12+
13+
assert not bs.exclude_unknowns
14+
assert 1. == bs.rating_threshold
15+
16+
17+
def test_trainset_none():
18+
bs = BaseStrategy(None, verbose=True)
19+
20+
try:
21+
bs.evaluate(None, {}, False)
22+
except ValueError:
23+
assert True
24+
25+
26+
def test_testset_none():
27+
bs = BaseStrategy(None, train_set=[], verbose=True)
28+
29+
try:
30+
bs.evaluate(None, {}, False)
31+
except ValueError:
32+
assert True
33+
34+
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
@author: Quoc-Tuan Truong <tuantq.vnu@gmail.com>
5+
"""
6+
7+
from ..ratio_split import RatioSplit
8+
9+
10+
def test_validate_size():
11+
train_size, val_size, test_size = RatioSplit._validate_sizes(0.1, 0.2, 10)
12+
assert 7 == train_size
13+
assert 1 == val_size
14+
assert 2 == test_size
15+
16+
train_size, val_size, test_size = RatioSplit._validate_sizes(None, 0.5, 10)
17+
assert 5 == train_size
18+
assert 0 == val_size
19+
assert 5 == test_size
20+
21+
train_size, val_size, test_size = RatioSplit._validate_sizes(None, None, 10)
22+
assert 10 == train_size
23+
assert 0 == val_size
24+
assert 0 == test_size
25+
26+
train_size, val_size, test_size = RatioSplit._validate_sizes(2, 2, 10)
27+
assert 6 == train_size
28+
assert 2 == val_size
29+
assert 2 == test_size
30+
31+
try:
32+
RatioSplit._validate_sizes(-1, 0.2, 10)
33+
except ValueError:
34+
assert True
35+
36+
try:
37+
RatioSplit._validate_sizes(11, 0.2, 10)
38+
except ValueError:
39+
assert True
40+
41+
try:
42+
RatioSplit._validate_sizes(0, 11, 10)
43+
except ValueError:
44+
assert True
45+
46+
try:
47+
RatioSplit._validate_sizes(3, 8, 10)
48+
except ValueError:
49+
assert True

cornac/experiment/tests/test_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_with_ratio_split():
1313
from ..experiment import Experiment
1414

1515
data = reader.txt_to_uir_triplets('./cornac/data/tests/data.txt')
16-
exp = Experiment(eval_strategy=RatioSplit(data),
16+
exp = Experiment(eval_strategy=RatioSplit(data, verbose=True),
1717
models=[PMF(1, 0)],
1818
metrics=[MAE(), RMSE(), Recall(1), FMeasure(1)],
1919
verbose=True)

cornac/utils/generic_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def safe_indexing(X, indices):
8787

8888

8989
def validate_data_format(data_format):
90+
"""Check the input data format is supported or not
91+
- UIR: (user, item, rating) triplet data
92+
- UIRT: (user, item , rating, timestamp) quadruplet data
93+
94+
:raise ValueError if not supported
95+
"""
9096
data_format = str(data_format).upper()
9197
if not data_format in ['UIR', 'UIRT']:
9298
raise ValueError('{} data format is not supported!'.format(data_format))

0 commit comments

Comments
 (0)