33
44from functools import partial
55
6- from ..datasets import (fetch_lenta , fetch_x5 ,
7- fetch_criteo , fetch_hillstrom )
6+ from ..datasets import (
7+ clear_data_dir ,
8+ fetch_lenta , fetch_x5 ,
9+ fetch_criteo , fetch_hillstrom ,
10+ fetch_megafon
11+ )
812
913
1014fetch_criteo10 = partial (fetch_criteo , percent10 = True )
1115
12-
13- def check_return_X_y_t (bunch , dataset_func ):
14- X_y_t_tuple = dataset_func (return_X_y_t = True )
15- assert isinstance (X_y_t_tuple , tuple )
16- assert X_y_t_tuple [0 ].shape == bunch .data .shape
17- assert X_y_t_tuple [1 ].shape == bunch .target .shape
18- assert X_y_t_tuple [2 ].shape == bunch .treatment .shape
16+ @pytest .fixture (scope = "session" , autouse = True )
17+ def clear ():
18+ # prepare something ahead of all tests
19+ clear_data_dir ()
1920
2021
2122@pytest .fixture
@@ -53,20 +54,11 @@ def test_fetch_x5(x5_dataset):
5354 assert data .treatment .shape == x5_dataset ['treatment.shape' ]
5455
5556
56- @pytest .mark .parametrize (
57- 'target_col, target_shape' ,
58- [('visit' , (64_000 ,)),
59- ('conversion' , (64_000 ,)),
60- ('spend' , (64_000 ,)),
61- ('all' , (64_000 , 3 ))]
62- )
63- def test_fetch_hillstrom (
64- target_col , target_shape
65- ):
66- data = fetch_hillstrom (target_col = target_col )
67- assert data .data .shape == (64_000 , 8 )
68- assert data .target .shape == target_shape
69- assert data .treatment .shape == (64_000 ,)
57+ @pytest .fixture
58+ def criteo10_dataset () -> dict :
59+ data = {'keys' : ['data' , 'target' , 'treatment' , 'DESCR' , 'feature_names' , 'target_name' , 'treatment_name' ],
60+ 'data.shape' : (1397960 , 12 )}
61+ return data
7062
7163
7264@pytest .mark .parametrize (
@@ -82,15 +74,69 @@ def test_fetch_hillstrom(
8274 ('all' , (1397960 , 2 ))]
8375)
8476def test_fetch_criteo10 (
85- target_col , target_shape , treatment_col , treatment_shape
77+ criteo10_dataset ,
78+ target_col , target_shape ,
79+ treatment_col , treatment_shape
8680):
8781 data = fetch_criteo10 (target_col = target_col , treatment_col = treatment_col )
88- assert data .data .shape == (1397960 , 12 )
82+ assert isinstance (data , sklearn .utils .Bunch )
83+ assert set (data .keys ()) == set (criteo10_dataset ['keys' ])
84+ assert data .data .shape == criteo10_dataset ['data.shape' ]
8985 assert data .target .shape == target_shape
9086 assert data .treatment .shape == treatment_shape
9187
9288
93- @pytest .mark .parametrize ("fetch_func" , [fetch_hillstrom , fetch_criteo10 , fetch_lenta ])
89+ @pytest .fixture
90+ def hillstrom_dataset () -> dict :
91+ data = {'keys' : ['data' , 'target' , 'treatment' , 'DESCR' , 'feature_names' , 'target_name' , 'treatment_name' ],
92+ 'data.shape' : (64000 , 8 ), 'treatment.shape' : (64000 ,)}
93+ return data
94+
95+
96+ @pytest .mark .parametrize (
97+ 'target_col, target_shape' ,
98+ [('visit' , (64_000 ,)),
99+ ('conversion' , (64_000 ,)),
100+ ('spend' , (64_000 ,)),
101+ ('all' , (64_000 , 3 ))]
102+ )
103+ def test_fetch_hillstrom (
104+ hillstrom_dataset ,
105+ target_col , target_shape
106+ ):
107+ data = fetch_hillstrom (target_col = target_col )
108+ assert isinstance (data , sklearn .utils .Bunch )
109+ assert set (data .keys ()) == set (hillstrom_dataset ['keys' ])
110+ assert data .data .shape == hillstrom_dataset ['data.shape' ]
111+ assert data .target .shape == target_shape
112+ assert data .treatment .shape == hillstrom_dataset ['treatment.shape' ]
113+
114+
115+ @pytest .fixture
116+ def megafon_dataset () -> dict :
117+ data = {'keys' : ['data' , 'target' , 'treatment' , 'DESCR' , 'feature_names' , 'target_name' , 'treatment_name' ],
118+ 'data.shape' : (600000 , 50 ), 'target.shape' : (600000 ,), 'treatment.shape' : (600000 ,)}
119+ return data
120+
121+
122+ def test_fetch_megafon (megafon_dataset ):
123+ data = fetch_megafon ()
124+ assert isinstance (data , sklearn .utils .Bunch )
125+ assert set (data .keys ()) == set (megafon_dataset ['keys' ])
126+ assert data .data .shape == megafon_dataset ['data.shape' ]
127+ assert data .target .shape == megafon_dataset ['target.shape' ]
128+ assert data .treatment .shape == megafon_dataset ['treatment.shape' ]
129+
130+
131+ def check_return_X_y_t (bunch , dataset_func ):
132+ X_y_t_tuple = dataset_func (return_X_y_t = True )
133+ assert isinstance (X_y_t_tuple , tuple )
134+ assert X_y_t_tuple [0 ].shape == bunch .data .shape
135+ assert X_y_t_tuple [1 ].shape == bunch .target .shape
136+ assert X_y_t_tuple [2 ].shape == bunch .treatment .shape
137+
138+
139+ @pytest .mark .parametrize ("fetch_func" , [fetch_hillstrom , fetch_criteo10 , fetch_lenta , fetch_megafon ])
94140def test_return_X_y_t (fetch_func ):
95141 data = fetch_func ()
96142 check_return_X_y_t (data , fetch_func )
0 commit comments