Skip to content

Commit 435be9b

Browse files
authored
Merge pull request #432 from cisco/vidamoda/pytorch_crf_bug_fix
Major bug fix for PyTorch CRF
2 parents 63e453b + 5b0f345 commit 435be9b

File tree

3 files changed

+71
-53
lines changed

3 files changed

+71
-53
lines changed

mindmeld/models/tagger_models.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ def _dump(self, path):
399399
self._clf.dump(path)
400400
if isinstance(self._clf, TorchCrfTagger):
401401
metadata.update({
402-
"model": self,
402+
"model_config": self.config,
403+
"feature_and_label_encoder": self._clf.get_torch_encoder(),
403404
"model_type": "torch-crf"
404405
})
405406
elif isinstance(self._clf, LstmModel):
@@ -432,25 +433,31 @@ def load(cls, path):
432433

433434
# If model is serializable, it can be loaded and used as-is. But if not serializable,
434435
# it means we need to create an instance and load necessary details for it to be used.
435-
if not is_serializable and metadata.get('model_type') == 'lstm':
436+
if not is_serializable:
436437
model = cls(metadata["model_config"])
438+
if metadata.get('model_type') == 'lstm':
439+
440+
# misc resources load
441+
try:
442+
model._current_params = metadata["current_params"]
443+
model._label_encoder = metadata["label_encoder"]
444+
model._no_entities = metadata["no_entities"]
445+
except KeyError: # backwards compatability
446+
model_dir = metadata["model"]
447+
tagger_vars = joblib.load(model_dir, ".tagger_vars")
448+
model._current_params = tagger_vars["current_params"]
449+
model._label_encoder = tagger_vars["label_encoder"]
450+
model._no_entities = tagger_vars["no_entities"]
451+
452+
# underneath tagger load
453+
model._clf.load(model_dir)
454+
455+
# replace model dump directory with actual model
456+
elif metadata.get('model_type') == 'torch-crf':
457+
model._clf.set_params(**metadata["model_config"].params)
458+
model._clf.set_torch_encoder(metadata['feature_and_label_encoder'])
459+
model._clf.load(path)
437460

438-
# misc resources load
439-
try:
440-
model._current_params = metadata["current_params"]
441-
model._label_encoder = metadata["label_encoder"]
442-
model._no_entities = metadata["no_entities"]
443-
except KeyError: # backwards compatability
444-
model_dir = metadata["model"]
445-
tagger_vars = joblib.load(model_dir, ".tagger_vars")
446-
model._current_params = tagger_vars["current_params"]
447-
model._label_encoder = tagger_vars["label_encoder"]
448-
model._no_entities = tagger_vars["no_entities"]
449-
450-
# underneath tagger load
451-
model._clf.load(model_dir)
452-
453-
# replace model dump directory with actual model
454461
metadata["model"] = model
455462

456463
return metadata["model"]

mindmeld/models/taggers/crf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,17 @@ def dump(self, path):
331331
best_model_save_path = os.path.join(os.path.split(path)[0], "best_crf_wts.pt")
332332
self._clf.save_best_weights_path(best_model_save_path)
333333

334+
def load(self, path):
335+
best_model_save_path = os.path.join(os.path.split(path)[0], "best_crf_wts.pt")
336+
self._clf.build_params(*self.get_torch_encoder().get_feats_and_classes())
337+
self._clf.load_best_weights_path(best_model_save_path)
338+
339+
def get_torch_encoder(self):
340+
return self._clf.get_encoder()
341+
342+
def set_torch_encoder(self, encoder):
343+
self._clf.set_encoder(encoder)
344+
334345

335346
# Feature extraction for CRF
336347

mindmeld/models/taggers/pytorch_crf.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy
77
from itertools import chain
88
from random import randint
9-
from tempfile import gettempdir
9+
from tempfile import mkdtemp
1010

