|
| 1 | +import copy |
| 2 | +import json |
| 3 | +from typing import Dict, List |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +import torch |
| 8 | +from datasets import Dataset, DatasetDict |
| 9 | +from sentence_transformers import SentenceTransformer |
| 10 | +from torch.nn.functional import cosine_similarity |
| 11 | +from tqdm import tqdm |
| 12 | + |
| 13 | + |
| 14 | +def create_splits_for_hf_hub(train_dataset: str): |
| 15 | + # Dataset format should be a list of dictionaries, where each dictionary represents a data point. |
| 16 | + path_to_train_data = f"path/to/train/{train_dataset}.json" |
| 17 | + with open(path_to_train_data, "r") as f: |
| 18 | + data = json.load(f) |
| 19 | + |
| 20 | + for filter_by in ["entropy", "max"]: |
| 21 | + dataset_dict = DatasetDict() |
| 22 | + for setting in ["easy", "medium", "hard"]: |
| 23 | + new_split = create_splits( |
| 24 | + data, |
| 25 | + train_dataset, |
| 26 | + filter_by=filter_by, |
| 27 | + setting=setting, |
| 28 | + ) |
| 29 | + |
| 30 | + hf_format = [convert_to_hf_format(data_point) for data_point in new_split] |
| 31 | + |
| 32 | + ds = Dataset.from_pandas(pd.DataFrame(data=hf_format)) |
| 33 | + dataset_dict[setting] = ds |
| 34 | + |
| 35 | + dataset_dict.push_to_hub(f"{train_dataset}_{filter_by}_splits") |
| 36 | + |
| 37 | + |
| 38 | +def convert_to_hf_format(data_point): |
| 39 | + tags = ["O"] * len(data_point["tokenized_text"]) |
| 40 | + spans = [] |
| 41 | + for ent in data_point["ner"]: |
| 42 | + start, end, label = ent[0], ent[1], ent[2] |
| 43 | + spans.append({"start": start, "end": end, "label": label}) |
| 44 | + if start == end: |
| 45 | + tags[start] = "B-" + label |
| 46 | + else: |
| 47 | + try: |
| 48 | + tags[start] = "B-" + label |
| 49 | + tags[start + 1 : end + 1] = ["I-" + label] * (end - start) |
| 50 | + except: |
| 51 | + pass |
| 52 | + return {"tokens": data_point["tokenized_text"], "ner_tags": tags, "spans": spans} |
| 53 | + |
| 54 | + |
| 55 | +def create_splits( |
| 56 | + dataset: List[Dict], |
| 57 | + dataset_name: str, # The name of the dataset for which the splits should be created |
| 58 | + filter_by: str = "entropy", |
| 59 | + setting: str = "medium", |
| 60 | +): |
| 61 | + try: |
| 62 | + df = pd.read_pickle("new_splits.pkl") |
| 63 | + except: |
| 64 | + raise FileNotFoundError("Please run the compute_new_splits function first to generate the data.") |
| 65 | + df = df[(df["train_dataset"] == dataset_name)] |
| 66 | + |
| 67 | + selected_entity_types = [] |
| 68 | + for benchmark_name in df["eval_dataset"].unique(): |
| 69 | + _df = df[(df["eval_dataset"] == benchmark_name)].copy() |
| 70 | + |
| 71 | + # The thresholds are dataset specific and may need to be adjusted to account for dataset with different characteristics |
| 72 | + if filter_by == "entropy": |
| 73 | + low_threshold = df[filter_by].quantile(0.01) |
| 74 | + high_threshold = df[filter_by].quantile(0.95) |
| 75 | + elif filter_by == "max": |
| 76 | + low_threshold = df[filter_by].quantile(0.05) |
| 77 | + high_threshold = df[filter_by].quantile(0.99) |
| 78 | + |
| 79 | + medium_lower_threshold = df[filter_by].quantile(0.495) |
| 80 | + medium_upper_threshold = df[filter_by].quantile(0.505) |
| 81 | + |
| 82 | + # Define conditions and choices for categorization |
| 83 | + conditions = [ |
| 84 | + _df[filter_by] <= low_threshold, # Bottom |
| 85 | + _df[filter_by].between(medium_lower_threshold, medium_upper_threshold), # Middle |
| 86 | + _df[filter_by] >= high_threshold, # Top |
| 87 | + ] |
| 88 | + choices = ["easy", "medium", "hard"] if filter_by == "entropy" else ["hard", "medium", "easy"] |
| 89 | + |
| 90 | + # Use np.select to create the new column based on the conditions |
| 91 | + _df["difficulty"] = np.select(conditions, choices, default="not relevant") |
| 92 | + |
| 93 | + selected_entity_types.extend(_df[_df["difficulty"] == setting]["entity"].tolist()) |
| 94 | + |
| 95 | + new_dataset = [] |
| 96 | + for dp in tqdm(dataset): |
| 97 | + matched_entities = [x for x in dp["ner"] if x[-1].lower().strip() in selected_entity_types] |
| 98 | + if matched_entities: |
| 99 | + new_np = copy.deepcopy(dp) |
| 100 | + new_np["ner"] = matched_entities |
| 101 | + new_dataset.append(new_np) |
| 102 | + |
| 103 | + return new_dataset |
| 104 | + |
| 105 | + |
| 106 | +def compute_new_splits(): |
| 107 | + # TODO: you need to load the data into two variables: 'benchmarks' and 'training_datasets'. |
| 108 | + # 'benchmarks' should be a dictionary with the benchmark names as keys and the (list of distinct) entity types as values. |
| 109 | + # 'training_datasets' should be a dictionary with the training dataset names as keys and the (list of distinct) entity types as values. |
| 110 | + # We process multiple benchmarks and training datasets in this example, but you can adjust the code to fit your needs. |
| 111 | + # Further, we stick with the following dataset layout: list of dictionaries, where each dictionary represents a data point. |
| 112 | + # For example: [{'tokenized_text': [...], 'ner': [(start, end, entity_type), ...]}, ...] |
| 113 | + |
| 114 | + benchmarks = {} |
| 115 | + for benchmark_name in ['path/to/eval/dataset1.json', 'path/to/eval/dataset2.json']: |
| 116 | + # Data loading logic here, e.g.: |
| 117 | + # tokens, entity_types = load_eval_dataset(benchmark_name) |
| 118 | + # benchmarks[benchmark_name] = list(entity_types) |
| 119 | + pass |
| 120 | + |
| 121 | + training_datasets = {} |
| 122 | + for train_dataset_name in ['path/to/train/dataset1.json', 'path/to/train/dataset2.json']: |
| 123 | + # Data loading logic here, e.g.: |
| 124 | + # tokens, entity_types = load_train_dataset(train_dataset_name) |
| 125 | + # training_datasets[train_dataset_name] = list(entity_types) |
| 126 | + pass |
| 127 | + |
| 128 | + batch_size = 256 |
| 129 | + model = SentenceTransformer("all-mpnet-base-v2").to("cuda") |
| 130 | + eval_encodings = {} |
| 131 | + for benchmark_name, entity_types in benchmarks.items(): |
| 132 | + embeddings = model.encode(entity_types, convert_to_tensor=True, device="cuda") |
| 133 | + eval_encodings[benchmark_name] = embeddings |
| 134 | + |
| 135 | + results = {} |
| 136 | + for dataset_name, entity_types in training_datasets.items(): |
| 137 | + for i in tqdm(range(0, len(entity_types), batch_size)): |
| 138 | + dataset_name = dataset_name.split(".")[0] |
| 139 | + batch = entity_types[i : i + batch_size] |
| 140 | + embeddings = model.encode(batch, convert_to_tensor=True, device="cuda") |
| 141 | + for benchmark_name, eval_embeddings in eval_encodings.items(): |
| 142 | + similarities = torch.clamp( |
| 143 | + cosine_similarity( |
| 144 | + embeddings.unsqueeze(1), |
| 145 | + eval_embeddings.unsqueeze(0), |
| 146 | + dim=2, |
| 147 | + ), |
| 148 | + min=0.0, |
| 149 | + max=1.0, |
| 150 | + ) |
| 151 | + probabilities = torch.nn.functional.softmax(similarities / 0.01, dim=1) |
| 152 | + entropy_values = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=1) |
| 153 | + max_values, _ = torch.max(similarities, dim=1) |
| 154 | + |
| 155 | + if dataset_name not in results: |
| 156 | + results[dataset_name] = {} |
| 157 | + if benchmark_name not in results[dataset_name]: |
| 158 | + results[dataset_name][benchmark_name] = {} |
| 159 | + |
| 160 | + for j, entity in enumerate(batch): |
| 161 | + if entity not in results[dataset_name][benchmark_name]: |
| 162 | + results[dataset_name][benchmark_name][entity] = {} |
| 163 | + results[dataset_name][benchmark_name][entity]["entropy"] = entropy_values[j].cpu().numpy().item() |
| 164 | + results[dataset_name][benchmark_name][entity]["max"] = max_values[j].cpu().numpy().item() |
| 165 | + |
| 166 | + entries = [] |
| 167 | + for dataset_name, eval_comparisons in results.items(): |
| 168 | + for benchmark_name, mapping in eval_comparisons.items(): |
| 169 | + for entity, values in mapping.items(): |
| 170 | + entries.append( |
| 171 | + { |
| 172 | + "entity": entity, |
| 173 | + "entropy": values["entropy"], |
| 174 | + "max": values["max"], |
| 175 | + "eval_dataset": benchmark_name, |
| 176 | + "train_dataset": dataset_name, |
| 177 | + } |
| 178 | + ) |
| 179 | + df = pd.DataFrame.from_dict(entries, orient="columns") |
| 180 | + df.to_pickle("new_splits.pkl") |
0 commit comments