Skip to content

Commit 5b0f345

Browse files
committed
Minor nits
1 parent b01eed2 commit 5b0f345

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

mindmeld/models/taggers/pytorch_crf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
TEST_BATCH_SIZE = 512
27+
2628

2729
class TaggerDataset(Dataset):
2830
"""PyTorch Dataset class used to handle tagger inputs, labels and mask"""
@@ -301,7 +303,6 @@ def load_best_weights_path(self, path):
301303
Args:
302304
path (str): Path to save the best model weights.
303305
"""
304-
# self.best_model_save_path = os.path.abspath(path)
305306
if os.path.exists(path):
306307
self.load_state_dict(torch.load(path))
307308
else:
@@ -500,8 +501,8 @@ def get_dataloader(self, X, y, is_train):
500501
"""
501502
tensor_inputs, input_seq_lens, tensor_labels = self._encoder.get_tensor_data(X, y, fit=is_train)
502503
tensor_dataset = TaggerDataset(tensor_inputs, input_seq_lens, tensor_labels)
503-
torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else 512, shuffle=is_train,
504-
collate_fn=collate_tensors_and_masks)
504+
torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else TEST_BATCH_SIZE,
505+
shuffle=is_train, collate_fn=collate_tensors_and_masks)
505506
return torch_dataloader
506507

507508
def fit(self, X, y):
@@ -534,7 +535,7 @@ def fit(self, X, y):
534535
if self.optimizer == "sgd":
535536
self.optim = optim.SGD(self.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-5)
536537
if self.optimizer == "adam":
537-
self.optim = optim.Adam(self.parameters(), weight_decay=1e-5)
538+
self.optim = optim.Adam(self.parameters(), lr=0.001, weight_decay=1e-5)
538539

539540
self.training_loop(train_dataloader, dev_dataloader)
540541
self.load_state_dict(torch.load(self.tmp_save_path))

0 commit comments

Comments
 (0)