Skip to content

Commit 45c38b3

Browse files
authored
Fix test for datasets submodule (#151)
* 🐛 Fix bug incorrect value * ♻️ Add test for fetch_megafon * 📝 Fix example in docstring fetch_megafon * 📝 Fix megafon test * 📝 Fix ci pipe * 📝 Fix ci pipe: add req.txt while sphinx build * 📝 Clear test dir
1 parent 32880af commit 45c38b3

File tree

3 files changed

+84
-36
lines changed

3 files changed

+84
-36
lines changed

.github/workflows/ci-test.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ name: Python package
22

33
on:
44
push:
5+
branches: [ master ]
56
pull_request_target:
67

78
jobs:
@@ -10,12 +11,13 @@ jobs:
1011
runs-on: ${{ matrix.os }}
1112
env:
1213
USING_COVERAGE_PY: '3.8'
13-
USING_COVERAGE_OS: 'ubuntu-latest'
14+
USING_COVERAGE_OS: 'macos-latest'
1415

1516
strategy:
1617
matrix:
1718
os: ['ubuntu-latest', 'windows-latest', 'macos-latest']
1819
python-version: ['3.6', '3.7', '3.8', '3.9']
20+
platform: 'x64'
1921
fail-fast: false
2022

2123
steps:
@@ -51,6 +53,6 @@ jobs:
5153
- name: Update pip
5254
run: python -m pip install --upgrade pip
5355
- name: Install dependencies
54-
run: pip install -r docs/requirements.txt
56+
run: pip install -r docs/requirements.txt -r requirements.txt
5557
- name: Run Sphinx
56-
run: sphinx-build -b html docs /tmp/_docs_build
58+
run: sphinx-build -W -b html docs /tmp/_docs_build

sklift/datasets/datasets.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,15 @@ def fetch_criteo(target_col='visit', treatment_col='treatment', data_home=None,
381381
if treatment_col == 'all':
382382
treatment_col = treatment_cols
383383
elif treatment_col not in treatment_cols:
384-
raise ValueError(f"treatment_col value must be in {treatment_cols + ['all']}. "
385-
f"Got value {treatment_col}.")
384+
raise ValueError(f"The treatment_col must be an element of {treatment_cols + ['all']}. "
385+
f"Got value target_col={treatment_col}.")
386386

387387
target_cols = ['visit', 'conversion']
388388
if target_col == 'all':
389389
target_col = target_cols
390390
elif target_col not in target_cols:
391-
raise ValueError(f"target_col value must be from {target_cols + ['all']}. "
392-
f"Got value {target_col}.")
391+
raise ValueError(f"The target_col must be an element of {target_cols + ['all']}. "
392+
f"Got value target_col={target_col}.")
393393

394394
if percent10:
395395
url = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
@@ -494,8 +494,8 @@ def fetch_hillstrom(target_col='visit', data_home=None, dest_subdir=None, downlo
494494
if target_col == 'all':
495495
target_col = target_cols
496496
elif target_col not in target_cols:
497-
raise ValueError(f"target_col value must be from {target_cols + ['all']}. "
498-
f"Got value {target_col + ['all']}.")
497+
raise ValueError(f"The target_col must be an element of {target_cols + ['all']}. "
498+
f"Got value target_col={target_col}.")
499499

500500
url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'
501501
filename = url.split('/')[-1]
@@ -566,7 +566,7 @@ def fetch_megafon(data_home=None, dest_subdir=None, download_if_missing=True,
566566
567567
568568
dataset = fetch_megafon()
569-
data, treatment, target = dataset.data, dataset.treatment, dataset.target
569+
data, target, treatment = dataset.data, dataset.target, dataset.treatment
570570
571571
# alternative option
572572
data, target, treatment = fetch_megafon(return_X_y_t=True)

sklift/tests/test_datasets.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33

44
from functools import partial
55

6-
from ..datasets import (fetch_lenta, fetch_x5,
7-
fetch_criteo, fetch_hillstrom)
6+
from ..datasets import (
7+
clear_data_dir,
8+
fetch_lenta, fetch_x5,
9+
fetch_criteo, fetch_hillstrom,
10+
fetch_megafon
11+
)
812

913

1014
fetch_criteo10 = partial(fetch_criteo, percent10=True)
1115

12-
13-
def check_return_X_y_t(bunch, dataset_func):
14-
X_y_t_tuple = dataset_func(return_X_y_t=True)
15-
assert isinstance(X_y_t_tuple, tuple)
16-
assert X_y_t_tuple[0].shape == bunch.data.shape
17-
assert X_y_t_tuple[1].shape == bunch.target.shape
18-
assert X_y_t_tuple[2].shape == bunch.treatment.shape
16+
@pytest.fixture(scope="session", autouse=True)
17+
def clear():
18+
# prepare something ahead of all tests
19+
clear_data_dir()
1920

2021

2122
@pytest.fixture
@@ -53,20 +54,11 @@ def test_fetch_x5(x5_dataset):
5354
assert data.treatment.shape == x5_dataset['treatment.shape']
5455

5556

56-
@pytest.mark.parametrize(
57-
'target_col, target_shape',
58-
[('visit', (64_000,)),
59-
('conversion', (64_000,)),
60-
('spend', (64_000,)),
61-
('all', (64_000, 3))]
62-
)
63-
def test_fetch_hillstrom(
64-
target_col, target_shape
65-
):
66-
data = fetch_hillstrom(target_col=target_col)
67-
assert data.data.shape == (64_000, 8)
68-
assert data.target.shape == target_shape
69-
assert data.treatment.shape == (64_000,)
57+
@pytest.fixture
58+
def criteo10_dataset() -> dict:
59+
data = {'keys': ['data', 'target', 'treatment', 'DESCR', 'feature_names', 'target_name', 'treatment_name'],
60+
'data.shape': (1397960, 12)}
61+
return data
7062

7163

7264
@pytest.mark.parametrize(
@@ -82,15 +74,69 @@ def test_fetch_hillstrom(
8274
('all', (1397960, 2))]
8375
)
8476
def test_fetch_criteo10(
85-
target_col, target_shape, treatment_col, treatment_shape
77+
criteo10_dataset,
78+
target_col, target_shape,
79+
treatment_col, treatment_shape
8680
):
8781
data = fetch_criteo10(target_col=target_col, treatment_col=treatment_col)
88-
assert data.data.shape == (1397960, 12)
82+
assert isinstance(data, sklearn.utils.Bunch)
83+
assert set(data.keys()) == set(criteo10_dataset['keys'])
84+
assert data.data.shape == criteo10_dataset['data.shape']
8985
assert data.target.shape == target_shape
9086
assert data.treatment.shape == treatment_shape
9187

9288

93-
@pytest.mark.parametrize("fetch_func", [fetch_hillstrom, fetch_criteo10, fetch_lenta])
89+
@pytest.fixture
90+
def hillstrom_dataset() -> dict:
91+
data = {'keys': ['data', 'target', 'treatment', 'DESCR', 'feature_names', 'target_name', 'treatment_name'],
92+
'data.shape': (64000, 8), 'treatment.shape': (64000,)}
93+
return data
94+
95+
96+
@pytest.mark.parametrize(
97+
'target_col, target_shape',
98+
[('visit', (64_000,)),
99+
('conversion', (64_000,)),
100+
('spend', (64_000,)),
101+
('all', (64_000, 3))]
102+
)
103+
def test_fetch_hillstrom(
104+
hillstrom_dataset,
105+
target_col, target_shape
106+
):
107+
data = fetch_hillstrom(target_col=target_col)
108+
assert isinstance(data, sklearn.utils.Bunch)
109+
assert set(data.keys()) == set(hillstrom_dataset['keys'])
110+
assert data.data.shape == hillstrom_dataset['data.shape']
111+
assert data.target.shape == target_shape
112+
assert data.treatment.shape == hillstrom_dataset['treatment.shape']
113+
114+
115+
@pytest.fixture
116+
def megafon_dataset() -> dict:
117+
data = {'keys': ['data', 'target', 'treatment', 'DESCR', 'feature_names', 'target_name', 'treatment_name'],
118+
'data.shape': (600000, 50), 'target.shape': (600000,), 'treatment.shape': (600000,)}
119+
return data
120+
121+
122+
def test_fetch_megafon(megafon_dataset):
123+
data = fetch_megafon()
124+
assert isinstance(data, sklearn.utils.Bunch)
125+
assert set(data.keys()) == set(megafon_dataset['keys'])
126+
assert data.data.shape == megafon_dataset['data.shape']
127+
assert data.target.shape == megafon_dataset['target.shape']
128+
assert data.treatment.shape == megafon_dataset['treatment.shape']
129+
130+
131+
def check_return_X_y_t(bunch, dataset_func):
132+
X_y_t_tuple = dataset_func(return_X_y_t=True)
133+
assert isinstance(X_y_t_tuple, tuple)
134+
assert X_y_t_tuple[0].shape == bunch.data.shape
135+
assert X_y_t_tuple[1].shape == bunch.target.shape
136+
assert X_y_t_tuple[2].shape == bunch.treatment.shape
137+
138+
139+
@pytest.mark.parametrize("fetch_func", [fetch_hillstrom, fetch_criteo10, fetch_lenta, fetch_megafon])
94140
def test_return_X_y_t(fetch_func):
95141
data = fetch_func()
96142
check_return_X_y_t(data, fetch_func)

0 commit comments

Comments
 (0)