@@ -638,8 +638,8 @@ def train(
638638 adaptive_update_bounds = (0.9 , 0.99 ),
639639 adaptive_update_fraction = 0.0 ,
640640 multi_gpu = False ,
641- tensorboard_log = False ,
642- tensorboard_profile = False ,
641+ log_tb = False ,
642+ export_tb = False ,
643643 ):
644644 """Train the GAN model on real low res data and real high res data
645645
@@ -703,12 +703,12 @@ def train(
703703 rate that the model and optimizer were initialized with.
704704 If true and multiple gpus are found, ``default_device`` device
705705 should be set to /gpu:0
706- tensorboard_log : bool
706+ log_tb : bool
707707 Whether to write log file for use with tensorboard. Log data can
708708 be viewed with ``tensorboard --logdir <logdir>`` where ``<logdir>``
709709 is the parent directory of ``out_dir``, and pointing the browser to
710710 the printed address.
711- tensorboard_profile : bool
711+ export_tb : bool
712712 Whether to export profiling information to tensorboard. This can
713713 then be viewed in the tensorboard dashboard under the profile tab
714714
@@ -720,10 +720,8 @@ def train(
720720 (3) Would like an automatic way to exit the batch handler thread
721721 instead of manually calling .stop() here.
722722 """
723- if tensorboard_log :
723+ if log_tb :
724724 self ._init_tensorboard_writer (out_dir )
725- if tensorboard_profile :
726- self ._write_tb_profile = True
727725
728726 self .set_norm_stats (batch_handler .means , batch_handler .stds )
729727 params = self .check_batch_handler_attrs (batch_handler )
@@ -759,6 +757,7 @@ def train(
759757 train_disc ,
760758 disc_loss_bounds ,
761759 multi_gpu = multi_gpu ,
760+ export_tb = export_tb ,
762761 )
763762 loss_details .update (
764763 self .calc_val_loss (batch_handler , weight_gen_advers )
@@ -1071,7 +1070,7 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means):
10711070 disc_loss = self ._train_record ['train_loss_disc' ].values .mean ()
10721071 gen_loss = self ._train_record ['train_loss_gen' ].values .mean ()
10731072
1074- logger .debug (
1073+ logger .info (
10751074 'Batch {} out of {} has (gen / disc) loss of: ({:.2e} / {:.2e}). '
10761075 'Running mean (gen / disc): ({:.2e} / {:.2e}). Trained '
10771076 '(gen / disc): ({} / {})' .format (
@@ -1102,6 +1101,7 @@ def _train_epoch(
11021101 train_disc ,
11031102 disc_loss_bounds ,
11041103 multi_gpu = False ,
1104+ export_tb = False ,
11051105 ):
11061106 """Train the GAN for one epoch.
11071107
@@ -1129,6 +1129,9 @@ def _train_epoch(
11291129 rate that the model and optimizer were initialized with.
11301130 If true and multiple gpus are found, ``default_device`` device
11311131 should be set to /gpu:0
1132+ export_tb : bool
1133+ Whether to export profiling information to tensorboard. This can
1134+ then be viewed in the tensorboard dashboard under the profile tab
11321135
11331136 Returns
11341137 -------
@@ -1151,9 +1154,10 @@ def _train_epoch(
11511154 only_gen = train_gen and not train_disc
11521155 only_disc = train_disc and not train_gen
11531156
1154- if self . _write_tb_profile :
1157+ if export_tb :
11551158 tf .summary .trace_on (graph = True , profiler = True )
11561159
1160+ prev_time = time .time ()
11571161 for ib , batch in enumerate (batch_handler ):
11581162 start = time .time ()
11591163
@@ -1163,7 +1167,7 @@ def _train_epoch(
11631167 disc_too_bad = (loss_disc > disc_th_high ) and train_disc
11641168 gen_too_good = disc_too_bad
11651169
1166- b_loss_details = self .timer ( self . _train_batch , log = True ) (
1170+ b_loss_details = self ._train_batch (
11671171 batch ,
11681172 train_gen ,
11691173 only_gen ,
@@ -1175,17 +1179,25 @@ def _train_epoch(
11751179 multi_gpu ,
11761180 )
11771181
1178- loss_means = self .timer ( self . _post_batch , log = True ) (
1182+ loss_means = self ._post_batch (
11791183 ib , b_loss_details , len (batch_handler ), loss_means
11801184 )
11811185
1186+ total_step_time = time .time () - prev_time
1187+ batch_step_time = time .time () - start
1188+ batch_load_time = total_step_time - batch_step_time
1189+
11821190 logger .info (
11831191 f'Finished batch step { ib + 1 } / { len (batch_handler )} in '
1184- f'{ time .time () - start :.4f} seconds'
1192+ f'{ total_step_time :.4f} seconds. Batch load time: '
1193+ f'{ batch_load_time :.4f} seconds. Batch train time: '
1194+ f'{ batch_step_time :.4f} seconds.'
11851195 )
11861196
1197+ prev_time = time .time ()
1198+
11871199 self .total_batches += len (batch_handler )
11881200 loss_details = self ._train_record .mean ().to_dict ()
11891201 loss_details ['total_batches' ] = int (self .total_batches )
1190- self .profile_to_tensorboard ('training_epoch' )
1202+ self .profile_to_tensorboard ('training_epoch' , export_tb = export_tb )
11911203 return loss_details
0 commit comments