1111
import numpy as np
1212
import torch
@@ -23,6 +23,8 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
TEST_BATCH_SIZE = 512
27+
2628

2729
class TaggerDataset(Dataset):
2830
"""PyTorch Dataset class used to handle tagger inputs, labels and mask"""
@@ -155,6 +157,9 @@ def __init__(self, feature_extractor="hash", num_feats=50000):
155157
self.classes = None
156158
self.num_feats = num_feats
157159

160+
def get_feats_and_classes(self):
161+
return self.num_feats, self.num_classes
162+
158163
def get_padded_transformed_tensors(self, inputs_or_labels, seq_lens, is_label):
159164
"""Returns the encoded and padded sparse tensor representations of the inputs/labels.
160165
@@ -253,7 +258,7 @@ class TorchCrfModel(nn.Module):
253258
def __init__(self):
254259
super().__init__()
255260
self.optim = None
256-
self.encoder = None
261+
self._encoder = None
257262
self.W = None
258263
self.b = None
259264
self.crf_layer = None
@@ -270,10 +275,13 @@ def __init__(self):
270275
self.optimizer = None
271276
self.random_state = None
272277

273-
self.best_model_save_path = None
274-
self.ready = False
275-
self.tmp_save_path = os.path.join(gettempdir(), "best_crf_wts.pt")
276-
# os.makedirs(os.path.dirname(self.tmp_save_path), exist_ok=True)
278+
self.tmp_save_path = os.path.join(mkdtemp(), "best_crf_wts.pt")
279+
280+
def get_encoder(self):
281+
return self._encoder
282+
283+
def set_encoder(self, encoder):
284+
self._encoder = encoder
277285

278286
def set_random_states(self):
279287
"""Sets the random seeds across all libraries used for deterministic output."""
@@ -287,10 +295,16 @@ def save_best_weights_path(self, path):
287295
Args:
288296
path (str): Path to save the best model weights.
289297
"""
290-
self.best_model_save_path = path
291-
if os.path.exists(self.tmp_save_path):
292-
best_weights = torch.load(self.tmp_save_path)
293-
torch.save(best_weights, self.best_model_save_path)
298+
torch.save(self.state_dict(), path)
299+
300+
def load_best_weights_path(self, path):
301+
"""Saves the best weights of the model to a path in the .generated folder.
302+
303+
Args:
304+
path (str): Path to save the best model weights.
305+
"""
306+
if os.path.exists(path):
307+
self.load_state_dict(torch.load(path))
294308
else:
295309
raise MindMeldError("CRF weights not saved. Please re-train model from scratch.")
296310

@@ -347,8 +361,8 @@ def forward(self, inputs, targets, mask, drop_input=0.0):
347361
if drop_input:
348362
dp_mask = (torch.FloatTensor(inputs.values().size()).uniform_() > drop_input)
349363
inputs.values()[:] = inputs.values() * dp_mask
350-
dense_W = torch.tile(self.W, dims=(mask.shape[0], 1))
351-
out_1 = torch.addmm(self.b, inputs, dense_W)
364+
dense_w = torch.tile(self.W, dims=(mask.shape[0], 1))
365+
out_1 = torch.addmm(self.b, inputs, dense_w)
352366
crf_input = out_1.reshape((mask.shape[0], -1, self.num_classes))
353367
if targets is None:
354368
return self.crf_layer.decode(crf_input, mask=mask)
@@ -485,10 +499,10 @@ def get_dataloader(self, X, y, is_train):
485499
torch_dataloader (torch.utils.data.dataloader.DataLoader): returns PyTorch dataloader object that can be
486500
used to iterate across the data.
487501
"""
488-
tensor_inputs, input_seq_lens, tensor_labels = self.encoder.get_tensor_data(X, y, fit=is_train)
502+
tensor_inputs, input_seq_lens, tensor_labels = self._encoder.get_tensor_data(X, y, fit=is_train)
489503
tensor_dataset = TaggerDataset(tensor_inputs, input_seq_lens, tensor_labels)
490-
torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else 512, shuffle=is_train,
491-
collate_fn=collate_tensors_and_masks)
504+
torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else TEST_BATCH_SIZE,
505+
shuffle=is_train, collate_fn=collate_tensors_and_masks)
492506
return torch_dataloader
493507

494508
def fit(self, X, y):
@@ -500,7 +514,7 @@ def fit(self, X, y):
500514
entity objects)
501515
"""
502516
self.set_random_states()
503-
self.encoder = Encoder(feature_extractor=self.feat_type, num_feats=self.feat_num)
517+
self._encoder = Encoder(feature_extractor=self.feat_type, num_feats=self.feat_num)
504518
stratify_tuples = None
505519
if self.stratify_train_val_split:
506520
X, y, stratify_tuples = stratify_input(X, y)
@@ -516,15 +530,15 @@ def fit(self, X, y):
516530
del X, y, train_X, train_y, dev_X, dev_y, stratify_tuples
517531
gc.collect()
518532

