forked from EuroEval/EuroEval
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_estner.py
More file actions
114 lines (91 loc) · 2.9 KB
/
create_estner.py
File metadata and controls
114 lines (91 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# /// script
# requires-python = ">=3.10,<4.0"
# dependencies = [
# "datasets==4.0.0",
# "huggingface-hub==0.34.4",
# "requests==2.32.5",
# ]
# ///
"""Create the Estonian NER dataset and upload to HF Hub."""
from typing import MutableMapping
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi
def main() -> None:
"""Create the Estonian NER dataset and upload to HF Hub."""
target_repo_id = "EuroEval/estner-mini"
# start from the official source
ds = load_dataset("tartuNLP/EstNER")
assert isinstance(ds, DatasetDict)
# target split sizes
train_size = 1024
val_size = 256
test_size = 2048
ds = ds.shuffle(seed=42)
train_ds = ds["train"].select(range(train_size))
val_ds = ds["dev"].select(range(val_size))
test_ds = ds["test"].select(range(test_size))
assert isinstance(train_ds, Dataset)
assert isinstance(val_ds, Dataset)
assert isinstance(test_ds, Dataset)
ds = DatasetDict({"train": train_ds, "val": val_ds, "test": test_ds})
ds = ds.rename_column("ner_tags", "labels")
# a separate text column is not available
ds = ds.map(lambda row: {"text": " ".join(row["tokens"])})
ds = ds.select_columns(["text", "tokens", "labels"])
# reduce the number of diverse labels by mapping to MISC
ds = ds.map(convert_labels)
# remove the dataset from Hugging Face Hub if it already exists
HfApi().delete_repo(target_repo_id, repo_type="dataset", missing_ok=True)
# push the dataset to the Hugging Face Hub
ds.push_to_hub(target_repo_id, private=True)
def get_label_map() -> dict[str, str]:
"""Get the map from original labels to the EuroEval ones.
Returns:
The mapping.
"""
original_labels = {
"I-DATE",
"I-PERCENT",
"I-TITLE",
"B-PERCENT",
"B-LOC",
"B-MONEY",
"I-GPE",
"B-ORG",
"B-PROD",
"I-LOC",
"B-GPE",
"B-TITLE",
"I-EVENT",
"I-MONEY",
"I-TIME",
"B-EVENT",
"B-TIME",
"I-PROD",
"I-PER",
"B-DATE",
"B-PER",
"I-ORG",
}
label_map = {}
for label in original_labels:
position, entity_type = label.split("-")
if entity_type not in ("PER", "LOC", "ORG", "MISC"):
if entity_type in ("DATE", "TIME", "PERCENT", "MONEY"):
label_map[label] = "O"
else:
label_map[label] = f"{position}-{entity_type}"
return label_map
def convert_labels(row: MutableMapping) -> MutableMapping:
"""Convert original labels in a row to new ones.
Args:
row:
A row of the original dataset.
Returns:
The updated row, with the new labels.
"""
label_map = get_label_map()
row["labels"] = [label_map.get(label, label) for label in row["labels"]]
return row
if __name__ == "__main__":
main()