11import os
22import shutil
3+ import hashlib
34
45import pandas as pd
56import requests
@@ -95,6 +96,11 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing,
9596 raise IOError ("Dataset missing" )
9697 return dest_path
9798
99+ def _get_file_hash (csv_path ):
100+ with open (csv_path , 'rb' ) as file_to_check :
101+ data = file_to_check .read ()
102+ return hashlib .md5 (data ).hexdigest ()
103+
98104
99105def clear_data_dir (path = None ):
100106 """Delete all the content of the data home cache.
@@ -170,11 +176,19 @@ def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, retu
170176 :func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification).
171177 """
172178
173- url = 'https://sklift.s3.eu-west-2.amazonaws.com/lenta_dataset.csv.gz'
174- filename = url .split ('/' )[- 1 ]
175- csv_path = _get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
179+ lenta_metadata = {
180+ 'url' : 'https://sklift.s3.eu-west-2.amazonaws.com/lenta_dataset.csv.gz' ,
181+ 'hash' : '6ab28ff0989ed8b8647f530e2e86452f'
182+ }
183+
184+ filename = lenta_metadata ['url' ].split ('/' )[- 1 ]
185+ csv_path = _get_data (data_home = data_home , url = lenta_metadata ['url' ], dest_subdir = dest_subdir ,
176186 dest_filename = filename ,
177187 download_if_missing = download_if_missing )
188+
189+ if _get_file_hash (csv_path ) != lenta_metadata ['hash' ]:
190+ raise ValueError (f"The { filename } file is broken,\
191+ please clean the directory with the clean_data_dir function, and run the function again" )
178192
179193 target_col = 'response_att'
180194 treatment_col = 'group'
@@ -262,11 +276,24 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
262276
263277 :func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification).
264278 """
265- url_train = 'https://sklift.s3.eu-west-2.amazonaws.com/uplift_train.csv.gz'
266- file_train = url_train .split ('/' )[- 1 ]
267- csv_train_path = _get_data (data_home = data_home , url = url_train , dest_subdir = dest_subdir ,
279+
280+ x5_metadata = {
281+ 'url_train' : 'https://sklift.s3.eu-west-2.amazonaws.com/uplift_train.csv.gz' ,
282+ 'url_clients' : 'https://sklift.s3.eu-west-2.amazonaws.com/clients.csv.gz' ,
283+ 'url_purchases' : 'https://sklift.s3.eu-west-2.amazonaws.com/purchases.csv.gz' ,
284+ 'uplift_hash' : '2720bbb659daa9e0989b2777b6a42d19' ,
285+ 'clients_hash' : 'b9cdeb2806b732771de03e819b3354c5' ,
286+ 'purchases_hash' : '48d2de13428e24e8b61d66fef02957a8'
287+ }
288+ file_train = x5_metadata ['url_train' ].split ('/' )[- 1 ]
289+ csv_train_path = _get_data (data_home = data_home , url = x5_metadata ['url_train' ], dest_subdir = dest_subdir ,
268290 dest_filename = file_train ,
269291 download_if_missing = download_if_missing )
292+
293+ if _get_file_hash (csv_train_path ) != x5_metadata ['uplift_hash' ]:
294+ raise ValueError (f"The { file_train } file is broken,\
295+ please clean the directory with the clean_data_dir function, and run the function again" )
296+
270297 train = pd .read_csv (csv_train_path )
271298 train_features = list (train .columns )
272299
@@ -277,19 +304,27 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
277304
278305 train = train .drop ([target_col , treatment_col ], axis = 1 )
279306
280- url_clients = 'https://sklift.s3.eu-west-2.amazonaws.com/clients.csv.gz'
281- file_clients = url_clients .split ('/' )[- 1 ]
282- csv_clients_path = _get_data (data_home = data_home , url = url_clients , dest_subdir = dest_subdir ,
307+ file_clients = x5_metadata ['url_clients' ].split ('/' )[- 1 ]
308+ csv_clients_path = _get_data (data_home = data_home , url = x5_metadata ['url_clients' ], dest_subdir = dest_subdir ,
283309 dest_filename = file_clients ,
284310 download_if_missing = download_if_missing )
311+
312+ if _get_file_hash (csv_clients_path ) != x5_metadata ['clients_hash' ]:
313+ raise ValueError (f"The { file_clients } file is broken,\
314+ please clean the directory with the clean_data_dir function, and run the function again" )
315+
285316 clients = pd .read_csv (csv_clients_path )
286317 clients_features = list (clients .columns )
287318
288- url_purchases = 'https://sklift.s3.eu-west-2.amazonaws.com/purchases.csv.gz'
289- file_purchases = url_purchases .split ('/' )[- 1 ]
290- csv_purchases_path = _get_data (data_home = data_home , url = url_purchases , dest_subdir = dest_subdir ,
319+ file_purchases = x5_metadata ['url_purchases' ].split ('/' )[- 1 ]
320+ csv_purchases_path = _get_data (data_home = data_home , url = x5_metadata ['url_purchases' ], dest_subdir = dest_subdir ,
291321 dest_filename = file_purchases ,
292322 download_if_missing = download_if_missing )
323+
324+ if _get_file_hash (csv_clients_path ) != x5_metadata ['purchases_hash' ]:
325+ raise ValueError (f"The { file_purchases } file is broken,\
326+ please clean the directory with the clean_data_dir function, and run the function again" )
327+
293328 purchases = pd .read_csv (csv_purchases_path )
294329 purchases_features = list (purchases .columns )
295330
@@ -391,16 +426,27 @@ def fetch_criteo(target_col='visit', treatment_col='treatment', data_home=None,
391426 raise ValueError (f"The target_col must be an element of { target_cols + ['all' ]} . "
392427 f"Got value target_col={ target_col } ." )
393428
429+ criteo_metadata = {
430+ 'url' : '' ,
431+ 'criteo_hash' : ''
432+ }
433+
394434 if percent10 :
395- url = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
435+ criteo_metadata ['url' ] = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
436+ criteo_metadata ['criteo_hash' ] = 'fe159bcee2cea57548e48eb2d7d5d00c'
396437 else :
397- url = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz"
438+ criteo_metadata ['url' ] = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz"
439+ criteo_metadata ['criteo_hash' ] = 'd2236769ef69e9be52556110102911ec'
398440
399- filename = url .split ('/' )[- 1 ]
400- csv_path = _get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
441+ filename = criteo_metadata [ ' url' ] .split ('/' )[- 1 ]
442+ csv_path = _get_data (data_home = data_home , url = criteo_metadata [ ' url' ] , dest_subdir = dest_subdir ,
401443 dest_filename = filename ,
402444 download_if_missing = download_if_missing )
403445
446+ if _get_file_hash (csv_path ) != criteo_metadata ['criteo_hash' ]:
447+ raise ValueError (f"The { filename } file is broken,\
448+ please clean the directory with the clean_data_dir function, and run the function again" )
449+
404450 dtypes = {
405451 'exposure' : 'Int8' ,
406452 'treatment' : 'Int8' ,
@@ -497,11 +543,19 @@ def fetch_hillstrom(target_col='visit', data_home=None, dest_subdir=None, downlo
497543 raise ValueError (f"The target_col must be an element of { target_cols + ['all' ]} . "
498544 f"Got value target_col={ target_col } ." )
499545
500- url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz'
501- filename = url .split ('/' )[- 1 ]
502- csv_path = _get_data (data_home = data_home , url = url , dest_subdir = dest_subdir ,
546+ hillstrom_metadata = {
547+ 'url' : 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz' ,
548+ 'hillstrom_hash' : 'a68a81291f53a14f4e29002629803ba3'
549+ }
550+
551+ filename = hillstrom_metadata ['url' ].split ('/' )[- 1 ]
552+ csv_path = _get_data (data_home = data_home , url = hillstrom_metadata ['url' ], dest_subdir = dest_subdir ,
503553 dest_filename = filename ,
504554 download_if_missing = download_if_missing )
555+
556+ if _get_file_hash (csv_path ) != hillstrom_metadata ['hillstrom_hash' ]:
557+ raise ValueError (f"The { filename } file is broken,\
558+ please clean the directory with the clean_data_dir function, and run the function again" )
505559
506560 treatment_col = 'segment'
507561
@@ -582,12 +636,21 @@ def fetch_megafon(data_home=None, dest_subdir=None, download_if_missing=True,
582636 :func:`.fetch_hillstrom`: Load and return Kevin Hillstrom Dataset MineThatData (classification or regression).
583637
584638 """
585- url_train = 'https://sklift.s3.eu-west-2.amazonaws.com/megafon_dataset.csv.gz'
586- file_train = url_train .split ('/' )[- 1 ]
587- csv_train_path = _get_data (data_home = data_home , url = url_train , dest_subdir = dest_subdir ,
588- dest_filename = file_train ,
639+ megafon_metadata = {
640+ 'url' : 'https://sklift.s3.eu-west-2.amazonaws.com/megafon_dataset.csv.gz' ,
641+ 'megafon_hash' : 'ee8d45a343d4d2cf90bb756c93959ecd'
642+ }
643+
644+ filename = megafon_metadata ['url' ].split ('/' )[- 1 ]
645+ csv_path = _get_data (data_home = data_home , url = megafon_metadata ['url' ], dest_subdir = dest_subdir ,
646+ dest_filename = filename ,
589647 download_if_missing = download_if_missing )
590- train = pd .read_csv (csv_train_path )
648+
649+ if _get_file_hash (csv_path ) != megafon_metadata ['megafon_hash' ]:
650+ raise ValueError (f"The { filename } file is broken,\
651+ please clean the directory with the clean_data_dir function, and run the function again" )
652+
653+ train = pd .read_csv (csv_path )
591654
592655 target_col = 'conversion'
593656 treatment_col = 'treatment_group'
0 commit comments