519-
self.build_params(self.encoder.num_feats, self.encoder.num_classes)
533+
self.build_params(*self._encoder.get_feats_and_classes())
520534

521535
if self.optimizer == "sgd":
522536
self.optim = optim.SGD(self.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-5)
523537
if self.optimizer == "adam":
524-
self.optim = optim.Adam(self.parameters(), weight_decay=1e-5)
538+
self.optim = optim.Adam(self.parameters(), lr=0.001, weight_decay=1e-5)
525539

526540
self.training_loop(train_dataloader, dev_dataloader)
527-
self.ready = True
541+
self.load_state_dict(torch.load(self.tmp_save_path))
528542

529543
def training_loop(self, train_dataloader, dev_dataloader):
530544
"""Contains the training loop process where we train the model for specified number of epochs.
@@ -605,13 +619,6 @@ def predict_marginals(self, X):
605619
Returns:
606620
marginals_dict (list of list of dicts): Returns the probability of every tag for each token in a sequence.
607621
"""
608-
if self.ready:
609-
if self.best_model_save_path:
610-
self.load_state_dict(torch.load(self.best_model_save_path))
611-
else:
612-
self.load_state_dict(torch.load(self.tmp_save_path))
613-
else:
614-
raise MindMeldError("PyTorch-CRF Model does not seem to be trained. Train before running predictions.")
615622
dataloader = self.get_dataloader(X, None, is_train=False)
616623
marginals_dict = []
617624
self.eval()
@@ -626,7 +633,7 @@ def predict_marginals(self, X):
626633
one_seq_list = []
627634
for (token_probs, valid_token) in zip(seq, mask_seq):
628635
if valid_token:
629-
one_seq_list.append(dict(zip(self.encoder.classes, token_probs)))
636+
one_seq_list.append(dict(zip(self._encoder.classes, token_probs)))
630637
marginals_dict.append(one_seq_list)
631638

632639
return marginals_dict
@@ -639,14 +646,7 @@ def predict(self, X):
639646
Returns:
640647
preds (list of lists): Predictions for each token in each sequence.
641648
"""
642-
if self.ready:
643-
if self.best_model_save_path:
644-
self.load_state_dict(torch.load(self.best_model_save_path))
645-
else:
646-
self.load_state_dict(torch.load(self.tmp_save_path))
647-
else:
648-
raise MindMeldError("PyTorch-CRF Model does not seem to be trained. Train before running predictions.")
649649
dataloader = self.get_dataloader(X, None, is_train=False)
650650

651651
preds = self.run_predictions(dataloader, calc_f1=False)
652-
return [self.encoder.label_encoder.inverse_transform(x).tolist() for x in preds]
652+
return [self._encoder.label_encoder.inverse_transform(x).tolist() for x in preds]

0 commit comments

Comments
 (0)