Skip to content

Commit 76a8756

Browse files
committed
Merge branch 'dev' of https://github.com/maks-sh/scikit-uplift into dev
2 parents d06e3e5 + 2a6c54b commit 76a8756

File tree

10 files changed

+120
-82
lines changed

10 files changed

+120
-82
lines changed

docs/api/datasets/create_data_dir.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/api/datasets/download.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/api/datasets/fetch_criteo.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
***********************************
1+
**************************************
22
`sklift.datasets <./>`_.fetch_criteo
3-
***********************************
3+
**************************************
44

5-
.. autofunction:: sklift.datasets.datasets.fetch_criteo
5+
.. autofunction:: sklift.datasets.datasets.fetch_criteo
6+
7+
.. include:: ../../../sklift/datasets/descr/criteo.rst

docs/api/datasets/fetch_hillstorm.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
****************************************
2+
`sklift.datasets <./>`_.fetch_hillstrom
3+
****************************************
4+
5+
.. autofunction:: sklift.datasets.datasets.fetch_hillstrom
6+
7+
.. include:: ../../../sklift/datasets/descr/lenta.rst

docs/api/datasets/fetch_lenta.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
***********************************
44

55
.. autofunction:: sklift.datasets.datasets.fetch_lenta
6+
67
.. include:: ../../../sklift/datasets/descr/lenta.rst

docs/api/datasets/fetch_x5.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22
`sklift.datasets <./>`_.fetch_x5
33
***********************************
44

5-
.. autofunction:: sklift.datasets.datasets.fetch_x5
5+
.. autofunction:: sklift.datasets.datasets.fetch_x5
6+
7+
.. include:: ../../../sklift/datasets/descr/x5.rst

docs/api/datasets/get_data.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/api/datasets/index.rst

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
:maxdepth: 3
77

88
./clear_data_dir
9-
./create_data_dir
10-
./download
119
./get_data_dir
12-
./get_data
1310
./fetch_lenta
1411
./fetch_x5
1512
./fetch_criteo
16-
./fetch_hillstorm
13+
./fetch_hillstrom

sklift/datasets/datasets.py

Lines changed: 103 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_data_dir():
1515
return os.path.join(os.path.expanduser("~"), "scikit-uplift-data")
1616

1717

18-
def create_data_dir(path):
18+
def _create_data_dir(path):
1919
"""This function creates a directory, which stores the datasets.
2020
2121
Args:
@@ -26,7 +26,7 @@ def create_data_dir(path):
2626
os.makedirs(path)
2727

2828

29-
def download(url, dest_path):
29+
def _download(url, dest_path):
3030
'''Download the file from url and save it localy
3131
3232
Args:
@@ -47,7 +47,7 @@ def download(url, dest_path):
4747
raise TypeError("URL must be a string")
4848

4949

50-
def get_data(data_home, url, dest_subdir, dest_filename, download_if_missing):
50+
def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing):
5151
"""Return the path to the dataset.
5252
5353
Args:
@@ -71,13 +71,13 @@ def get_data(data_home, url, dest_subdir, dest_filename, download_if_missing):
7171
else:
7272
data_dir = os.path.join(os.path.abspath(data_home), dest_subdir)
7373

74-
create_data_dir(data_dir)
74+
_create_data_dir(data_dir)
7575

7676
dest_path = os.path.join(data_dir, dest_filename)
7777

7878
if not os.path.isfile(dest_path):
7979
if download_if_missing:
80-
download(url, dest_path)
80+
_download(url, dest_path)
8181
else:
8282
raise IOError("Dataset missing")
8383
return dest_path
@@ -95,15 +95,16 @@ def clear_data_dir(path=None):
9595
shutil.rmtree(path, ignore_errors=True)
9696

9797

