Skip to content

Commit 5d59680

Browse files
LR scheduler + train refactor (#103)
* split __train up for clarity * split __train up for clarity * added lr scheduler after epoch completes
1 parent 309e45e commit 5d59680

File tree

1 file changed

+88
-86
lines changed

1 file changed

+88
-86
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 88 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -758,25 +758,16 @@ def __run_pretrain_routine(self, model):
758758
# ---------------------------
759759
# CORE TRAINING LOOP
760760
# ---------------------------
761-
762761
self.__train()
763762

764763
def __train(self):
765764
# run all epochs
766765
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
767-
# update the lr scheduler
768-
if self.lr_schedulers is not None:
769-
for lr_scheduler in self.lr_schedulers:
770-
lr_scheduler.step()
771-
766+
# get model
772767
model = self.__get_model()
773-
model.current_epoch = epoch_nb
774-
775-
# hook
776-
if self.__is_function_implemented('on_epoch_start'):
777-
model = self.__get_model()
778-
model.on_epoch_start()
779768

769+
# update training progress in trainer and model
770+
model.current_epoch = epoch_nb
780771
self.current_epoch = epoch_nb
781772
self.total_batches = self.nb_tng_batches + self.nb_val_batches
782773
self.batch_loss_value = 0 # accumulated grads
@@ -786,92 +777,103 @@ def __train(self):
786777
self.prog_bar = tqdm.tqdm(range(self.total_batches),
787778
position=self.process_position)
788779

789-
for batch_nb, data_batch in enumerate(self.tng_dataloader):
790-
self.batch_nb = batch_nb
791-
self.global_step += 1
792-
793-
model = self.__get_model()
794-
model.global_step = self.global_step
795-
796-
# stop when the flag is changed or we've gone past the amount
797-
# requested in the batches
798-
self.total_batch_nb += 1
799-
met_batch_limit = batch_nb > self.nb_tng_batches
800-
if met_batch_limit:
801-
break
802-
803-
# ---------------
804-
# RUN TRAIN STEP
805-
# ---------------
806-
batch_result = self.__run_tng_batch(data_batch, batch_nb)
807-
early_stop_epoch = batch_result == -1
808-
809-
# ---------------
810-
# RUN VAL STEP
811-
# ---------------
812-
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
813-
if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
814-
self.__run_validation()
815-
816-
# when batch should be saved
817-
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
818-
if self.proc_rank == 0 and self.experiment is not None:
819-
self.experiment.save()
820-
821-
# when metrics should be logged
822-
if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
823-
# count items in memory
824-
# nb_params, nb_tensors = count_mem_items()
825-
826-
model = self.__get_model()
827-
metrics = self.__tng_tqdm_dic
828-
829-
# add gpu memory
830-
if self.on_gpu:
831-
mem_map = get_gpu_memory_map()
832-
metrics.update(mem_map)
833-
834-
# add norms
835-
if self.track_grad_norm > 0:
836-
model = self.__get_model()
837-
grad_norm_dic = model.grad_norm(self.track_grad_norm)
838-
metrics.update(grad_norm_dic)
839-
840-
if self.__is_function_implemented('on_tng_metrics'):
841-
model.on_tng_metrics(metrics)
842-
843-
# log metrics
844-
scalar_metrics = self.__metrics_to_scalars(
845-
metrics, blacklist=self.__log_vals_blacklist())
846-
if self.proc_rank == 0 and self.experiment is not None:
847-
self.experiment.log(scalar_metrics, global_step=self.global_step)
848-
self.experiment.save()
849-
850-
# hook
851-
if self.__is_function_implemented('on_batch_end'):
852-
model = self.__get_model()
853-
model.on_batch_end()
854-
855-
# end epoch early
856-
if early_stop_epoch:
857-
break
780+
# -----------------
781+
# RUN TNG EPOCH
782+
# -----------------
783+
self.run_tng_epoch()
858784

859-
# hook
860-
if self.__is_function_implemented('on_epoch_end'):
861-
model = self.__get_model()
862-
model.on_epoch_end()
785+
# update LR schedulers
786+
if self.lr_schedulers is not None:
787+
for lr_scheduler in self.lr_schedulers:
788+
lr_scheduler.step()
863789

864790
# early stopping
865791
met_min_epochs = epoch_nb > self.min_nb_epochs
866792
if self.enable_early_stop and met_min_epochs:
867793
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
868794
logs=self.__tng_tqdm_dic)
869-
870795
# stop training
871796
stop = should_stop and met_min_epochs
872797
if stop:
873798
return
874799

800+
def run_tng_epoch(self):
801+
# before epoch hook
802+
if self.__is_function_implemented('on_epoch_start'):
803+
model = self.__get_model()
804+
model.on_epoch_start()
805+
806+
# run epoch
807+
for batch_nb, data_batch in enumerate(self.tng_dataloader):
808+
self.batch_nb = batch_nb
809+
self.global_step += 1
810+
811+
model = self.__get_model()
812+
model.global_step = self.global_step
813+
814+
# stop when the flag is changed or we've gone past the amount
815+
# requested in the batches
816+
self.total_batch_nb += 1
817+
met_batch_limit = batch_nb > self.nb_tng_batches
818+
if met_batch_limit:
819+
break
820+
821+
# ---------------
822+
# RUN TRAIN STEP
823+
# ---------------
824+
batch_result = self.__run_tng_batch(data_batch, batch_nb)
825+
early_stop_epoch = batch_result == -1
826+
827+
# ---------------
828+
# RUN VAL STEP
829+
# ---------------
830+
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
831+
if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
832+
self.__run_validation()
833+
834+
# when batch should be saved
835+
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
836+
if self.proc_rank == 0 and self.experiment is not None:
837+
self.experiment.save()
838+
839+
# when metrics should be logged
840+
if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
841+
# count items in memory
842+
# nb_params, nb_tensors = count_mem_items()
843+
844+
model = self.__get_model()
845+
metrics = self.__tng_tqdm_dic
846+
847+
# add gpu memory
848+
if self.on_gpu:
849+
mem_map = get_gpu_memory_map()
850+
metrics.update(mem_map)
851+
852+
# add norms
853+
if self.track_grad_norm > 0:
854+
model = self.__get_model()
855+
grad_norm_dic = model.grad_norm(self.track_grad_norm)
856+
metrics.update(grad_norm_dic)
857+
858+
if self.__is_function_implemented('on_tng_metrics'):
859+
model.on_tng_metrics(metrics)
860+
861+
# log metrics
862+
scalar_metrics = self.__metrics_to_scalars(
863+
metrics, blacklist=self.__log_vals_blacklist())
864+
if self.proc_rank == 0 and self.experiment is not None:
865+
self.experiment.log(scalar_metrics, global_step=self.global_step)
866+
self.experiment.save()
867+
868+
# end epoch early
869+
if early_stop_epoch:
870+
break
871+
872+
# epoch end hook
873+
if self.__is_function_implemented('on_epoch_end'):
874+
model = self.__get_model()
875+
model.on_epoch_end()
876+
875877
def __metrics_to_scalars(self, metrics, blacklist=set()):
876878
new_metrics = {}
877879
for k, v in metrics.items():

0 commit comments

Comments
 (0)