66from copy import copy
77from itertools import chain
88from random import randint
9- from tempfile import gettempdir
9+ from tempfile import mkdtemp
1010
1111import numpy as np
1212import torch
2323
2424logger = logging .getLogger (__name__ )
2525
26+ TEST_BATCH_SIZE = 512
27+
2628
2729class TaggerDataset (Dataset ):
2830 """PyTorch Dataset class used to handle tagger inputs, labels and mask"""
@@ -155,6 +157,9 @@ def __init__(self, feature_extractor="hash", num_feats=50000):
155157 self .classes = None
156158 self .num_feats = num_feats
157159
160+ def get_feats_and_classes (self ):
161+ return self .num_feats , self .num_classes
162+
158163 def get_padded_transformed_tensors (self , inputs_or_labels , seq_lens , is_label ):
159164 """Returns the encoded and padded sparse tensor representations of the inputs/labels.
160165
@@ -253,7 +258,7 @@ class TorchCrfModel(nn.Module):
253258 def __init__ (self ):
254259 super ().__init__ ()
255260 self .optim = None
256- self .encoder = None
261+ self ._encoder = None
257262 self .W = None
258263 self .b = None
259264 self .crf_layer = None
@@ -270,10 +275,13 @@ def __init__(self):
270275 self .optimizer = None
271276 self .random_state = None
272277
273- self .best_model_save_path = None
274- self .ready = False
275- self .tmp_save_path = os .path .join (gettempdir (), "best_crf_wts.pt" )
276- # os.makedirs(os.path.dirname(self.tmp_save_path), exist_ok=True)
278+ self .tmp_save_path = os .path .join (mkdtemp (), "best_crf_wts.pt" )
279+
280+ def get_encoder (self ):
281+ return self ._encoder
282+
283+ def set_encoder (self , encoder ):
284+ self ._encoder = encoder
277285
278286 def set_random_states (self ):
279287 """Sets the random seeds across all libraries used for deterministic output."""
@@ -287,10 +295,16 @@ def save_best_weights_path(self, path):
287295 Args:
288296 path (str): Path to save the best model weights.
289297 """
290- self .best_model_save_path = path
291- if os .path .exists (self .tmp_save_path ):
292- best_weights = torch .load (self .tmp_save_path )
293- torch .save (best_weights , self .best_model_save_path )
298+ torch .save (self .state_dict (), path )
299+
300+ def load_best_weights_path (self , path ):
301+ """Saves the best weights of the model to a path in the .generated folder.
302+
303+ Args:
304+ path (str): Path to save the best model weights.
305+ """
306+ if os .path .exists (path ):
307+ self .load_state_dict (torch .load (path ))
294308 else :
295309 raise MindMeldError ("CRF weights not saved. Please re-train model from scratch." )
296310
@@ -347,8 +361,8 @@ def forward(self, inputs, targets, mask, drop_input=0.0):
347361 if drop_input :
348362 dp_mask = (torch .FloatTensor (inputs .values ().size ()).uniform_ () > drop_input )
349363 inputs .values ()[:] = inputs .values () * dp_mask
350- dense_W = torch .tile (self .W , dims = (mask .shape [0 ], 1 ))
351- out_1 = torch .addmm (self .b , inputs , dense_W )
364+ dense_w = torch .tile (self .W , dims = (mask .shape [0 ], 1 ))
365+ out_1 = torch .addmm (self .b , inputs , dense_w )
352366 crf_input = out_1 .reshape ((mask .shape [0 ], - 1 , self .num_classes ))
353367 if targets is None :
354368 return self .crf_layer .decode (crf_input , mask = mask )
@@ -485,10 +499,10 @@ def get_dataloader(self, X, y, is_train):
485499 torch_dataloader (torch.utils.data.dataloader.DataLoader): returns PyTorch dataloader object that can be
486500 used to iterate across the data.
487501 """
488- tensor_inputs , input_seq_lens , tensor_labels = self .encoder .get_tensor_data (X , y , fit = is_train )
502+ tensor_inputs , input_seq_lens , tensor_labels = self ._encoder .get_tensor_data (X , y , fit = is_train )
489503 tensor_dataset = TaggerDataset (tensor_inputs , input_seq_lens , tensor_labels )
490- torch_dataloader = DataLoader (tensor_dataset , batch_size = self .batch_size if is_train else 512 , shuffle = is_train ,
491- collate_fn = collate_tensors_and_masks )
504+ torch_dataloader = DataLoader (tensor_dataset , batch_size = self .batch_size if is_train else TEST_BATCH_SIZE ,
505+ shuffle = is_train , collate_fn = collate_tensors_and_masks )
492506 return torch_dataloader
493507
494508 def fit (self , X , y ):
@@ -500,7 +514,7 @@ def fit(self, X, y):
500514 entity objects)
501515 """
502516 self .set_random_states ()
503- self .encoder = Encoder (feature_extractor = self .feat_type , num_feats = self .feat_num )
517+ self ._encoder = Encoder (feature_extractor = self .feat_type , num_feats = self .feat_num )
504518 stratify_tuples = None
505519 if self .stratify_train_val_split :
506520 X , y , stratify_tuples = stratify_input (X , y )
@@ -516,15 +530,15 @@ def fit(self, X, y):
516530 del X , y , train_X , train_y , dev_X , dev_y , stratify_tuples
517531 gc .collect ()
518532
519- self .build_params (self .encoder . num_feats , self . encoder . num_classes )
533+ self .build_params (* self ._encoder . get_feats_and_classes () )
520534
521535 if self .optimizer == "sgd" :
522536 self .optim = optim .SGD (self .parameters (), lr = 0.01 , momentum = 0.9 , nesterov = True , weight_decay = 1e-5 )
523537 if self .optimizer == "adam" :
524- self .optim = optim .Adam (self .parameters (), weight_decay = 1e-5 )
538+ self .optim = optim .Adam (self .parameters (), lr = 0.001 , weight_decay = 1e-5 )
525539
526540 self .training_loop (train_dataloader , dev_dataloader )
527- self .ready = True
541+ self .load_state_dict ( torch . load ( self . tmp_save_path ))
528542
529543 def training_loop (self , train_dataloader , dev_dataloader ):
530544 """Contains the training loop process where we train the model for specified number of epochs.
@@ -605,13 +619,6 @@ def predict_marginals(self, X):
605619 Returns:
606620 marginals_dict (list of list of dicts): Returns the probability of every tag for each token in a sequence.
607621 """
608- if self .ready :
609- if self .best_model_save_path :
610- self .load_state_dict (torch .load (self .best_model_save_path ))
611- else :
612- self .load_state_dict (torch .load (self .tmp_save_path ))
613- else :
614- raise MindMeldError ("PyTorch-CRF Model does not seem to be trained. Train before running predictions." )
615622 dataloader = self .get_dataloader (X , None , is_train = False )
616623 marginals_dict = []
617624 self .eval ()
@@ -626,7 +633,7 @@ def predict_marginals(self, X):
626633 one_seq_list = []
627634 for (token_probs , valid_token ) in zip (seq , mask_seq ):
628635 if valid_token :
629- one_seq_list .append (dict (zip (self .encoder .classes , token_probs )))
636+ one_seq_list .append (dict (zip (self ._encoder .classes , token_probs )))
630637 marginals_dict .append (one_seq_list )
631638
632639 return marginals_dict
@@ -639,14 +646,7 @@ def predict(self, X):
639646 Returns:
640647 preds (list of lists): Predictions for each token in each sequence.
641648 """
642- if self .ready :
643- if self .best_model_save_path :
644- self .load_state_dict (torch .load (self .best_model_save_path ))
645- else :
646- self .load_state_dict (torch .load (self .tmp_save_path ))
647- else :
648- raise MindMeldError ("PyTorch-CRF Model does not seem to be trained. Train before running predictions." )
649649 dataloader = self .get_dataloader (X , None , is_train = False )
650650
651651 preds = self .run_predictions (dataloader , calc_f1 = False )
652- return [self .encoder .label_encoder .inverse_transform (x ).tolist () for x in preds ]
652+ return [self ._encoder .label_encoder .inverse_transform (x ).tolist () for x in preds ]
0 commit comments