@@ -27,11 +27,19 @@ def get_args():
2727 required = True ,
2828 help = "path to the parquet dataset folder" ,
2929 )
30- parser .add_argument ("--save-path-stats-json" , type = str , help = "Where to save the stats json." )
31- parser .add_argument ("--save-path-stats-full-json" , type = str , help = "Where to save the stats json." )
32- parser .add_argument ("--save-batch-size" , type = int , required = True , help = "Batch size when writing." )
30+ parser .add_argument (
31+ "--save-path-stats-json" , type = str , help = "Where to save the stats json."
32+ )
33+ parser .add_argument (
34+ "--save-path-stats-full-json" , type = str , help = "Where to save the stats json."
35+ )
36+ parser .add_argument (
37+ "--save-batch-size" , type = int , required = True , help = "Batch size when writing."
38+ )
3339 parser .add_argument ("--use-datasets-caching" , action = "store_true" )
34- parser .add_argument ("--num-proc" , type = int , default = 1 , help = "Number of procs use for preprocessing." )
40+ parser .add_argument (
41+ "--num-proc" , type = int , default = 1 , help = "Number of procs use for preprocessing."
42+ )
3543 parser .add_argument (
3644 "--seed-id" ,
3745 type = int ,
@@ -57,12 +65,16 @@ def main():
5765 level = logging .INFO ,
5866 )
5967 args = get_args ()
60- logger .info (f"** The job is runned with the following arguments: **\n { args } \n **** " )
68+ logger .info (
69+ f"** The job is runned with the following arguments: **\n { args } \n **** "
70+ )
6171
6272 if not args .use_datasets_caching :
6373 datasets .set_caching_enabled (False )
6474 else :
65- logger .info (f"the datasets results will be cached at { config .HF_DATASETS_CACHE } ." )
75+ logger .info (
76+ f"the datasets results will be cached at { config .HF_DATASETS_CACHE } ."
77+ )
6678
6779 ds = load_from_disk (args .dataset_path )
6880
@@ -73,15 +85,19 @@ def main():
7385 splits = {
7486 ** {
7587 mime_type : ds .filter (
76- lambda mime_types_ : [mime_type_ == mime_type for mime_type_ in mime_types_ ],
88+ lambda mime_types_ : [
89+ mime_type_ == mime_type for mime_type_ in mime_types_
90+ ],
7791 input_columns = "content_mime_detected" ,
7892 batched = True ,
7993 num_proc = args .num_proc ,
8094 )
8195 for mime_type in selected_mime_types
8296 },
8397 "others" : ds .filter (
84- lambda mime_types_ : [mime_type_ not in selected_mime_types for mime_type_ in mime_types_ ],
98+ lambda mime_types_ : [
99+ mime_type_ not in selected_mime_types for mime_type_ in mime_types_
100+ ],
85101 input_columns = "content_mime_detected" ,
86102 batched = True ,
87103 num_proc = args .num_proc ,
@@ -96,7 +112,11 @@ def get_length_text(example):
96112 example ["length_text" ] = len (example ["text" ])
97113 return example
98114
99- cols_to_remove = [col for col in ds .column_names if col not in ["content_languages" , "url_host_tld" ]]
115+ cols_to_remove = [
116+ col
117+ for col in ds .column_names
118+ if col not in ["content_languages" , "url_host_tld" ]
119+ ]
100120 ds_html = ds_html .map (
101121 get_length_text ,
102122 batched = False ,
@@ -105,10 +125,14 @@ def get_length_text(example):
105125 )
106126
107127 data_stats ["html_empty_text" ] = len ([e for e in ds_html ["length_text" ] if e == 0 ])
108- data_stats ["html_mean_length_non_empty_text" ] = mean ([e for e in ds_html ["length_text" ] if e != 0 ])
128+ data_stats ["html_mean_length_non_empty_text" ] = mean (
129+ [e for e in ds_html ["length_text" ] if e != 0 ]
130+ )
109131 data_stats ["seed_id" ] = args .seed_id
110132
111- logger .info (f"There is { data_stats ['html_empty_text' ]} empty text rows out of { len (ds_html )} rows." )
133+ logger .info (
134+ f"There is { data_stats ['html_empty_text' ]} empty text rows out of { len (ds_html )} rows."
135+ )
112136
113137 save_path = Path (args .save_path_stats_json )
114138 save_path_tmp = f"{ str (save_path .absolute ())} .tmp"
0 commit comments