Skip to content

Commit 85dd893

Browse files
committed
Support Multi-Supervised All Datasets
1 parent 2ac056f commit 85dd893

File tree

4 files changed

+424
-0
lines changed

4 files changed

+424
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Modified from: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/paraphrases/MultiDatasetDataLoader.py
2+
3+
import math
4+
import logging
5+
import random
6+
7+
8+
class MultiDatasetDataLoader:
9+
def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1):
10+
self.allow_swap = True
11+
self.batch_size_pairs = batch_size_pairs
12+
self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets
13+
14+
# Compute dataset weights
15+
self.dataset_lengths = list(map(len, datasets))
16+
self.dataset_lengths_sum = sum(self.dataset_lengths)
17+
18+
weights = []
19+
if dataset_size_temp > 0: # Scale probability with dataset size
20+
for dataset in datasets:
21+
prob = len(dataset) / self.dataset_lengths_sum
22+
weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000)))
23+
else: # Equal weighting of all datasets
24+
weights = [100] * len(datasets)
25+
26+
logging.info("Dataset lengths and weights: {}".format(list(zip(self.dataset_lengths, weights))))
27+
28+
self.dataset_idx = []
29+
self.dataset_idx_pointer = 0
30+
31+
for idx, weight in enumerate(weights):
32+
self.dataset_idx.extend([idx] * weight)
33+
random.shuffle(self.dataset_idx)
34+
35+
self.datasets = []
36+
for dataset in datasets:
37+
random.shuffle(dataset)
38+
self.datasets.append(
39+
{
40+
"elements": dataset,
41+
"pointer": 0,
42+
}
43+
)
44+
45+
def __iter__(self):
46+
for _ in range(int(self.__len__())):
47+
# Select dataset
48+
if self.dataset_idx_pointer >= len(self.dataset_idx):
49+
self.dataset_idx_pointer = 0
50+
random.shuffle(self.dataset_idx)
51+
52+
dataset_idx = self.dataset_idx[self.dataset_idx_pointer]
53+
self.dataset_idx_pointer += 1
54+
55+
# Select batch from this dataset
56+
dataset = self.datasets[dataset_idx]
57+
batch_size = self.batch_size_pairs if len(dataset["elements"][0].texts) == 2 else self.batch_size_triplets
58+
59+
batch = []
60+
while len(batch) < batch_size:
61+
example = dataset["elements"][dataset["pointer"]]
62+
63+
if self.allow_swap and random.random() > 0.5:
64+
example.texts[0], example.texts[1] = example.texts[1], example.texts[0]
65+
66+
batch.append(example)
67+
68+
dataset["pointer"] += 1
69+
if dataset["pointer"] >= len(dataset["elements"]):
70+
dataset["pointer"] = 0
71+
random.shuffle(dataset["elements"])
72+
73+
yield self.collate_fn(batch) if self.collate_fn is not None else batch
74+
75+
def __len__(self):
76+
return int(self.dataset_lengths_sum / self.batch_size_pairs)

training/all/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# All Supervised Datasets
2+
3+
Inspired by [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), we fine-tuned Indonesian sentence embedding models on a set of existing supervised datasets. The tasks included in the training dataset are: question-answering, textual entailment, retrieval, commonsense reasoning, and natural language inference. Currently, our script simply concatenates these datasets and our models are trained conventionally using the `MultipleNegativesRankingLoss`.
4+
5+
## Training Data
6+
7+
| Dataset | Task | Type | Number of Training Tuples |
8+
| --------- | :------------------------: | :------: | :-----------------------: |
9+
| indonli | Natural Language Inference | triplets | 3,914 |
10+
| | | | |
11+
| | | | |
12+
| | | | |
13+
| | | | |
14+
| | | | |
15+
| | | | |
16+
| | | | |
17+
| **Total** | | | **135,258** |
18+
19+
## All Supervised Datasets with MultipleNegativesRankingLoss
20+
21+
### IndoBERT Base
22+
23+
```sh
24+
python train_all_mnrl.py \
25+
--model-name indobenchmark/indobert-base-p1 \
26+
--max-seq-length 128 \
27+
--num-epochs 5 \
28+
--train-batch-size-pairs 384 \
29+
--train-batch-size-triplets 256 \
30+
--learning-rate 2e-5
31+
```
32+
33+
## References
34+
35+
```bibtex
36+
@inproceedings{reimers-2019-sentence-bert,
37+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
38+
author = "Reimers, Nils and Gurevych, Iryna",
39+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
40+
month = "11",
41+
year = "2019",
42+
publisher = "Association for Computational Linguistics",
43+
url = "https://arxiv.org/abs/1908.10084",
44+
}
45+
```

