Skip to content

Commit 11b505a

Browse files
committed
Add LRS wrapper to optimizer
1 parent 1c61338 commit 11b505a

File tree

1 file changed

+157
-6
lines changed

1 file changed

+157
-6
lines changed

python/paddle/trainer_config_helpers/optimizers.py

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
default_gradient_clipping_threshold, default_momentum
1717

1818
from .default_decorators import wrap_param_default
19+
import collections
20+
import cStringIO
1921

2022
__all__ = [
2123
'Optimizer', 'BaseSGDOptimizer', 'MomentumOptimizer', 'AdamaxOptimizer',
2224
'AdamOptimizer', 'AdaGradOptimizer', 'RMSPropOptimizer',
2325
'DecayedAdaGradOptimizer', 'AdaDeltaOptimizer', 'BaseRegularization',
24-
'L2Regularization', 'settings', 'ModelAverage'
26+
'L2Regularization', 'settings', 'ModelAverage', 'PolyLRS', 'ConstantLRS',
27+
'ExpLRS', 'DiscreteExpLRS', 'LinearLRS', 'ManualLRS', 'PassManualLRS'
2528
]
2629

2730

@@ -351,15 +354,141 @@ def __extends__(dict1, dict2):
351354
return dict1
352355

353356

357+
class BaseLRS(Optimizer):
358+
def __init__(self, a, b, scheduler_name):
359+
self.__a__ = float(a)
360+
self.__b__ = float(b)
361+
self.__scheduler_name__ = scheduler_name
362+
363+
def to_setting_kwargs(self):
364+
return {
365+
'learning_rate_schedule': self.__scheduler_name__,
366+
'learning_rate_decay_a': self.__a__,
367+
'learning_rate_decay_b': self.__b__
368+
}
369+
370+
371+
class PolyLRS(BaseLRS):
372+
"""
373+
Poly Learning Rate Scheduler.
374+
375+
lr = learning_rate * pow(1 + a * num_samples_processed, -b)
376+
"""
377+
378+
def __init__(self, a, b):
379+
super(PolyLRS, self).__init__(a=a, b=b, scheduler_name='poly')
380+
381+
382+
class ConstantLRS(Optimizer):
383+
"""
384+
Constant Learning Rate Scheduler. Learning rate will not be changed.
385+
"""
386+
387+
def to_setting_kwargs(self):
388+
return {'learning_rate_schedule': 'constant'}
389+
390+
391+
class ExpLRS(BaseLRS):
392+
"""
393+
Exp Learning Rate Scheduler.
394+
395+
lr = learning_rate * pow(a, num_samples_processed/b)
396+
"""
397+
398+
def __init__(self, a, b):
399+
super(ExpLRS, self).__init__(a=a, b=b, scheduler_name='exp')
400+
401+
402+
class DiscreteExpLRS(BaseLRS):
403+
"""
404+
Discrete Exp Learning Rate Scheduler.
405+
406+
lr = learning_rate * pow(a, floor(num_samples_processed / b))
407+
"""
408+
409+
def __init__(self, a, b):
410+
super(DiscreteExpLRS, self).__init__(a=a, b=b, scheduler_name='discexp')
411+
412+
413+
class LinearLRS(BaseLRS):
414+
"""
415+
Linear Learning Rate Scheduler.
416+
417+
lr = max(learning_rate - a, b)
418+
"""
419+
420+
def __init__(self, a, b):
421+
super(LinearLRS, self).__init__(a=a, b=b, scheduler_name='linear')
422+
423+
424+
class ManualLRS(Optimizer):
425+
"""
426+
specify learning rate through explicit pass all learning_rates.
427+
428+
:param learning_rates: list of learning rates. Each item contains two field.
429+
First is a int value, as segmentation. Second is the
430+
learning rate.
431+
432+
The real learning rate is:
433+
434+
if seg_{i-1} <= numSamples <= seg_i,
435+
return lr_{i}
436+
437+
:type learning_rates: list of list. Each element should be (int, float)
438+
"""
439+
440+
def __init__(self, learning_rates):
441+
assert isinstance(learning_rates, collections.Sequence)
442+
with cStringIO.StringIO() as buf:
443+
for i, each in enumerate(learning_rates):
444+
assert isinstance(each, collections.Sequence)
445+
assert len(each) == 2
446+
buf.write("{0}:{1:.5f}".format(int(each[0]), float(each[1])))
447+
if i + 1 != len(learning_rates): # not at end
448+
buf.write(",")
449+
self.__args__ = buf.getvalue()
450+
451+
def to_setting_kwargs(self):
452+
return {
453+
'learning_rate_schedule': 'manual',
454+
'learning_rate_args': self.__args__
455+
}
456+
457+
458+
class PassManualLRS(ManualLRS):
459+
"""
460+
Pass Manual Learning Rate Scheduler.
461+
462+
Basically same as manual learning rate scheduler, except pass manual LRS use
463+
pass number as segment number.
464+
465+
The real learning rate is:
466+
467+
if seg_{i-1} <= pass_id <= seg_i:
468+
return lr_{i}
469+
"""
470+
471+
def __init__(self, learning_rates):
472+
super(PassManualLRS, self).__init__(learning_rates=learning_rates)
473+
474+
def to_setting_kwargs(self):
475+
return {
476+
'learning_rate_schedule': 'pass_manual',
477+
'learning_rate_args': self.__args__
478+
}
479+
480+
354481
@wrap_param_default(
355482
['learning_method'], default_factory=lambda _: MomentumOptimizer())
356483
@wrap_param_default(
357484
['regularization'], default_factory=lambda _: BaseRegularization())
485+
@wrap_param_default(
486+
['learning_rate_args'], default_factory=lambda _: ConstantLRS())
358487
def settings(batch_size,
359488
learning_rate=1e-3,
360489
learning_rate_decay_a=0.,
361490
learning_rate_decay_b=0.,
362-
learning_rate_schedule='poly',
491+
learning_rate_schedule=None,
363492
learning_rate_args='',
364493
learning_method=None,
365494
regularization=None,
@@ -396,6 +525,19 @@ def settings(batch_size,
396525
value larger than some value, will be
397526
clipped.
398527
:type gradient_clipping_threshold: float
528+
529+
:param learning_rate_schedule: A Learning Rate Scheduler object or basestr.
530+
It is recommend to pass a LRS object.
531+
If you set learning_rate_schedule as basestr,
532+
you should manually set learning_rate_decay_a
533+
learning_rate_decay_b and learning_rate_args.
534+
535+
Check LRS.to_setting_kwargs to figure out
536+
how to set these arguments.
537+
:type learning_rate_schedule: basestring|Optimizer
538+
:param learning_rate_decay_a: See learning_rate_schedule.
539+
:param learning_rate_decay_b: See learning_rate_schedule.
540+
:param learning_rate_args: See learning_rate_schedule.
399541
"""
400542
if isinstance(regularization, BaseRegularization):
401543
regularization = [regularization]
@@ -406,15 +548,24 @@ def settings(batch_size,
406548
else:
407549
algorithm = 'owlqn'
408550

409-
args = [
410-
'batch_size', 'learning_rate', 'learning_rate_decay_a',
411-
'learning_rate_decay_b', 'learning_rate_schedule', 'learning_rate_args'
412-
]
551+
args = ['batch_size', 'learning_rate']
413552
kwargs = dict()
414553
kwargs['algorithm'] = algorithm
554+
415555
for arg in args:
416556
kwargs[arg] = locals()[arg]
417557

558+
if isinstance(learning_rate_schedule, Optimizer):
559+
kwargs = __extends__(kwargs, learning_rate_schedule.to_setting_kwargs())
560+
elif isinstance(learning_rate_schedule, basestring):
561+
for arg in [
562+
'learning_rate_decay_a', 'learning_rate_decay_b',
563+
'learning_rate_schedule', 'learning_rate_args'
564+
]:
565+
kwargs[arg] = locals()[arg]
566+
else:
567+
raise RuntimeWarning("Unexcepted branch")
568+
418569
kwargs = __extends__(kwargs, learning_method.to_setting_kwargs())
419570
learning_method.extra_settings()
420571

0 commit comments

Comments
 (0)