98-
def fetch_lenta(return_X_y_t=False, data_home=None, dest_subdir=None, download_if_missing=True):
99-
'''Fetch the Lenta dataset.
98+
def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, return_X_y_t=False, as_frame=False):
99+
"""Fetch the Lenta dataset.
100100
101101
Args:
102-
return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object.
103-
See below for more information about the data and target object.
104102
data_home (str, unicode): The path to the folder where datasets are stored.
105103
dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
106104
download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
105+
return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object.
106+
See below for more information about the data and target object.
107+
as_frame (bool):
107108
108109
Returns:
109110
* dataset ('~sklearn.utils.Bunch'): Dictionary-like object, with the following attributes.
@@ -113,69 +114,101 @@ def fetch_lenta(return_X_y_t=False, data_home=None, dest_subdir=None, download_i
113114
* DESCR (str): Description of the Lenta dataset.
114115
115116
* (data,target,treatment): tuple if 'return_X_y_t' is True.
116-
'''
117-
url='https:/winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz'
118-
filename='lentadataset.csv.gz'
119-
csv_path=get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
120-
dest_filename=filename,
121-
download_if_missing=download_if_missing)
117+
"""
118+
119+
url = 'https://winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz'
120+
filename = 'lentadataset.csv.gz'
121+
csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
122+
dest_filename=filename,
123+
download_if_missing=download_if_missing)
122124
data = pd.read_csv(csv_path)
123-
target=data['response_att']
124-
treatment=data['group']
125-
data=data.drop(['response_att', 'group'], axis=1)
125+
if as_frame:
126+
target=data['response_att']
127+
treatment=data['group']
128+
data=data.drop(['response_att', 'group'], axis=1)
129+
feature_names = list(data.columns)
130+
else:
131+
target = data[['response_att']].to_numpy()
132+
treatment = data[['group']].to_numpy()
133+
data = data.drop(['response_att', 'group'], axis=1)
134+
feature_names = list(data.columns)
135+
data = data.to_numpy()
126136

127137
module_path = os.path.dirname(__file__)
128138
with open(os.path.join(module_path, 'descr', 'lenta.rst')) as rst_file:
129139
fdescr = rst_file.read()
130140

131-
if return_X_y_t == True:
141+
if return_X_y_t:
132142
return data, target, treatment
133143

134-
return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr)
144+
return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr,
145+
feature_names=feature_names, target_name='response_att', treatment_name='group')
135146

136147

137-
def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
148+
def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True, as_frame=False):
138149
"""Fetch the X5 dataset.
139150
140-
Args:
141-
'~sklearn.utils.Bunch': dataset
151+
Args:
152+
data_home (string): Specify a download and cache folder for the datasets.
153+
dest_subdir (string, unicode): The name of the folder in which the dataset is stored.
154+
download_if_missing (bool, default=True): If False, raise an IOError if the data is not locally available
155+
instead of trying to download the data from the source site.
156+
as_frame (bool, default=False):
157+
158+
Returns:
159+
'~sklearn.utils.Bunch': dataset
142160
Dictionary-like object, with the following attributes.
143-
data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
144-
target (Series object): Column target by values
145-
treatment (Series object): Column treatment by values
146-
DESCR (str): Description of the X5 dataset.
147-
train (DataFrame object): Dataset with target and treatment.
161+
data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
162+
target (Series object): Column target by values
163+
treatment (Series object): Column treatment by values
164+
DESCR (str): Description of the X5 dataset.
165+
train (DataFrame object): Dataset with target and treatment.
166+
data_names ('~sklearn.utils.Bunch'): Names of features.
167+
treatment_name (string): The name of the treatment column.
148168
"""
149169
url_clients = 'https://timds.s3.eu-central-1.amazonaws.com/clients.csv.gz'
150170
file_clients = 'clients.csv.gz'
151-
csv_clients_path = get_data(data_home=data_home, url=url_clients, dest_subdir=dest_subdir,
171+
csv_clients_path = _get_data(data_home=data_home, url=url_clients, dest_subdir=dest_subdir,
152172
dest_filename=file_clients,
153173
download_if_missing=download_if_missing)
154174
clients = pd.read_csv(csv_clients_path)
175+
clients_names = list(clients.column)
155176

156177
url_train = 'https://timds.s3.eu-central-1.amazonaws.com/uplift_train.csv.gz'
157178
file_train = 'uplift_train.csv.gz'
158-
csv_train_path = get_data(data_home=data_home, url=url_train, dest_subdir=dest_subdir,
179+
csv_train_path = _get_data(data_home=data_home, url=url_train, dest_subdir=dest_subdir,
159180
dest_filename=file_train,
160181
download_if_missing=download_if_missing)
161182
train = pd.read_csv(csv_train_path)
183+
train_names = list(train.columns)
162184

163185
url_purchases = 'https://timds.s3.eu-central-1.amazonaws.com/purchases.csv.gz'
164186
file_purchases = 'purchases.csv.gz'
165-
csv_purchases_path = get_data(data_home=data_home, url=url_purchases, dest_subdir=dest_subdir,
187+
csv_purchases_path = _get_data(data_home=data_home, url=url_purchases, dest_subdir=dest_subdir,
166188
dest_filename=file_purchases,
167189
download_if_missing=download_if_missing)
168190
purchases = pd.read_csv(csv_purchases_path)
191+
purchases_names = list(purchases.columns)
169192

170-
target = train['target']
171-
treatment = train['treatment_flg']
193+
if as_frame:
194+
target = train['target']
195+
treatment = train['treatment_flg']
196+
else:
197+
target = train[['target']].to_numpy()
198+
treatment = train[['treatment_flg']].to_numpy()
199+
train = train.to_numpy()
200+
clients = clients.to_numpy()
201+
purchases = purchases.to_numpy()
172202

173203
module_path = os.path.dirname(__file__)
174204
with open(os.path.join(module_path, 'descr', 'x5.rst')) as rst_file:
175205
fdescr = rst_file.read()
176206

177207
return Bunch(data=Bunch(clients=clients, train=train, purchases=purchases),
178-
target=target, treatment=treatment, DESCR=fdescr)
208+
target=target, treatment=treatment, DESCR=fdescr,
209+
data_names=Bunch(clients_names=clients_names, train_names=train_names,
210+
purchases_names=purchases_names),
211+
treatment_name='treatment_flg')
179212

180213

181214
def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, percent10=True,
@@ -209,14 +242,14 @@ def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, per
209242
"""
210243
if percent10:
211244
url = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
212-
csv_path = get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
245+
csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
213246
dest_filename='criteo10.csv.gz',
214247
download_if_missing=download_if_missing)
215248
else:
216249
url = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz"
217-
csv_path = get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
218-
dest_filename='criteo.csv.gz',
219-
download_if_missing=download_if_missing)
250+
csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir,
251+
dest_filename='criteo.csv.gz',
252+
download_if_missing=download_if_missing)
220253