training/all/all_datasets.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from typing import List
2+
from dataclasses import dataclass
3+
import random
4+
5+
from datasets import load_dataset
6+
from sentence_transformers import InputExample
7+
8+
##############
9+
# PAIRS
10+
##############
11+
12+
13+
@dataclass
14+
class WReTE:
15+
dataset = load_dataset("SEACrowd/wrete", split="train", trust_remote_code=True)
16+
# filter for entailment pairs
17+
dataset = dataset.filter(lambda example: example["label"] == "Entail_or_Paraphrase")
18+
19+
@staticmethod
20+
def train_samples() -> List[InputExample]:
21+
train_samples = []
22+
23+
for datum in WReTE.dataset:
24+
train_samples.append(InputExample(texts=[datum["sent_A"], datum["sent_B"]]))
25+
26+
return train_samples
27+
28+
29+
@dataclass
30+
class IndoLEMNTP:
31+
dataset = load_dataset("SEACrowd/indolem_ntp", split="train", trust_remote_code=True)
32+
# filter for entailment pairs
33+
dataset = dataset.filter(lambda example: example["label"] == 1)
34+
35+
@staticmethod
36+
def train_samples() -> List[InputExample]:
37+
train_samples = []
38+
39+
for datum in IndoLEMNTP.dataset:
40+
train_samples.append(InputExample(texts=[datum["tweets"], datum["next_tweet"]]))
41+
42+
return train_samples
43+
44+
45+
@dataclass
46+
class TyDiQA:
47+
dataset = load_dataset("khalidalt/tydiqa-goldp", "indonesian", split="train", trust_remote_code=True).shuffle(
48+
seed=42
49+
)
50+
51+
@staticmethod
52+
def train_samples() -> List[InputExample]:
53+
train_samples = []
54+
55+
for datum in TyDiQA.dataset:
56+
train_samples.append(InputExample(texts=[datum["question_text"], datum["passage_text"]]))
57+
train_samples.append(InputExample(texts=[datum["question_text"], datum["answers"]["text"][0]]))
58+
59+
return train_samples
60+
61+
62+
@dataclass
63+
class FacQA:
64+
dataset = load_dataset("SEACrowd/facqa", split="train", trust_remote_code=True)
65+
66+
@staticmethod
67+
def train_samples() -> List[InputExample]:
68+
train_samples = []
69+
70+
for datum in FacQA.dataset:
71+
question = " ".join(datum["question"])
72+
passage = " ".join(datum["passage"])
73+
answer = " ".join(t for t, l in zip(datum["passage"], datum["seq_label"]) if l != "O")
74+
75+
train_samples.append(InputExample(texts=[question, passage]))
76+
train_samples.append(InputExample(texts=[question, answer]))
77+
78+
return train_samples
79+
80+
81+
##############
82+
# TRIPLETS
83+
##############
84+
85+
86+
@dataclass
87+
class mMARCO:
88+
dataset = load_dataset("unicamp-dl/mmarco", "indonesian", split="train", trust_remote_code=True)
89+
# limit to only 100,000 rows
90+
dataset = dataset.shuffle(seed=42).select(range(100_000))
91+
92+
@staticmethod
93+
def train_samples() -> List[InputExample]:
94+
train_samples = []
95+
96+
for datum in mMARCO.dataset:
97+
train_samples.append(
98+
InputExample(
99+
texts=[
100+
datum["query"],
101+
datum["positive"],
102+
datum["negative"],
103+
]
104+
)
105+
)
106+
107+
return train_samples
108+
109+
110+
@dataclass
111+
class MIRACL:
112+
dataset = load_dataset("miracl/miracl", "id", split="train", trust_remote_code=True)
113+
114+
@staticmethod
115+
def train_samples() -> List[InputExample]:
116+
train_samples = []
117+
118+
for datum in MIRACL.dataset:
119+
query = datum["query"]
120+
positives = [doc["text"] for doc in datum["positive_passages"]]
121+
negatives = [doc["text"] for doc in datum["negative_passages"]]
122+
123+
if len(negatives) > 0:
124+
train_samples.append(InputExample(texts=[query, random.choice(positives), random.choice(negatives)]))
125+
train_samples.append(InputExample(texts=[random.choice(positives), query, random.choice(negatives)]))
126+
127+
return train_samples
128+
129+
130+
@dataclass
131+
class IndoStoryCloze:
132+
dataset = load_dataset("indolem/indo_story_cloze", split="train", trust_remote_code=True)
133+
134+
@staticmethod
135+
def train_samples() -> List[InputExample]:
136+
train_samples = []
137+
138+
for datum in IndoStoryCloze.dataset:
139+
context = ". ".join([datum["sentence-1"], datum["sentence-2"], datum["sentence-3"], datum["sentence-4"]])
140+
train_samples.append(
141+
InputExample(
142+
texts=[
143+
context,
144+
datum["correct_ending"],
145+
datum["incorrect_ending"],
146+
]
147+
)
148+
)
149+
150+
return train_samples
151+
152+
153+
@dataclass
154+
class IndoNLI:
155+
dataset = load_dataset("indonli", split="train", trust_remote_code=True)
156+
id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
157+
158+
@staticmethod
159+
def train_samples() -> List[InputExample]:
160+
def add_to_samples(sent1, sent2, label):
161+
if sent1 not in train_data:
162+
train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()}
163+
train_data[sent1][label].add(sent2)
164+
165+
train_data = {}
166+
train_samples = []
167+
168+
for datum in IndoNLI.dataset:
169+
sent1 = datum["premise"].strip()
170+
sent2 = datum["hypothesis"].strip()
171+
172+
add_to_samples(sent1, sent2, IndoNLI.id2label[datum["label"]])
173+
add_to_samples(sent2, sent1, IndoNLI.id2label[datum["label"]]) # Also add the opposite
174+
175+
for sent1, others in train_data.items():
176+
if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0:
177+
train_samples.append(
178+
InputExample(
179+
texts=[
180+
sent1,
181+
random.choice(list(others["entailment"])),
182+
random.choice(list(others["contradiction"])),
183+
]
184+
)
185+
)
186+
train_samples.append(
187+
InputExample(
188+
texts=[
189+
random.choice(list(others["entailment"])),
190+
sent1,
191+
random.choice(list(others["contradiction"])),
192+
]
193+
)
194+
)
195+
196+
return train_samples

0 commit comments

Comments
 (0)