Skip to content
This repository was archived by the owner on Nov 8, 2022. It is now read-only.

Commit 060f0cd

Browse files
shira-gPeter Izsak
authored andcommitted
Pseudo Labeling Distillation for token classification (#111)
* add pseudo-labeling procedure
1 parent 4df6568 commit 060f0cd

File tree

22 files changed

+454
-87
lines changed

22 files changed

+454
-87
lines changed

docs-source/source/transformers_distillation.rst

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ One approach is similar to the method in Hinton 2015 [#]_. The loss function is
3939
modified to include a measure of distributions divergence, which can be measured
4040
using KL divergence or MSE between the logits of the student and the teacher network.
4141

42-
:math:`loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)`
42+
:math:`loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)`
4343

4444
where *T* is a value representing temperature for softening the logits prior to
4545
applying softmax. `loss_{student}` is the original loss of the student network
@@ -73,5 +73,19 @@ Usage:
7373
.. note::
7474
More models supporting distillation will be added in next releases
7575

76+
Pseudo Labeling
77+
================
78+
79+
This method can be used in order to produce pseudo-labels when training the student on unlabeled examples.
80+
The pseudo-guess is produced by applying arg max on the logits of the teacher model, and results in the following loss:
81+
82+
.. math::
83+
84+
loss &= \Bigg\{\begin{eqnarray}CE(yˆ, y) && labeled&example\\ CE(yˆ, yˆt) && unlabeled&example\end{eqnarray}
85+
86+
87+
where CE is Cross Entropy loss, yˆ is the predicted entity label class by the student model and yˆt is
88+
the predicted label by the teacher model.
89+
7690

7791
.. [#] Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean, https://arxiv.org/abs/1503.02531

nlp_architect/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ******************************************************************************
1616
"""
17-
Generic config object:
17+
Generic config object:
1818
load config from json file
1919
load config from ordinary python dict
2020
export config as dictionaty or json string

nlp_architect/data/sequential_tagging.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,14 +326,14 @@ def _read_examples(self, data_dir, file_name, set_name):
326326
return self._create_examples(read_column_tagged_file(os.path.join(data_dir, file_name),
327327
tag_col=self.tag_col), set_name)
328328

329-
def get_train_examples(self):
330-
return self._read_examples(self.data_dir, "train.txt", "train")
329+
def get_train_examples(self, filename="train.txt"):
330+
return self._read_examples(self.data_dir, filename, "train")
331331

332-
def get_dev_examples(self):
333-
return self._read_examples(self.data_dir, "dev.txt", "dev")
332+
def get_dev_examples(self, filename="dev.txt"):
333+
return self._read_examples(self.data_dir, filename, "dev")
334334

335-
def get_test_examples(self):
336-
return self._read_examples(self.data_dir, "test.txt", "test")
335+
def get_test_examples(self, filename="test.txt"):
336+
return self._read_examples(self.data_dir, filename, "test")
337337

338338
# pylint: disable=arguments-differ
339339
def get_labels(self):

nlp_architect/data/utils.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def __init__(self, guid: str, text, label=None):
3535
class DataProcessor(object):
3636
"""Base class for data converters for sequence/token classification data sets."""
3737

38-
def get_train_examples(self, data_dir):
38+
def get_train_examples(self):
3939
"""Gets a collection of `InputExample`s for the train set."""
4040
raise NotImplementedError()
4141

42-
def get_dev_examples(self, data_dir):
42+
def get_dev_examples(self):
4343
"""Gets a collection of `InputExample`s for the dev set."""
4444
raise NotImplementedError()
4545

46-
def get_test_examples(self, data_dir):
46+
def get_test_examples(self):
4747
"""Gets a collection of `InputExample`s for the test set."""
4848
raise NotImplementedError()
4949

@@ -66,12 +66,6 @@ def __init__(self, name: str, processor: DataProcessor, data_dir: str, task_type
6666
self.data_dir = data_dir
6767
self.task_type = task_type
6868

69-
def get_split_train_examples(self, labeled: int, unlabeled: int):
70-
"""split the train set into 2 sub sets (given by input size) to be
71-
used as labelled and unlabeled sets for semi-supervision tasks
72-
"""
73-
return self.processor.get_split_train_examples(self.data_dir, labeled, unlabeled)
74-
7569
def get_train_examples(self):
7670
return self.processor.get_train_examples(self.data_dir)
7771

@@ -154,3 +148,26 @@ def sample_label_unlabeled(samples: List[InputExample], no_labeled: int, no_unla
154148
label_samples = [samples[i] for i in labeled_indices]
155149
unlabel_samples = [samples[i] for i in unlabeled_indices]
156150
return label_samples, unlabel_samples
151+
152+
153+
def split_column_dataset(
154+
first_count: int, second_count: int, out_folder, dataset, first_filename, second_filename, tag_col=-1):
155+
"""
156+
Splits a single column tagged dataset into two files according to the amount of examples
157+
requested to be included in each file.
158+
split1_count (int) : the amount of examples to include in the first split file
159+
split2_count (int) : the amount of examples to include in the second split file
160+
out_folder (str) : the folder in which the result files will be stored
161+
dataset (str) : the path to the original data file
162+
split1_filename (str) : the name of the first split file
163+
split2_filename (str) : the name of the second split file
164+
tag_col (int) : the index of the tag column
165+
"""
166+
lines = read_column_tagged_file(dataset, tag_col=tag_col)
167+
num_of_examples = len(lines)
168+
assert first_count + second_count <= num_of_examples and first_count > 0 and second_count > 0
169+
selected_lines = random.sample(lines, first_count + second_count)
170+
first_data = selected_lines[:first_count]
171+
second_data = selected_lines[first_count:]
172+
write_column_tagged_file(out_folder + os.sep + first_filename, first_data)
173+
write_column_tagged_file(out_folder + os.sep + second_filename, second_data)

nlp_architect/models/absa/train/acquire_terms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tqdm import tqdm
2222

2323
from nlp_architect.models.absa import TRAIN_LEXICONS, LEXICONS_OUT
24-
from nlp_architect.models.absa import TRAIN_OUT, GENERIC_OP_LEX
24+
from nlp_architect.models.absa import GENERIC_OP_LEX
2525
from nlp_architect.models.absa.inference.data_types import Polarity
2626
from nlp_architect.models.absa.train.data_types import AspectTerm, \
2727
DepRelation, DepRelationTerm, LoadOpinionStopLists, LoadAspectStopLists, OpinionTerm, \

nlp_architect/models/matchlstm_ansptr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,14 @@ def inference_mode(self, session, valid, vocab_tuple, num_examples, dropout=1.0,
526526
# Print Paragraph
527527
print("\n")
528528
print("Paragraph Number AA:", idx)
529-
test_paragraph = [vocab_forward[ele].replace(" ", "") for ele in valid[idx][0] if ele != 0]
529+
test_paragraph = [vocab_forward[ele].replace(
530+
" ", "") for ele in valid[idx][0] if ele != 0]
530531
para_string = " ".join(map(str, test_paragraph))
531532
print(para_string)
532533

533534
# Print corresponding Question
534-
test_question = [vocab_forward[ele].replace(" ", "") for ele in valid[idx][1] if ele != 0]
535+
test_question = [vocab_forward[ele].replace(
536+
" ", "") for ele in valid[idx][1] if ele != 0]
535537
ques_string = " ".join(map(str, test_question))
536538
print("Question:", ques_string)
537539
question_ids = valid[idx][1]

nlp_architect/models/tagging.py

Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,12 @@ def train(self, train_data_set: DataLoader,
221221
for _ in epoch_it:
222222
step_it = tqdm(train_data_set, desc="Train iteration")
223223
avg_loss = 0
224-
for s_idx, batch in enumerate(step_it):
224+
for step, batch in enumerate(step_it):
225225
self.model.train()
226226
if distiller:
227227
batch, t_batch = batch[:2]
228228
t_batch = tuple(t.to(self.device) for t in t_batch)
229+
t_logits = distiller.get_teacher_logits(t_batch)
229230
batch = tuple(t.to(self.device) for t in batch)
230231
inputs = self.batch_mapper(batch)
231232
logits = self.model(**inputs)
@@ -239,7 +240,6 @@ def train(self, train_data_set: DataLoader,
239240

240241
# add distillation loss if activated
241242
if distiller:
242-
t_logits = distiller.get_teacher_logits(t_batch)
243243
loss = distiller.distill_loss(loss, logits, t_logits)
244244

245245
loss.backward()
@@ -251,17 +251,157 @@ def train(self, train_data_set: DataLoader,
251251
global_step += 1
252252
avg_loss += loss.item()
253253
if global_step % logging_steps == 0:
254-
logger.info(" global_step = %s, average loss = %s", global_step, avg_loss / s_idx)
254+
if step != 0:
255+
logger.info(
256+
" global_step = %s, average loss = %s", global_step, avg_loss / step)
255257
self._get_eval(dev_data_set, "dev")
256258
self._get_eval(test_data_set, "test")
257259
if save_path is not None and global_step % save_steps == 0:
258260
self.save_model(save_path)
259261

262+
def train_pseudo(
263+
self, labeled_data_set: DataLoader,
264+
unlabeled_data_set: DataLoader,
265+
distiller: TeacherStudentDistill,
266+
dev_data_set: DataLoader = None,
267+
test_data_set: DataLoader = None,
268+
batch_size_l: int = 8,
269+
batch_size_ul: int = 8,
270+
epochs: int = 100,
271+
optimizer=None,
272+
max_grad_norm: float = 5.0,
273+
logging_steps: int = 50,
274+
save_steps: int = 100,
275+
save_path: str = None,
276+
save_best: bool = False):
277+
"""
278+
Train a tagging model
279+
280+
Args:
281+
train_data_set (DataLoader): train examples dataloader. If distiller object is
282+
provided train examples should contain a tuple of student/teacher data examples.
283+
dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
284+
test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
285+
batch_size_l (int, optional): batch size for the labeled dataset. Defaults to 8.
286+
batch_size_ul (int, optional): batch size for the unlabeled dataset. Defaults to 8.
287+
epochs (int, optional): num of epochs to train. Defaults to 100.
288+
optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
289+
max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
290+
logging_steps (int, optional): number of steps between logging. Defaults to 50.
291+
save_steps (int, optional): number of steps between model saves. Defaults to 100.
292+
save_path (str, optional): model output path. Defaults to None.
293+
save_best (str, optional): wether to save model when result is best on dev set
294+
distiller (TeacherStudentDistill, optional): KD model for training the model using
295+
a teacher model. Defaults to None.
296+
"""
297+
if optimizer is None:
298+
optimizer = self.get_optimizer()
299+
train_batch_size_l = batch_size_l * max(1, self.n_gpus)
300+
train_batch_size_ul = batch_size_ul * max(1, self.n_gpus)
301+
logger.info("***** Running training *****")
302+
logger.info(" Num labeled examples = %d", len(labeled_data_set.dataset))
303+
logger.info(" Num unlabeled examples = %d", len(unlabeled_data_set.dataset))
304+
logger.info(" Instantaneous labeled batch size per GPU/CPU = %d",
305+
batch_size_l)
306+
logger.info(" Instantaneous unlabeled batch size per GPU/CPU = %d",
307+
batch_size_ul)
308+
logger.info(" Total batch size labeled= %d", train_batch_size_l)
309+
logger.info(" Total batch size unlabeled= %d", train_batch_size_ul)
310+
global_step = 0
311+
self.model.zero_grad()
312+
avg_loss = 0
313+
iter_l = iter(labeled_data_set)
314+
iter_ul = iter(unlabeled_data_set)
315+
epoch_l = 0
316+
epoch_ul = 0
317+
s_idx = -1
318+
best_dev = 0
319+
best_test = 0
320+
while(True):
321+
logger.info("labeled epoch=%d, unlabeled epoch=%d", epoch_l, epoch_ul)
322+
loss_labeled = 0
323+
loss_unlabeled = 0
324+
try:
325+
batch_l = next(iter_l)
326+
s_idx += 1
327+
except StopIteration:
328+
iter_l = iter(labeled_data_set)
329+
epoch_l += 1
330+
batch_l = next(iter_l)
331+
s_idx = 0
332+
avg_loss = 0
333+
try:
334+
batch_ul = next(iter_ul)
335+
except StopIteration:
336+
iter_ul = iter(unlabeled_data_set)
337+
epoch_ul += 1
338+
batch_ul = next(iter_ul)
339+
if epoch_ul > epochs:
340+
logger.info("Done")
341+
return
342+
self.model.train()
343+
batch_l, t_batch_l = batch_l[:2]
344+
batch_ul, t_batch_ul = batch_ul[:2]
345+
t_batch_l = tuple(t.to(self.device) for t in t_batch_l)
346+
t_batch_ul = tuple(t.to(self.device) for t in t_batch_ul)
347+
t_logits = distiller.get_teacher_logits(t_batch_l)
348+
t_logits_ul = distiller.get_teacher_logits(t_batch_ul)
349+
batch_l = tuple(t.to(self.device) for t in batch_l)
350+
batch_ul = tuple(t.to(self.device) for t in batch_ul)
351+
inputs = self.batch_mapper(batch_l)
352+
inputs_ul = self.batch_mapper(batch_ul)
353+
logits = self.model(**inputs)
354+
logits_ul = self.model(**inputs_ul)
355+
t_labels = torch.argmax(F.log_softmax(t_logits_ul, dim=2), dim=2)
356+
if self.use_crf:
357+
loss_labeled = -1.0 * self.crf(
358+
logits, inputs['labels'], mask=inputs['mask'] != 0.0)
359+
loss_unlabeled = -1.0 * self.crf(
360+
logits_ul, t_labels, mask=inputs_ul['mask'] != 0.0)
361+
else:
362+
loss_fn = CrossEntropyLoss(ignore_index=0)
363+
loss_labeled = loss_fn(logits.view(-1, self.num_labels), inputs['labels'].view(-1))
364+
loss_unlabeled = loss_fn(logits_ul.view(-1, self.num_labels), t_labels.view(-1))
365+
366+
if self.n_gpus > 1:
367+
loss_labeled = loss_labeled.mean()
368+
loss_unlabeled = loss_unlabeled.mean()
369+
370+
# add distillation loss
371+
loss_labeled = distiller.distill_loss(loss_labeled, logits, t_logits)
372+
loss_unlabeled = distiller.distill_loss(loss_unlabeled, logits_ul, t_logits_ul)
373+
374+
# sum labeled and unlabeled losses
375+
loss = loss_labeled + loss_unlabeled
376+
loss.backward()
377+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
378+
optimizer.step()
379+
# self.model.zero_grad()
380+
optimizer.zero_grad()
381+
global_step += 1
382+
avg_loss += loss.item()
383+
if global_step % logging_steps == 0:
384+
if s_idx != 0:
385+
logger.info(
386+
" global_step = %s, average loss = %s", global_step, avg_loss / s_idx)
387+
dev = self._get_eval(dev_data_set, "dev")
388+
test = self._get_eval(test_data_set, "test")
389+
if dev > best_dev:
390+
best_dev = dev
391+
best_test = test
392+
if save_path is not None and save_best:
393+
self.save_model(save_path)
394+
logger.info("Best result: dev= %s, test= %s", str(best_dev), str(best_test))
395+
if save_path is not None and global_step % save_steps == 0:
396+
self.save_model(save_path)
397+
260398
def _get_eval(self, ds, set_name):
261399
if ds is not None:
262400
logits, out_label_ids = self.evaluate(ds)
263401
res = self.evaluate_predictions(logits, out_label_ids)
264402
logger.info(" {} set F1 = {}".format(set_name, res['f1']))
403+
return res['f1']
404+
return None
265405

266406
def to(self, device='cpu', n_gpus=0):
267407
"""

nlp_architect/models/transformers/base_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ def load_model(cls, model_path: str, model_type: str, *args, **kwargs):
199199
raise FileNotFoundError
200200
with io.open(model_path + os.sep + 'labels.txt') as fp:
201201
labels = [l.strip() for l in fp.readlines()]
202-
return cls(model_type=model_type, model_name_or_path=model_path, labels=labels, *args, **kwargs)
202+
return cls(
203+
model_type=model_type, model_name_or_path=model_path, labels=labels, *args, **kwargs)
203204

204205
@staticmethod
205206
def get_train_steps_epochs(max_steps: int,

nlp_architect/models/transformers/quantized_bert.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, from_8bit=False,
213213

214214
# Instantiate model.
215215
model = cls(config)
216-
# Set model to initialize variables to be loaded from quantized checkpoint which are None by Default
216+
# Set model to initialize variables to be loaded from quantized
217+
# checkpoint which are None by Default
217218
model.eval()
218219
# Get state dict of model
219220
state_dict = torch.load(model_file, map_location='cpu')
@@ -232,17 +233,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, from_8bit=False,
232233
def load(module, prefix=''):
233234
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
234235
module._load_from_state_dict(
235-
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
236+
state_dict, prefix, local_metadata, True, missing_keys,
237+
unexpected_keys, error_msgs)
236238
for name, child in module._modules.items():
237239
if child is not None:
238240
load(child, prefix + name + '.')
239241

240242
# Make sure we are able to load base models as well as derived models (with heads)
241243
start_prefix = ''
242244
model_to_load = model
243-
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
245+
if not hasattr(model, cls.base_model_prefix) and any(
246+
s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
244247
start_prefix = cls.base_model_prefix + '.'
245-
if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
248+
if hasattr(model, cls.base_model_prefix) and not any(
249+
s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
246250
model_to_load = getattr(model, cls.base_model_prefix)
247251

248252
load(model_to_load, prefix=start_prefix)

0 commit comments

Comments
 (0)