@@ -15,7 +15,7 @@ def get_data_dir():
1515 return os .path .join (os .path .expanduser ("~" ), "scikit-uplift-data" )
1616
1717
18- def create_data_dir (path ):
18+ def _create_data_dir (path ):
1919 """This function creates a directory, which stores the datasets.
2020
2121 Args:
@@ -26,7 +26,7 @@ def create_data_dir(path):
2626 os .makedirs (path )
2727
2828
29- def download (url , dest_path ):
29+ def _download (url , dest_path ):
3030 '''Download the file from url and save it localy
3131
3232 Args:
@@ -47,7 +47,7 @@ def download(url, dest_path):
4747 raise TypeError ("URL must be a string" )
4848
4949
50- def get_data (data_home , url , dest_subdir , dest_filename , download_if_missing ):
50+ def _get_data (data_home , url , dest_subdir , dest_filename , download_if_missing ):
5151 """Return the path to the dataset.
5252
5353 Args:
@@ -71,13 +71,13 @@ def get_data(data_home, url, dest_subdir, dest_filename, download_if_missing):
7171 else :
7272 data_dir = os .path .join (os .path .abspath (data_home ), dest_subdir )
7373
74- create_data_dir (data_dir )
74+ _create_data_dir (data_dir )
7575
7676 dest_path = os .path .join (data_dir , dest_filename )
7777
7878 if not os .path .isfile (dest_path ):
7979 if download_if_missing :
80- download (url , dest_path )
80+ _download (url , dest_path )
8181 else :
8282 raise IOError ("Dataset missing" )
8383 return dest_path
@@ -95,15 +95,16 @@ def clear_data_dir(path=None):
9595 shutil .rmtree (path , ignore_errors = True )
9696
9797
98- def fetch_lenta (return_X_y_t = False , data_home = None , dest_subdir = None , download_if_missing = True ):
99- ''' Fetch the Lenta dataset.
98+ def fetch_lenta (data_home = None , dest_subdir = None , download_if_missing = True , return_X_y_t = False , as_frame = False ):
99+ """ Fetch the Lenta dataset.
100100
101101 Args:
102- return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object.
103- See below for more information about the data and target object.
104102 data_home (str, unicode): The path to the folder where datasets are stored.
105103 dest_subdir (str, unicode): The name of the folder in which the dataset is stored.
106104 download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing.
105+ return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object.
106+ See below for more information about the data and target object.
107+ as_frame (bool):
107108
108109 Returns:
109110 * dataset ('~sklearn.utils.Bunch'): Dictionary-like object, with the following attributes.
@@ -113,69 +114,101 @@ def fetch_lenta(return_X_y_t=False, data_home=None, dest_subdir=None, download_i
113114 * DESCR (str): Description of the Lenta dataset.
114115
115116 * (data,target,treatment): tuple if 'return_X_y_t' is True.
116- '''
117- url = 'https:/winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz'
118- filename = 'lentadataset.csv.gz'
119- csv_path = get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
120- dest_filename = filename ,
121- download_if_missing = download_if_missing )
117+ """
118+
119+ url = 'https://winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz'
120+ filename = 'lentadataset.csv.gz'
121+ csv_path = _get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
122+ dest_filename = filename ,
123+ download_if_missing = download_if_missing )
122124 data = pd .read_csv (csv_path )
123- target = data ['response_att' ]
124- treatment = data ['group' ]
125- data = data .drop (['response_att' , 'group' ], axis = 1 )
125+ if as_frame :
126+ target = data ['response_att' ]
127+ treatment = data ['group' ]
128+ data = data .drop (['response_att' , 'group' ], axis = 1 )
129+ feature_names = list (data .columns )
130+ else :
131+ target = data [['response_att' ]].to_numpy ()
132+ treatment = data [['group' ]].to_numpy ()
133+ data = data .drop (['response_att' , 'group' ], axis = 1 )
134+ feature_names = list (data .columns )
135+ data = data .to_numpy ()
126136
127137 module_path = os .path .dirname (__file__ )
128138 with open (os .path .join (module_path , 'descr' , 'lenta.rst' )) as rst_file :
129139 fdescr = rst_file .read ()
130140
131- if return_X_y_t == True :
141+ if return_X_y_t :
132142 return data , target , treatment
133143
134- return Bunch (data = data , target = target , treatment = treatment , DESCR = fdescr )
144+ return Bunch (data = data , target = target , treatment = treatment , DESCR = fdescr ,
145+ feature_names = feature_names , target_name = 'response_att' , treatment_name = 'group' )
135146
136147
137- def fetch_x5 (data_home = None , dest_subdir = None , download_if_missing = True ):
148+ def fetch_x5 (data_home = None , dest_subdir = None , download_if_missing = True , as_frame = False ):
138149 """Fetch the X5 dataset.
139150
140- Args:
141- '~sklearn.utils.Bunch': dataset
151+ Args:
152+ data_home (string): Specify a download and cache folder for the datasets.
153+ dest_subdir (string, unicode): The name of the folder in which the dataset is stored.
154+ download_if_missing (bool, default=True): If False, raise an IOError if the data is not locally available
155+ instead of trying to download the data from the source site.
156+ as_frame (bool, default=False):
157+
158+ Returns:
159+ '~sklearn.utils.Bunch': dataset
142160 Dictionary-like object, with the following attributes.
143- data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
144- target (Series object): Column target by values
145- treatment (Series object): Column treatment by values
146- DESCR (str): Description of the X5 dataset.
147- train (DataFrame object): Dataset with target and treatment.
161+ data ('~sklearn.utils.Bunch'): Dataset without target and treatment.
162+ target (Series object): Column target by values
163+ treatment (Series object): Column treatment by values
164+ DESCR (str): Description of the X5 dataset.
165+ train (DataFrame object): Dataset with target and treatment.
166+ data_names ('~sklearn.utils.Bunch'): Names of features.
167+ treatment_name (string): The name of the treatment column.
148168 """
149169 url_clients = 'https://timds.s3.eu-central-1.amazonaws.com/clients.csv.gz'
150170 file_clients = 'clients.csv.gz'
151- csv_clients_path = get_data (data_home = data_home , url = url_clients , dest_subdir = dest_subdir ,
171+ csv_clients_path = _get_data (data_home = data_home , url = url_clients , dest_subdir = dest_subdir ,
152172 dest_filename = file_clients ,
153173 download_if_missing = download_if_missing )
154174 clients = pd .read_csv (csv_clients_path )
175+ clients_names = list (clients .column )
155176
156177 url_train = 'https://timds.s3.eu-central-1.amazonaws.com/uplift_train.csv.gz'
157178 file_train = 'uplift_train.csv.gz'
158- csv_train_path = get_data (data_home = data_home , url = url_train , dest_subdir = dest_subdir ,
179+ csv_train_path = _get_data (data_home = data_home , url = url_train , dest_subdir = dest_subdir ,
159180 dest_filename = file_train ,
160181 download_if_missing = download_if_missing )
161182 train = pd .read_csv (csv_train_path )
183+ train_names = list (train .columns )
162184
163185 url_purchases = 'https://timds.s3.eu-central-1.amazonaws.com/purchases.csv.gz'
164186 file_purchases = 'purchases.csv.gz'
165- csv_purchases_path = get_data (data_home = data_home , url = url_purchases , dest_subdir = dest_subdir ,
187+ csv_purchases_path = _get_data (data_home = data_home , url = url_purchases , dest_subdir = dest_subdir ,
166188 dest_filename = file_purchases ,
167189 download_if_missing = download_if_missing )
168190 purchases = pd .read_csv (csv_purchases_path )
191+ purchases_names = list (purchases .columns )
169192
170- target = train ['target' ]
171- treatment = train ['treatment_flg' ]
193+ if as_frame :
194+ target = train ['target' ]
195+ treatment = train ['treatment_flg' ]
196+ else :
197+ target = train [['target' ]].to_numpy ()
198+ treatment = train [['treatment_flg' ]].to_numpy ()
199+ train = train .to_numpy ()
200+ clients = clients .to_numpy ()
201+ purchases = purchases .to_numpy ()
172202
173203 module_path = os .path .dirname (__file__ )
174204 with open (os .path .join (module_path , 'descr' , 'x5.rst' )) as rst_file :
175205 fdescr = rst_file .read ()
176206
177207 return Bunch (data = Bunch (clients = clients , train = train , purchases = purchases ),
178- target = target , treatment = treatment , DESCR = fdescr )
208+ target = target , treatment = treatment , DESCR = fdescr ,
209+ data_names = Bunch (clients_names = clients_names , train_names = train_names ,
210+ purchases_names = purchases_names ),
211+ treatment_name = 'treatment_flg' )
179212
180213
181214def fetch_criteo (data_home = None , dest_subdir = None , download_if_missing = True , percent10 = True ,
@@ -209,14 +242,14 @@ def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, per
209242 """
210243 if percent10 :
211244 url = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
212- csv_path = get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
245+ csv_path = _get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
213246 dest_filename = 'criteo10.csv.gz' ,
214247 download_if_missing = download_if_missing )
215248 else :
216249 url = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz"
217- csv_path = get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
218- dest_filename = 'criteo.csv.gz' ,
219- download_if_missing = download_if_missing )
250+ csv_path = _get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
251+ dest_filename = 'criteo.csv.gz' ,
252+ download_if_missing = download_if_missing )
220253
221254 if treatment_feature == 'exposure' :
222255 data = pd .read_csv (csv_path , usecols = [i for i in range (12 )])
@@ -264,22 +297,21 @@ def fetch_criteo(data_home=None, dest_subdir=None, download_if_missing=True, per
264297 feature_names = feature_names , target_name = target_name , treatment_name = treatment_name )
265298
266299
267- def fetch_hillstrom (target = 'visit' ,
268- data_home = None ,
269- dest_subdir = None ,
270- download_if_missing = True ,
271- return_X_y = False ):
300+ def fetch_hillstrom (data_home = None , dest_subdir = None , download_if_missing = True , target_column = 'visit' ,
301+ return_X_y_t = False , as_frame = False ):
272302 """Load the hillstrom dataset.
273303
274304 Args:
275- target : str, desfault=visit.
276- Can also be conversion, and spend
277305 data_home : str, default=None
278306 Specify another download and cache folder for the datasets.
279307 dest_subdir : str, default=None
280308 download_if_missing : bool, default=True
281309 If False, raise a IOError if the data is not locally available
282310 instead of trying to download the data from the source site.
311+ target_column (string, 'visit' or 'conversion' or 'spend', default='visit'): Selects which column from dataset
312+ will be target
313+ return_X_y_t (bool):
314+ as_frame (bool):
283315
284316 Returns:
285317 Dictionary-like object, with the following attributes.
@@ -288,24 +320,41 @@ def fetch_hillstrom(target='visit',
288320 target : {ndarray, series} of shape (64000,)
289321 The regression target for each sample.
290322 treatment : {ndarray, series} of shape (64000,)
323+ feature_names (list): The names of the future columns
324+ target_name (string): The name of the target column.
325+ treatment_name (string): The name of the treatment column
291326 """
292327
293328 url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'
294- csv_path = get_data (data_home = data_home ,
329+ csv_path = _get_data (data_home = data_home ,
295330 url = url ,
296331 dest_subdir = dest_subdir ,
297332 dest_filename = 'hillstorm_no_indices.csv.gz' ,
298333 download_if_missing = download_if_missing )
299- hillstrom = pd .read_csv (csv_path )
300- hillstrom_data = hillstrom .drop (columns = ['segment' , target ])
334+
335+ if target_column != ('visit' or 'conversion' or 'spend' ):
336+ raise ValueError (f"Target_column value must be from { ['visit' , 'conversion' , 'spend' ]} . "
337+ f"Got value { target_column } ." )
338+
339+ data = pd .read_csv (csv_path , usecols = [i for i in range (8 )])
340+ feature_names = list (data .columns )
341+ treatment = pd .read_csv (csv_path , usecols = ['segment' ])
342+ target = pd .read_csv (csv_path , usecols = [target_column ])
343+ if as_frame :
344+ target = target [target_column ]
345+ treatment = treatment ['segment' ]
346+ else :
347+ data = data .to_numpy ()
348+ target = target .to_numpy ()
349+ treatment = treatment .to_numpy ()
301350
302351 module_path = os .path .dirname ('__file__' )
303352 with open (os .path .join (module_path , 'descr' , 'hillstrom.rst' )) as rst_file :
304353 fdescr = rst_file .read ()
305354
306- if return_X_y :
307- return treatment , data , target
308-
309- return Bunch ( treatment = hillstrom [ 'segment' ],
310- target = hillstrom [ target ] ,
311- data = hillstrom_data , DESCR = fdescr )
355+ if return_X_y_t :
356+ return data , target , treatment
357+ else :
358+ target_name = target_column
359+ return Bunch ( data = data , target = target , treatment = treatment , DESCR = fdescr ,
360+ feature_names = feature_names , target_name = target_name , treatment_name = 'segment' )
0 commit comments