|
23 | 23 |
|
24 | 24 | logger = logging.getLogger(__name__) |
25 | 25 |
|
| 26 | +TEST_BATCH_SIZE = 512 |
| 27 | + |
26 | 28 |
|
27 | 29 | class TaggerDataset(Dataset): |
28 | 30 | """PyTorch Dataset class used to handle tagger inputs, labels and mask""" |
@@ -301,7 +303,6 @@ def load_best_weights_path(self, path): |
301 | 303 | Args: |
302 | 304 | path (str): Path to save the best model weights. |
303 | 305 | """ |
304 | | - # self.best_model_save_path = os.path.abspath(path) |
305 | 306 | if os.path.exists(path): |
306 | 307 | self.load_state_dict(torch.load(path)) |
307 | 308 | else: |
@@ -500,8 +501,8 @@ def get_dataloader(self, X, y, is_train): |
500 | 501 | """ |
501 | 502 | tensor_inputs, input_seq_lens, tensor_labels = self._encoder.get_tensor_data(X, y, fit=is_train) |
502 | 503 | tensor_dataset = TaggerDataset(tensor_inputs, input_seq_lens, tensor_labels) |
503 | | - torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else 512, shuffle=is_train, |
504 | | - collate_fn=collate_tensors_and_masks) |
| 504 | + torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else TEST_BATCH_SIZE, |
| 505 | + shuffle=is_train, collate_fn=collate_tensors_and_masks) |
505 | 506 | return torch_dataloader |
506 | 507 |
|
507 | 508 | def fit(self, X, y): |
@@ -534,7 +535,7 @@ def fit(self, X, y): |
534 | 535 | if self.optimizer == "sgd": |
535 | 536 | self.optim = optim.SGD(self.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-5) |
536 | 537 | if self.optimizer == "adam": |
537 | | - self.optim = optim.Adam(self.parameters(), weight_decay=1e-5) |
| 538 | + self.optim = optim.Adam(self.parameters(), lr=0.001, weight_decay=1e-5) |
538 | 539 |
|
539 | 540 | self.training_loop(train_dataloader, dev_dataloader) |
540 | 541 | self.load_state_dict(torch.load(self.tmp_save_path)) |
|
0 commit comments