|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import subprocess |
| 4 | +import sys |
| 5 | +from argparse import ArgumentParser |
| 6 | +from pathlib import Path |
| 7 | +from statistics import mean |
| 8 | + |
| 9 | +import datasets |
| 10 | +from bs4 import BeautifulSoup |
| 11 | +from bs4.dammit import EncodingDetector |
| 12 | +from datasets import config, load_from_disk |
| 13 | +from datasets.utils.logging import set_verbosity_info |
| 14 | + |
| 15 | +set_verbosity_info() |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | +# For `soup.decode_content` that can hit the limit |
| 19 | +sys.setrecursionlimit(10000) |
| 20 | + |
| 21 | + |
| 22 | +def get_args(): |
| 23 | + parser = ArgumentParser() |
| 24 | + parser.add_argument( |
| 25 | + "--dataset-path", |
| 26 | + type=str, |
| 27 | + required=True, |
| 28 | + help="path to the parquet dataset folder", |
| 29 | + ) |
| 30 | + parser.add_argument("--save-path-stats-json", type=str, help="Where to save the stats json.") |
| 31 | + parser.add_argument("--save-path-stats-full-json", type=str, help="Where to save the stats json.") |
| 32 | + parser.add_argument("--save-batch-size", type=int, required=True, help="Batch size when writing.") |
| 33 | + parser.add_argument("--use-datasets-caching", action="store_true") |
| 34 | + parser.add_argument("--num-proc", type=int, default=1, help="Number of procs use for preprocessing.") |
| 35 | + parser.add_argument( |
| 36 | + "--seed-id", |
| 37 | + type=int, |
| 38 | + required=True, |
| 39 | + help="Value of the seed id.", |
| 40 | + ) |
| 41 | + parser.add_argument( |
| 42 | + "--num-examples", |
| 43 | + type=int, |
| 44 | + default=None, |
| 45 | + help="Optional argument to select a subset (used for debugging purposes). Example `10`.", |
| 46 | + ) |
| 47 | + args = parser.parse_args() |
| 48 | + |
| 49 | + return args |
| 50 | + |
| 51 | + |
| 52 | +def main(): |
| 53 | + # Setup logging |
| 54 | + logging.basicConfig( |
| 55 | + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 56 | + datefmt="%m/%d/%Y %H:%M:%S", |
| 57 | + level=logging.INFO, |
| 58 | + ) |
| 59 | + args = get_args() |
| 60 | + logger.info(f"** The job is runned with the following arguments: **\n{args}\n **** ") |
| 61 | + |
| 62 | + if not args.use_datasets_caching: |
| 63 | + datasets.set_caching_enabled(False) |
| 64 | + else: |
| 65 | + logger.info(f"the datasets results will be cached at {config.HF_DATASETS_CACHE}.") |
| 66 | + |
| 67 | + ds = load_from_disk(args.dataset_path) |
| 68 | + |
| 69 | + if args.num_examples: |
| 70 | + ds = ds.select([i for i in range(args.num_examples)]) |
| 71 | + |
| 72 | + selected_mime_types = ["text/html"] |
| 73 | + splits = { |
| 74 | + **{ |
| 75 | + mime_type: ds.filter( |
| 76 | + lambda mime_types_: [mime_type_ == mime_type for mime_type_ in mime_types_], |
| 77 | + input_columns="content_mime_detected", |
| 78 | + batched=True, |
| 79 | + num_proc=args.num_proc, |
| 80 | + ) |
| 81 | + for mime_type in selected_mime_types |
| 82 | + }, |
| 83 | + "others": ds.filter( |
| 84 | + lambda mime_types_: [mime_type_ not in selected_mime_types for mime_type_ in mime_types_], |
| 85 | + input_columns="content_mime_detected", |
| 86 | + batched=True, |
| 87 | + num_proc=args.num_proc, |
| 88 | + ), |
| 89 | + } |
| 90 | + |
| 91 | + data_stats = {f"{split_name}_total": len(ds) for split_name, ds in splits.items()} |
| 92 | + |
| 93 | + ds_html = splits[selected_mime_types[0]] |
| 94 | + |
| 95 | + def get_length_text(example): |
| 96 | + example["length_text"] = len(example["text"]) |
| 97 | + return example |
| 98 | + |
| 99 | + cols_to_remove = [col for col in ds.column_names if col not in ["content_languages", "url_host_tld"]] |
| 100 | + ds_html = ds_html.map( |
| 101 | + get_length_text, |
| 102 | + batched=False, |
| 103 | + num_proc=args.num_proc, |
| 104 | + remove_columns=cols_to_remove, |
| 105 | + ) |
| 106 | + |
| 107 | + data_stats["html_empty_text"] = len([e for e in ds_html["length_text"] if e == 0]) |
| 108 | + data_stats["html_mean_length_non_empty_text"] = mean([e for e in ds_html["length_text"] if e != 0]) |
| 109 | + data_stats["seed_id"] = args.seed_id |
| 110 | + |
| 111 | + logger.info(f"There is {data_stats['html_empty_text']} empty text rows out of {len(ds_html)} rows.") |
| 112 | + |
| 113 | + save_path = Path(args.save_path_stats_json) |
| 114 | + save_path_tmp = f"{str(save_path.absolute())}.tmp" |
| 115 | + logger.info(f"Saving the dataset at {save_path_tmp}") |
| 116 | + with open(save_path_tmp, "w", encoding="utf-8") as f: |
| 117 | + json.dump(data_stats, f, ensure_ascii=False, indent=4) |
| 118 | + logger.info(f"Moving the saved dataset to {str(save_path.absolute())}") |
| 119 | + subprocess.run(["mv", save_path_tmp, str(save_path.absolute())]) |
| 120 | + |
| 121 | + save_path = Path(args.save_path_stats_full_json) |
| 122 | + save_path_tmp = f"{str(save_path.absolute())}.tmp" |
| 123 | + logger.info(f"Saving the dataset at {save_path_tmp}") |
| 124 | + ds_html.to_json( |
| 125 | + save_path_tmp, |
| 126 | + batch_size=args.save_batch_size, |
| 127 | + num_proc=args.num_proc, |
| 128 | + compression="gzip", |
| 129 | + ) |
| 130 | + logger.info(f"Moving the saved dataset to {str(save_path.absolute())}") |
| 131 | + subprocess.run(["mv", save_path_tmp, str(save_path.absolute())]) |
| 132 | + |
| 133 | + |
| 134 | +if __name__ == "__main__": |
| 135 | + main() |
0 commit comments