-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutils_rover.py
More file actions
143 lines (127 loc) · 4.33 KB
/
utils_rover.py
File metadata and controls
143 lines (127 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os, sys
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import *
pretrained_model = "bert-base-cased"
tokenizer = BertTokenizer.from_pretrained(pretrained_model)
MAX_LEN = 512
fever2id = {"NEI": 0, "REFUTES": 1, "SUPPORTS": 2}
id2fever = {0: "NEI", 1: "REFUTES", 2: "SUPPORTS"}
class customDataset(Dataset):
def __init__(self, data, labels):
super().__init__()
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.labels[index]
def custom_collate_fn(batch):
batch_samples = []
batch_labels = []
for sample, label in batch:
batch_samples.append([sample["context"], sample["text"]])
batch_labels.append(label)
batch_tokens = tokenizer.batch_encode_plus(
batch_samples,
max_length=MAX_LEN,
pad_to_max_length=True,
return_tensors="pt",
add_special_tokens=True,
truncation_strategy="longest_first",
)
batch_labels = torch.tensor(batch_labels, dtype=torch.int64)
out_dict = {
"input_ids": batch_tokens["input_ids"],
"attention_mask": batch_tokens["attention_mask"],
"token_type_ids": batch_tokens["token_type_ids"],
"labels": batch_labels,
}
return out_dict
def load_dataset(file_path, batch_size, shuffle=False):
data = []
labels = []
with open(file_path, "r") as rf:
for line in rf:
json_dict = json.loads(line.strip())
context = json_dict["context"]
for question in json_dict["questions"]:
text = question["text"]
label = question["label"]
data.append({"context": context, "text": text})
labels.append(fever2id[label])
dataset = customDataset(data, labels)
data_loader = DataLoader(
dataset,
shuffle=shuffle,
num_workers=4,
batch_size=batch_size,
collate_fn=custom_collate_fn,
)
return data_loader, len(data)
def load_dataset_fever(file_path, batch_size, shuffle=False):
data = []
labels = []
with open(file_path, "r") as rf:
for line in rf:
json_dict = json.loads(line.strip())
label = False
if json_dict["label"] == "SUPPORTS":
label = True
elif (
json_dict["label"] == "REFUTES"
or json_dict["label"] == "NOT ENOUGH INFO"
):
label = False
else:
continue
labels.append(label)
claim = json_dict["claim"]
evidence = ""
for evd_sent in json_dict["ranked_evidence"]:
wikiTitle, sent = evd_sent[0], evd_sent[2]
# evidence += " ".join(
# [wikiTitle, sent]
# ) # concat wikititle and evidence sentence
evidence += sent # just evidence sentence
evidence += " "
evidence = evidence.rstrip()
data.append({"context": evidence, "text": claim})
dataset = customDataset(data, labels)
data_loader = DataLoader(
dataset,
shuffle=shuffle,
num_workers=4,
batch_size=batch_size,
collate_fn=custom_collate_fn,
)
return data_loader, len(data)
def load_dataset_symmetric_fever(file_path, batch_size, shuffle=False):
data = []
labels = []
with open(file_path, "r") as rf:
for line in rf:
json_dict = json.loads(line.strip())
label = None
assert (
json_dict["gold_label"] == json_dict["label"]
), "two labels in the input"
if json_dict["label"] == "SUPPORTS":
label = True
elif json_dict["label"] == "REFUTES":
label = False
assert label is not None, "unexpected label"
labels.append(label)
claim = json_dict["claim"]
evidence = json_dict["evidence"]
data.append({"context": evidence, "text": claim})
dataset = customDataset(data, labels)
data_loader = DataLoader(
dataset,
shuffle=shuffle,
num_workers=4,
batch_size=batch_size,
collate_fn=custom_collate_fn,
)
return data_loader, len(data)