221254
if treatment_feature == 'exposure':
222255
data = pd.read_csv(csv_path, usecols=[i for i in range(12)])
@@ -264,22 +297,21 @@ def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, per
264297
feature_names=feature_names, target_name=target_name, treatment_name=treatment_name)
265298

266299

267-
def fetch_hillstrom(target='visit',
268-
data_home=None,
269-
dest_subdir=None,
270-
download_if_missing=True,
271-
return_X_y=False):
300+
def fetch_hillstrom(data_home=None, dest_subdir=None, download_if_missing=True, target_column='visit',
301+
return_X_y_t=False, as_frame=False):
272302
"""Load the hillstrom dataset.
273303
274304
Args:
275-
target : str, desfault=visit.
276-
Can also be conversion, and spend
277305
data_home : str, default=None
278306
Specify another download and cache folder for the datasets.
279307
dest_subdir : str, default=None
280308
download_if_missing : bool, default=True
281309
If False, raise a IOError if the data is not locally available
282310
instead of trying to download the data from the source site.
311+
target_column (string, 'visit' or 'conversion' or 'spend', default='visit'): Selects which column from dataset
312+
will be target
313+
return_X_y_t (bool):
314+
as_frame (bool):
283315
284316
Returns:
285317
Dictionary-like object, with the following attributes.
@@ -288,24 +320,41 @@ def fetch_hillstrom(target='visit',
288320
target : {ndarray, series} of shape (64000,)
289321
The regression target for each sample.
290322
treatment : {ndarray, series} of shape (64000,)
323+
feature_names (list): The names of the future columns
324+
target_name (string): The name of the target column.
325+
treatment_name (string): The name of the treatment column
291326
"""
292327

293328
url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'
294-
csv_path = get_data(data_home=data_home,
329+
csv_path = _get_data(data_home=data_home,
295330
url=url,
296331
dest_subdir=dest_subdir,
297332
dest_filename='hillstorm_no_indices.csv.gz',
298333
download_if_missing=download_if_missing)
299-
hillstrom = pd.read_csv(csv_path)
300-
hillstrom_data = hillstrom.drop(columns=['segment', target])
334+
335+
if target_column != ('visit' or 'conversion' or 'spend'):
336+
raise ValueError(f"Target_column value must be from {['visit', 'conversion', 'spend']}. "
337+
f"Got value {target_column}.")
338+
339+
data = pd.read_csv(csv_path, usecols=[i for i in range(8)])
340+
feature_names = list(data.columns)
341+
treatment = pd.read_csv(csv_path, usecols=['segment'])
342+
target = pd.read_csv(csv_path, usecols=[target_column])
343+
if as_frame:
344+
target = target[target_column]
345+
treatment = treatment['segment']
346+
else:
347+
data = data.to_numpy()
348+
target = target.to_numpy()
349+
treatment = treatment.to_numpy()
301350

302351
module_path = os.path.dirname('__file__')
303352
with open(os.path.join(module_path, 'descr', 'hillstrom.rst')) as rst_file:
304353
fdescr = rst_file.read()
305354

306-
if return_X_y:
307-
return treatment, data, target
308-
309-
return Bunch(treatment=hillstrom['segment'],
310-
target=hillstrom[target],
311-
data=hillstrom_data, DESCR=fdescr)
355+
if return_X_y_t:
356+
return data, target, treatment
357+
else:
358+
target_name = target_column
359+
return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr,
360+
feature_names=feature_names, target_name=target_name, treatment_name='segment')

0 commit comments

Comments
 (0)