1+ import hashlib
12import os
23import shutil
3- import hashlib
44
55import pandas as pd
66import requests
@@ -17,7 +17,6 @@ def get_data_dir():
1717
1818 Returns:
1919 string: The path to scikit-uplift data dir.
20-
2120 """
2221 return os .path .join (os .path .expanduser ("~" ), "scikit-uplift-data" )
2322
@@ -27,13 +26,12 @@ def _create_data_dir(path):
2726
2827 Args:
2928 path (str): The path to scikit-uplift data dir.
30-
3129 """
3230 if not os .path .isdir (path ):
3331 os .makedirs (path )
3432
3533
36- def _download (url , dest_path , content_length_header_key = 'Content-Length' ):
34+ def _download (url , dest_path , content_length_header_key = 'Content-Length' , desc = None ):
3735 """Download the file from url and save it locally.
3836
3937 Args:
@@ -48,7 +46,7 @@ def _download(url, dest_path, content_length_header_key='Content-Length'):
4846
4947 with open (dest_path , "wb" ) as fd :
5048 total_size_in_bytes = int (req .headers .get (content_length_header_key , 0 ))
51- progress_bar = tqdm (total = total_size_in_bytes , unit = 'iB' , unit_scale = True )
49+ progress_bar = tqdm (desc = desc , total = total_size_in_bytes , unit = 'iB' , unit_scale = True )
5250 for chunk in req .iter_content (chunk_size = 2 ** 20 ):
5351 progress_bar .update (len (chunk ))
5452 fd .write (chunk )
@@ -57,7 +55,7 @@ def _download(url, dest_path, content_length_header_key='Content-Length'):
5755
5856
5957def _get_data (data_home , url , dest_subdir , dest_filename , download_if_missing ,
60- content_length_header_key = 'Content-Length' ):
58+ content_length_header_key = 'Content-Length' , desc = None ):
6159 """Return the path to the dataset.
6260
6361 Args:
@@ -72,7 +70,6 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing,
7270
7371 Returns:
7472 string: The path to the dataset.
75-
7673 """
7774 if data_home is None :
7875 if dest_subdir is None :
@@ -91,13 +88,19 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing,
9188
9289 if not os .path .isfile (dest_path ):
9390 if download_if_missing :
94- _download (url , dest_path , content_length_header_key )
91+ _download (url , dest_path , content_length_header_key , desc )
9592 else :
9693 raise IOError ("Dataset missing" )
9794 return dest_path
9895
99- def _get_file_hash (csv_path ):
100- with open (csv_path , 'rb' ) as file_to_check :
96+
97+ def _get_file_hash (path ):
98+ """Сompute the hash value for a file by using md5 algorithm.
99+
100+ Args:
101+ path (str): The path to file
102+ """
103+ with open (path , 'rb' ) as file_to_check :
101104 data = file_to_check .read ()
102105 return hashlib .md5 (data ).hexdigest ()
103106
@@ -107,7 +110,6 @@ def clear_data_dir(path=None):
107110
108111 Args:
109112 path (str): The path to scikit-uplift data dir
110-
111113 """
112114 if path is None :
113115 path = get_data_dir ()
@@ -175,20 +177,21 @@ def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, retu
175177
176178 :func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification).
177179 """
178-
179180 lenta_metadata = {
181+ 'desc' : 'Lenta dataset' ,
180182 'url' : 'https://sklift.s3.eu-west-2.amazonaws.com/lenta_dataset.csv.gz' ,
181183 'hash' : '6ab28ff0989ed8b8647f530e2e86452f'
182184 }
183185
184186 filename = lenta_metadata ['url' ].split ('/' )[- 1 ]
185187 csv_path = _get_data (data_home = data_home , url = lenta_metadata ['url' ], dest_subdir = dest_subdir ,
186188 dest_filename = filename ,
187- download_if_missing = download_if_missing )
189+ download_if_missing = download_if_missing ,
190+ desc = lenta_metadata ['desc' ])
188191
189192 if _get_file_hash (csv_path ) != lenta_metadata ['hash' ]:
190- raise ValueError (f"The { filename } file is broken,\
191- please clean the directory with the clean_data_dir function, and run the function again" )
193+ raise ValueError (f"The { filename } file is broken, please clean the directory "
194+ f" with the clean_data_dir() function, and run the function again" )
192195
193196 target_col = 'response_att'
194197 treatment_col = 'group'
@@ -276,23 +279,26 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
276279
277280 :func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification).
278281 """
279-
280282 x5_metadata = {
283+ 'desc_train' : 'Part 1: X5 train' ,
284+ 'desc_clients' : 'Part 2: X5 clients' ,
285+ 'desc_purchases' : 'Part 3: X5 purchases' ,
281286 'url_train' : 'https://sklift.s3.eu-west-2.amazonaws.com/uplift_train.csv.gz' ,
282287 'url_clients' : 'https://sklift.s3.eu-west-2.amazonaws.com/clients.csv.gz' ,
283288 'url_purchases' : 'https://sklift.s3.eu-west-2.amazonaws.com/purchases.csv.gz' ,
284- 'uplift_hash ' : '2720bbb659daa9e0989b2777b6a42d19' ,
285- 'clients_hash ' : 'b9cdeb2806b732771de03e819b3354c5' ,
286- 'purchases_hash ' : '48d2de13428e24e8b61d66fef02957a8'
289+ 'hash_train ' : '2720bbb659daa9e0989b2777b6a42d19' ,
290+ 'hash_clients ' : 'b9cdeb2806b732771de03e819b3354c5' ,
291+ 'hash_purchases ' : '48d2de13428e24e8b61d66fef02957a8'
287292 }
288293 file_train = x5_metadata ['url_train' ].split ('/' )[- 1 ]
289294 csv_train_path = _get_data (data_home = data_home , url = x5_metadata ['url_train' ], dest_subdir = dest_subdir ,
290295 dest_filename = file_train ,
291- download_if_missing = download_if_missing )
296+ download_if_missing = download_if_missing ,
297+ desc = x5_metadata ['desc_train' ])
292298
293- if _get_file_hash (csv_train_path ) != x5_metadata ['uplift_hash ' ]:
294- raise ValueError (f"The { file_train } file is broken,\
295- please clean the directory with the clean_data_dir function, and run the function again" )
299+ if _get_file_hash (csv_train_path ) != x5_metadata ['hash_train ' ]:
300+ raise ValueError (f"The { file_train } file is broken, please clean the directory "
301+ f" with the clean_data_dir() function, and run the function again" )
296302
297303 train = pd .read_csv (csv_train_path )
298304 train_features = list (train .columns )
@@ -307,24 +313,26 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True):
307313 file_clients = x5_metadata ['url_clients' ].split ('/' )[- 1 ]
308314 csv_clients_path = _get_data (data_home = data_home , url = x5_metadata ['url_clients' ], dest_subdir = dest_subdir ,
309315 dest_filename = file_clients ,
310- download_if_missing = download_if_missing )
316+ download_if_missing = download_if_missing ,
317+ desc = x5_metadata ['desc_clients' ])
311318
312- if _get_file_hash (csv_clients_path ) != x5_metadata ['clients_hash ' ]:
313- raise ValueError (f"The { file_clients } file is broken,\
314- please clean the directory with the clean_data_dir function, and run the function again" )
319+ if _get_file_hash (csv_clients_path ) != x5_metadata ['hash_clients ' ]:
320+ raise ValueError (f"The { file_clients } file is broken, please clean the directory "
321+ f" with the clean_data_dir() function, and run the function again" )
315322
316323 clients = pd .read_csv (csv_clients_path )
317324 clients_features = list (clients .columns )
318325
319326 file_purchases = x5_metadata ['url_purchases' ].split ('/' )[- 1 ]
320327 csv_purchases_path = _get_data (data_home = data_home , url = x5_metadata ['url_purchases' ], dest_subdir = dest_subdir ,
321328 dest_filename = file_purchases ,
322- download_if_missing = download_if_missing )
329+ download_if_missing = download_if_missing ,
330+ desc = x5_metadata ['desc_purchases' ])
323331
324- if _get_file_hash (csv_clients_path ) != x5_metadata ['purchases_hash ' ]:
325- raise ValueError (f"The { file_purchases } file is broken,\
326- please clean the directory with the clean_data_dir function, and run the function again" )
327-
332+ if _get_file_hash (csv_purchases_path ) != x5_metadata ['hash_purchases ' ]:
333+ raise ValueError (f"The { file_purchases } file is broken, please clean the directory "
334+ f" with the clean_data_dir() function, and run the function again" )
335+
328336 purchases = pd .read_csv (csv_purchases_path )
329337 purchases_features = list (purchases .columns )
330338
@@ -426,26 +434,28 @@ def fetch_criteo(target_col='visit', treatment_col='treatment', data_home=None,
426434 raise ValueError (f"The target_col must be an element of { target_cols + ['all' ]} . "
427435 f"Got value target_col={ target_col } ." )
428436
429- criteo_metadata = {
430- 'url' : '' ,
431- 'criteo_hash' : ''
432- }
433-
434437 if percent10 :
435- criteo_metadata ['url' ] = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz'
436- criteo_metadata ['criteo_hash' ] = 'fe159bcee2cea57548e48eb2d7d5d00c'
438+ criteo_metadata = {
439+ 'desc' : 'Criteo dataset (10 percent)' ,
440+ 'url' : 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz' ,
441+ 'hash' : 'fe159bcee2cea57548e48eb2d7d5d00c'
442+ }
437443 else :
438- criteo_metadata ['url' ] = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz"
439- criteo_metadata ['criteo_hash' ] = 'd2236769ef69e9be52556110102911ec'
444+ criteo_metadata = {
445+ 'desc' : 'Criteo dataset' ,
446+ 'url' : 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz' ,
447+ 'hash' : 'd2236769ef69e9be52556110102911ec'
448+ }
440449
441450 filename = criteo_metadata ['url' ].split ('/' )[- 1 ]
442451 csv_path = _get_data (data_home = data_home , url = criteo_metadata ['url' ], dest_subdir = dest_subdir ,
443452 dest_filename = filename ,
444- download_if_missing = download_if_missing )
453+ download_if_missing = download_if_missing ,
454+ desc = criteo_metadata ['desc' ])
445455
446- if _get_file_hash (csv_path ) != criteo_metadata ['criteo_hash ' ]:
447- raise ValueError (f"The { filename } file is broken,\
448- please clean the directory with the clean_data_dir function, and run the function again" )
456+ if _get_file_hash (csv_path ) != criteo_metadata ['hash ' ]:
457+ raise ValueError (f"The { filename } file is broken, please clean the directory "
458+ f" with the clean_data_dir() function, and run the function again" )
449459
450460 dtypes = {
451461 'exposure' : 'Int8' ,
@@ -544,18 +554,20 @@ def fetch_hillstrom(target_col='visit', data_home=None, dest_subdir=None, downlo
544554 f"Got value target_col={ target_col } ." )
545555
546556 hillstrom_metadata = {
557+ 'desc' : 'Hillstrom dataset' ,
547558 'url' : 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz' ,
548- 'hillstrom_hash ' : 'a68a81291f53a14f4e29002629803ba3'
559+ 'hash ' : 'a68a81291f53a14f4e29002629803ba3'
549560 }
550561
551562 filename = hillstrom_metadata ['url' ].split ('/' )[- 1 ]
552563 csv_path = _get_data (data_home = data_home , url = hillstrom_metadata ['url' ], dest_subdir = dest_subdir ,
553564 dest_filename = filename ,
554- download_if_missing = download_if_missing )
565+ download_if_missing = download_if_missing ,
566+ desc = hillstrom_metadata ['desc' ])
555567
556- if _get_file_hash (csv_path ) != hillstrom_metadata ['hillstrom_hash ' ]:
557- raise ValueError (f"The { filename } file is broken,\
558- please clean the directory with the clean_data_dir function, and run the function again" )
568+ if _get_file_hash (csv_path ) != hillstrom_metadata ['hash ' ]:
569+ raise ValueError (f"The { filename } file is broken, please clean the directory "
570+ f" with the clean_data_dir() function, and run the function again" )
559571
560572 treatment_col = 'segment'
561573
@@ -634,21 +646,22 @@ def fetch_megafon(data_home=None, dest_subdir=None, download_if_missing=True,
634646 :func:`.fetch_criteo`: Load and return the Criteo Uplift Prediction Dataset (classification).
635647
636648 :func:`.fetch_hillstrom`: Load and return Kevin Hillstrom Dataset MineThatData (classification or regression).
637-
638649 """
639650 megafon_metadata = {
651+ 'desc' : 'Megafon dataset' ,
640652 'url' : 'https://sklift.s3.eu-west-2.amazonaws.com/megafon_dataset.csv.gz' ,
641- 'megafon_hash ' : 'ee8d45a343d4d2cf90bb756c93959ecd'
653+ 'hash ' : 'ee8d45a343d4d2cf90bb756c93959ecd'
642654 }
643655
644656 filename = megafon_metadata ['url' ].split ('/' )[- 1 ]
645657 csv_path = _get_data (data_home = data_home , url = megafon_metadata ['url' ], dest_subdir = dest_subdir ,
646- dest_filename = filename ,
647- download_if_missing = download_if_missing )
658+ dest_filename = filename ,
659+ download_if_missing = download_if_missing ,
660+ desc = megafon_metadata ['desc' ])
648661
649- if _get_file_hash (csv_path ) != megafon_metadata ['megafon_hash ' ]:
650- raise ValueError (f"The { filename } file is broken,\
651- please clean the directory with the clean_data_dir function, and run the function again" )
662+ if _get_file_hash (csv_path ) != megafon_metadata ['hash ' ]:
663+ raise ValueError (f"The { filename } file is broken, please clean the directory "
664+ f" with the clean_data_dir() function, and run the function again" )
652665
653666 train = pd .read_csv (csv_path )
654667
0 commit comments