Skip to content

Commit 98c8fba

Browse files
sidorovTVTim Sidorov
andauthored
Added new version of def 'fetch_x5' (#62)
* X5_download * Revert "X5_download * Added def fetch_x5 * Added new version def 'fetch_x5' * Added 'purchases' * NewCommit * NewAdded Co-authored-by: Tim Sidorov <[email protected]>
1 parent c05ad3b commit 98c8fba

File tree

1 file changed

+51
-5
lines changed

1 file changed

+51
-5
lines changed

sklift/datasets/datasets.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import os
21
import shutil
2+
import os
33
import pandas as pd
44
import requests
55
from sklearn.utils import Bunch
@@ -41,7 +41,7 @@ def download(url, dest_path):
4141
req.raise_for_status()
4242

4343
with open(dest_path, "wb") as fd:
44-
for chunk in req.iter_content(chunk_size=2**20):
44+
for chunk in req.iter_content(chunk_size=2 ** 20):
4545
fd.write(chunk)
4646
else:
4747
raise TypeError("URL must be a string")
@@ -96,6 +96,54 @@ def clear_data_dir(path=None):
9696
shutil.rmtree(path, ignore_errors=True)
9797

9898

99+
def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
100+
"""Fetch the X5 dataset.
101+
102+
Args:
103+
data_home (str, unicode): The path to the folder where datasets are stored.
104+
dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
105+
download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
106+
107+
Returns:
108+
'~sklearn.utils.Bunch': dataset
109+
Dictionary-like object, with the following attributes.
110+
data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
111+
target (Series object): Column target by values
112+
treatment (Series object): Column treatment by values
113+
DESCR (str): Description of the X5 dataset.
114+
train (DataFrame object): Dataset with target and treatment.
115+
"""
116+
url_clients = 'https://timds.s3.eu-central-1.amazonaws.com/clients.csv.gz'
117+
file_clients = 'clients.csv.gz'
118+
csv_clients_path = get_data(data_home=data_home, url=url_clients, dest_subdir=dest_subdir,
119+
dest_filename=file_clients,
120+
download_if_missing=download_if_missing)
121+
clients = pd.read_csv(csv_clients_path)
122+
123+
url_train = 'https://timds.s3.eu-central-1.amazonaws.com/uplift_train.csv.gz'
124+
file_train = 'uplift_train.csv.gz'
125+
csv_train_path = get_data(data_home=data_home, url=url_train, dest_subdir=dest_subdir,
126+
dest_filename=file_train,
127+
download_if_missing=download_if_missing)
128+
train = pd.read_csv(csv_train_path)
129+
130+
url_purchases = 'https://timds.s3.eu-central-1.amazonaws.com/purchases.csv.gz'
131+
file_purchases = 'purchases.csv.gz'
132+
csv_purchases_path = get_data(data_home=data_home, url=url_purchases, dest_subdir=dest_subdir,
133+
dest_filename=file_purchases,
134+
download_if_missing=download_if_missing)
135+
purchases = pd.read_csv(csv_purchases_path)
136+
137+
target = train['target']
138+
treatment = train['treatment_flg']
139+
140+
module_path = os.path.dirname(__file__)
141+
with open(os.path.join(module_path, 'descr', 'x5.rst')) as rst_file:
142+
fdescr = rst_file.read()
143+
144+
return Bunch(data=Bunch(clients=clients, train=train, purchases=purchases),
145+
target=target, treatment=treatment, DESCR=fdescr)
146+
99147

100148
def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, percent10=True,
101149
treatment_feature='treatment', target_column='visit', return_X_y_t=False, as_frame=False):
@@ -188,7 +236,6 @@ def fetch_hillstrom(target='visit',
188236
dest_subdir=None,
189237
download_if_missing=True,
190238
return_X_y=False):
191-
192239
"""Load the hillstrom dataset.
193240
194241
Args:
@@ -208,8 +255,7 @@ def fetch_hillstrom(target='visit',
208255
target : {ndarray, series} of shape (64000,)
209256
The regression target for each sample.
210257
treatment : {ndarray, series} of shape (64000,)
211-
212-
"""
258+
"""
213259

214260
url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'
215261
csv_path = get_data(data_home=data_home,

0 commit comments

Comments
 (0)