-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
163 lines (130 loc) · 5.56 KB
/
training.py
File metadata and controls
163 lines (130 loc) · 5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Required Imports
import datasets
import numpy as np
from transformers import BertTokenizerFast, ElectraTokenizerFast
from transformers import DataCollatorForTokenClassification
from transformers import AutoModelForTokenClassification
from transformers import TrainingArguments, Trainer
class NERTraining:
def __init__(self, model_name) -> None:
self.model_name = model_name
def get_dataset(self, dataset_name: str):
"""Downloads dataset from hugging face"""
dataset = None
try:
dataset = datasets.load_dataset(dataset_name)
except Exception as ex:
print("Unable to download dataset - ", ex)
return dataset
def get_tokenizer(self):
"""Returns the tokenizer based on the model selected"""
tokenizer = None
try:
if self.model_name.lower() == "bert-base-uncased":
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
elif "electra" in self.model_name:
tokenizer = ElectraTokenizerFast.from_pretrained("bert-base-uncased")
except Exception as ex:
print("Unable to get tokenizer for the model -%s", self.model_name)
return tokenizer
def format_labels(self, data, label_all=True):
"""
Appends -100 for the None Type and returns the labels
"""
tokenizer = self.get_tokenizer()
tokenized_input = tokenizer(data['tokens'], truncation=True, is_split_into_words=True)
labels = []
for i, label in enumerate(data['ner_tags']):
word_ids = tokenized_input.word_ids(batch_index=i)
label_ids = []
pre_ind = None
for wi in word_ids:
if wi is None:
label_ids.append(-100)
elif wi != pre_ind:
label_ids.append(label[wi])
else:
label_ids.append(label[wi] if label_all else -100)
pre_ind = wi
# now append to labels list
labels.append(label_ids)
tokenized_input['labels'] = labels
return tokenized_input
def get_model(self):
"""Returns the model instance"""
model = None
try:
model = AutoModelForTokenClassification.from_pretrained(self.model_name, num_labels=9)
except Exception as ex:
print("Unable to download the model - ",self.model_name)
return model
def set_arguments(self, m_args:dict):
"""Based on give settings create args object"""
args = None
try:
args = TrainingArguments(**m_args)
except Exception as ex:
print("Unable to create args object based on the provided - ", ex)
return args
def get_data_collator(self, tokenizer):
"""data collator """
data_collator = None
try:
data_collator = DataCollatorForTokenClassification(tokenizer)
except Exception as ex:
print("Data collator operation failed - ", ex)
return data_collator
def get_metrics(self):
metrics = None
try:
metrics = datasets.load_metric("seqeval")
except Exception as ex:
print("Unable to load metrics from seqeval - ", ex)
return metrics
def compute_metrics(self, p):
"""computest result for the prediction and actual output"""
label_list = dataset['train'].features['ner_tags'].feature.names
metrics = self.get_metrics()
predictions, labels = p
#select predicted index with maximum logit for each token
predictions = np.argmax(predictions, axis=2)
# model predictions
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
# actual prediction
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
# compute result
results = metrics.compute(predictions=true_predictions, references=true_labels)
result_dict = {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
return result_dict
def model_training(self, model, args, train_dataset, eval_dataset, data_collator, tokenizer, compute_metrics):
"""Trains the model based on give params"""
try:
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
except Exception as ex:
print("Unable to train the model - ", ex)
return trainer
def save_model(self, model, tokenizer, loc_name, label_list):
"""saves the artificats to given location"""
model.save_pretrained(loc_name)
tokenizer.save_pretrained("tokenizer")
print("Successfully saved the model :)")