@@ -18,7 +18,8 @@ class Optimizer(object):
18
18
but need to use one of it's implementation.
19
19
"""
20
20
21
- def __init__ (self ):
21
+ def __init__ (self , global_step = None ):
22
+ self ._global_step = global_step
22
23
# Dictionary of accumulators. Some optimizer subclasses need to
23
24
# allocate and manage extra variables associated with the parameters
24
25
# to train. These variables are called accumulators.
@@ -109,6 +110,26 @@ def _get_accumulator(self, name, param):
109
110
format (name , param .name ))
110
111
return self ._accumulators [name ][param .name ]
111
112
113
+ def _increment_global_step (self , block ):
114
+ """Increment the global step by 1 after every iteration
115
+
116
+ Args:
117
+ block: the block in which the loss variable is present
118
+
119
+ Returns:
120
+ list with global_step increment op as its only element
121
+ """
122
+ assert isinstance (block , framework .Block )
123
+ assert self ._global_step is not None
124
+ # create the increment op
125
+ increment_op = block .append_op (
126
+ type = "increment" ,
127
+ inputs = {"X" : self ._global_step },
128
+ outputs = {"Out" : self ._global_step },
129
+ attrs = {"step" : 1.0 })
130
+
131
+ return increment_op
132
+
112
133
def create_optimization_pass (self , parameters_and_grads , loss ):
113
134
"""Add optimization operators to update gradients to variables.
114
135
@@ -152,6 +173,8 @@ def create_optimization_pass(self, parameters_and_grads, loss):
152
173
if finish_ops is not None :
153
174
return_ops += finish_ops
154
175
176
+ if self ._global_step is not None :
177
+ return_ops .append (self ._increment_global_step (loss .block ))
155
178
return return_ops
156
179
157
180
def minimize (self , loss , parameter_list = None , no_grad_set = None ):
@@ -172,9 +195,9 @@ class SGDOptimizer(Optimizer):
172
195
""" Simple SGD optimizer without any state.
173
196
"""
174
197
175
- def __init__ (self , learning_rate ):
198
+ def __init__ (self , learning_rate , global_step = None ):
176
199
assert learning_rate is not None
177
- super (SGDOptimizer , self ).__init__ ()
200
+ super (SGDOptimizer , self ).__init__ (global_step )
178
201
self .type = "sgd"
179
202
self ._learning_rate = learning_rate
180
203
@@ -215,10 +238,14 @@ class MomentumOptimizer(Optimizer):
215
238
"""
216
239
_velocity_acc_str = "velocity"
217
240
218
- def __init__ (self , learning_rate , momentum , use_nesterov = False ):
241
+ def __init__ (self ,
242
+ learning_rate ,
243
+ momentum ,
244
+ use_nesterov = False ,
245
+ global_step = None ):
219
246
assert learning_rate is not None
220
247
assert momentum is not None
221
- super (MomentumOptimizer , self ).__init__ ()
248
+ super (MomentumOptimizer , self ).__init__ (global_step )
222
249
self .type = "momentum"
223
250
self ._learning_rate = learning_rate
224
251
self ._momentum = momentum
@@ -275,10 +302,10 @@ class AdagradOptimizer(Optimizer):
275
302
"""
276
303
_moment_acc_str = "moment"
277
304
278
- def __init__ (self , learning_rate , epsilon = 1.0e-6 ):
305
+ def __init__ (self , learning_rate , epsilon = 1.0e-6 , global_step = None ):
279
306
assert learning_rate is not None
280
307
assert epsilon is not None
281
- super (AdagradOptimizer , self ).__init__ ()
308
+ super (AdagradOptimizer , self ).__init__ (global_step )
282
309
self .type = "adagrad"
283
310
self ._learning_rate = learning_rate
284
311
self ._epsilon = epsilon
@@ -337,12 +364,13 @@ def __init__(self,
337
364
learning_rate = 0.001 ,
338
365
beta1 = 0.9 ,
339
366
beta2 = 0.999 ,
340
- epsilon = 1e-8 ):
367
+ epsilon = 1e-8 ,
368
+ global_step = None ):
341
369
assert learning_rate is not None
342
370
assert beta1 is not None
343
371
assert beta2 is not None
344
372
assert epsilon is not None
345
- super (AdamOptimizer , self ).__init__ ()
373
+ super (AdamOptimizer , self ).__init__ (global_step )
346
374
self .type = "adam"
347
375
self ._learning_rate = learning_rate
348
376
self ._beta1 = beta1
@@ -458,7 +486,8 @@ def __init__(self,
458
486
learning_rate = 0.001 ,
459
487
beta1 = 0.9 ,
460
488
beta2 = 0.999 ,
461
- epsilon = 1e-8 ):
489
+ epsilon = 1e-8 ,
490
+ global_step = None ):
462
491
assert learning_rate is not None
463
492
assert beta1 is not None
464
493
assert beta2 is not None
0 commit comments