1- import os
21import shutil
2+ import os
33import pandas as pd
44import requests
55from sklearn .utils import Bunch
@@ -41,7 +41,7 @@ def download(url, dest_path):
4141 req .raise_for_status ()
4242
4343 with open (dest_path , "wb" ) as fd :
44- for chunk in req .iter_content (chunk_size = 2 ** 20 ):
44+ for chunk in req .iter_content (chunk_size = 2 ** 20 ):
4545 fd .write (chunk )
4646 else :
4747 raise TypeError ("URL must be a string" )
@@ -96,6 +96,54 @@ def clear_data_dir(path=None):
9696 shutil .rmtree (path , ignore_errors = True )
9797
9898
99+ def fetch_x5 (data_home = None , dest_subdir = None , download_if_missing = True ):
100+ """Fetch the X5 dataset.
101+
102+ Args:
103+ data_home (str, unicode): The path to the folder where datasets are stored.
104+ dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
105+ download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
106+
107+ Returns:
108+ '~sklearn.utils.Bunch': dataset
109+ Dictionary-like object, with the following attributes.
110+ data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
111+ target (Series object): Column target by values
112+ treatment (Series object): Column treatment by values
113+ DESCR (str): Description of the X5 dataset.
114+ train (DataFrame object): Dataset with target and treatment.
115+ """
116+ url_clients = 'https://timds.s3.eu-central-1.amazonaws.com/clients.csv.gz'
117+ file_clients = 'clients.csv.gz'
118+ csv_clients_path = get_data (data_home = data_home , url = url_clients , dest_subdir = dest_subdir ,
119+ dest_filename = file_clients ,
120+ download_if_missing = download_if_missing )
121+ clients = pd .read_csv (csv_clients_path )
122+
123+ url_train = 'https://timds.s3.eu-central-1.amazonaws.com/uplift_train.csv.gz'
124+ file_train = 'uplift_train.csv.gz'
125+ csv_train_path = get_data (data_home = data_home , url = url_train , dest_subdir = dest_subdir ,
126+ dest_filename = file_train ,
127+ download_if_missing = download_if_missing )
128+ train = pd .read_csv (csv_train_path )
129+
130+ url_purchases = 'https://timds.s3.eu-central-1.amazonaws.com/purchases.csv.gz'
131+ file_purchases = 'purchases.csv.gz'
132+ csv_purchases_path = get_data (data_home = data_home , url = url_purchases , dest_subdir = dest_subdir ,
133+ dest_filename = file_purchases ,
134+ download_if_missing = download_if_missing )
135+ purchases = pd .read_csv (csv_purchases_path )
136+
137+ target = train ['target' ]
138+ treatment = train ['treatment_flg' ]
139+
140+ module_path = os .path .dirname (__file__ )
141+ with open (os .path .join (module_path , 'descr' , 'x5.rst' )) as rst_file :
142+ fdescr = rst_file .read ()
143+
144+ return Bunch (data = Bunch (clients = clients , train = train , purchases = purchases ),
145+ target = target , treatment = treatment , DESCR = fdescr )
146+
99147
100148def fetch_criteo (data_home = None , dest_subdir = None , download_if_missing = True , percent10 = True ,
101149 treatment_feature = 'treatment' , target_column = 'visit' , return_X_y_t = False , as_frame = False ):
@@ -188,7 +236,6 @@ def fetch_hillstrom(target='visit',
188236 dest_subdir = None ,
189237 download_if_missing = True ,
190238 return_X_y = False ):
191-
192239 """Load the hillstrom dataset.
193240
194241 Args:
@@ -208,8 +255,7 @@ def fetch_hillstrom(target='visit',
208255 target : {ndarray, series} of shape (64000,)
209256 The regression target for each sample.
210257 treatment : {ndarray, series} of shape (64000,)
211-
212- """
258+ """
213259
214260 url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'
215261 csv_path = get_data (data_home = data_home ,
0 commit comments