@@ -49,7 +49,7 @@ def start(self):
4949 def log (self , * args , ** kwargs ):
5050 logger .info (* args , ** kwargs )
5151
52- def report_training (self , step , num_steps , learning_rate ,
52+ def report_training (self , step , num_steps , learning_rate , patience ,
5353 report_stats , multigpu = False ):
5454 """
5555 This is the user-defined batch-level traing progress
@@ -72,7 +72,7 @@ def report_training(self, step, num_steps, learning_rate,
7272 report_stats = \
7373 onmt .utils .Statistics .all_gather_stats (report_stats )
7474 self ._report_training (
75- step , num_steps , learning_rate , report_stats )
75+ step , num_steps , learning_rate , patience , report_stats )
7676 return onmt .utils .Statistics ()
7777 else :
7878 return report_stats
@@ -81,17 +81,22 @@ def _report_training(self, *args, **kwargs):
8181 """ To be overridden """
8282 raise NotImplementedError ()
8383
84- def report_step (self , lr , step , train_stats = None , valid_stats = None ):
84+ def report_step (self , lr , patience , step , train_stats = None ,
85+ valid_stats = None ):
8586 """
8687 Report stats of a step
8788
8889 Args:
90+ lr(float): current learning rate
91+ patience(int): current patience
92+ step(int): current step
8993 train_stats(Statistics): training stats
9094 valid_stats(Statistics): validation stats
91- lr(float): current learning rate
9295 """
9396 self ._report_step (
94- lr , step , train_stats = train_stats , valid_stats = valid_stats )
97+ lr , patience , step ,
98+ train_stats = train_stats ,
99+ valid_stats = valid_stats )
95100
96101 def _report_step (self , * args , ** kwargs ):
97102 raise NotImplementedError ()
@@ -111,12 +116,13 @@ def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
111116 super (ReportMgr , self ).__init__ (report_every , start_time )
112117 self .tensorboard_writer = tensorboard_writer
113118
114- def maybe_log_tensorboard (self , stats , prefix , learning_rate , step ):
119+ def maybe_log_tensorboard (self , stats , prefix , learning_rate ,
120+ patience , step ):
115121 if self .tensorboard_writer is not None :
116122 stats .log_tensorboard (
117- prefix , self .tensorboard_writer , learning_rate , step )
123+ prefix , self .tensorboard_writer , learning_rate , patience , step )
118124
119- def _report_training (self , step , num_steps , learning_rate ,
125+ def _report_training (self , step , num_steps , learning_rate , patience ,
120126 report_stats ):
121127 """
122128 See base class method `ReportMgrBase.report_training`.
@@ -127,12 +133,15 @@ def _report_training(self, step, num_steps, learning_rate,
127133 self .maybe_log_tensorboard (report_stats ,
128134 "progress" ,
129135 learning_rate ,
136+ patience ,
130137 step )
131138 report_stats = onmt .utils .Statistics ()
132139
133140 return report_stats
134141
135- def _report_step (self , lr , step , train_stats = None , valid_stats = None ):
142+ def _report_step (self , lr , patience , step ,
143+ train_stats = None ,
144+ valid_stats = None ):
136145 """
137146 See base class method `ReportMgrBase.report_step`.
138147 """
@@ -143,6 +152,7 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
143152 self .maybe_log_tensorboard (train_stats ,
144153 "train" ,
145154 lr ,
155+ patience ,
146156 step )
147157
148158 if valid_stats is not None :
@@ -152,4 +162,5 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
152162 self .maybe_log_tensorboard (valid_stats ,
153163 "valid" ,
154164 lr ,
165+ patience ,
155166 step )
0 commit comments