16
16
default_gradient_clipping_threshold , default_momentum
17
17
18
18
from .default_decorators import wrap_param_default
19
+ import collections
20
+ import cStringIO
19
21
20
22
__all__ = [
21
23
'Optimizer' , 'BaseSGDOptimizer' , 'MomentumOptimizer' , 'AdamaxOptimizer' ,
22
24
'AdamOptimizer' , 'AdaGradOptimizer' , 'RMSPropOptimizer' ,
23
25
'DecayedAdaGradOptimizer' , 'AdaDeltaOptimizer' , 'BaseRegularization' ,
24
- 'L2Regularization' , 'settings' , 'ModelAverage'
26
+ 'L2Regularization' , 'settings' , 'ModelAverage' , 'PolyLRS' , 'ConstantLRS' ,
27
+ 'ExpLRS' , 'DiscreteExpLRS' , 'LinearLRS' , 'ManualLRS' , 'PassManualLRS'
25
28
]
26
29
27
30
@@ -351,15 +354,141 @@ def __extends__(dict1, dict2):
351
354
return dict1
352
355
353
356
357
+ class BaseLRS (Optimizer ):
358
+ def __init__ (self , a , b , scheduler_name ):
359
+ self .__a__ = float (a )
360
+ self .__b__ = float (b )
361
+ self .__scheduler_name__ = scheduler_name
362
+
363
+ def to_setting_kwargs (self ):
364
+ return {
365
+ 'learning_rate_schedule' : self .__scheduler_name__ ,
366
+ 'learning_rate_decay_a' : self .__a__ ,
367
+ 'learning_rate_decay_b' : self .__b__
368
+ }
369
+
370
+
371
+ class PolyLRS (BaseLRS ):
372
+ """
373
+ Poly Learning Rate Scheduler.
374
+
375
+ lr = learning_rate * pow(1 + a * num_samples_processed, -b)
376
+ """
377
+
378
+ def __init__ (self , a , b ):
379
+ super (PolyLRS , self ).__init__ (a = a , b = b , scheduler_name = 'poly' )
380
+
381
+
382
+ class ConstantLRS (Optimizer ):
383
+ """
384
+ Constant Learning Rate Scheduler. Learning rate will not be changed.
385
+ """
386
+
387
+ def to_setting_kwargs (self ):
388
+ return {'learning_rate_schedule' : 'constant' }
389
+
390
+
391
+ class ExpLRS (BaseLRS ):
392
+ """
393
+ Exp Learning Rate Scheduler.
394
+
395
+ lr = learning_rate * pow(a, num_samples_processed/b)
396
+ """
397
+
398
+ def __init__ (self , a , b ):
399
+ super (ExpLRS , self ).__init__ (a = a , b = b , scheduler_name = 'exp' )
400
+
401
+
402
+ class DiscreteExpLRS (BaseLRS ):
403
+ """
404
+ Discrete Exp Learning Rate Scheduler.
405
+
406
+ lr = learning_rate * pow(a, floor(num_samples_processed / b))
407
+ """
408
+
409
+ def __init__ (self , a , b ):
410
+ super (DiscreteExpLRS , self ).__init__ (a = a , b = b , scheduler_name = 'discexp' )
411
+
412
+
413
+ class LinearLRS (BaseLRS ):
414
+ """
415
+ Linear Learning Rate Scheduler.
416
+
417
+ lr = max(learning_rate - a, b)
418
+ """
419
+
420
+ def __init__ (self , a , b ):
421
+ super (LinearLRS , self ).__init__ (a = a , b = b , scheduler_name = 'linear' )
422
+
423
+
424
+ class ManualLRS (Optimizer ):
425
+ """
426
+ specify learning rate through explicit pass all learning_rates.
427
+
428
+ :param learning_rates: list of learning rates. Each item contains two field.
429
+ First is a int value, as segmentation. Second is the
430
+ learning rate.
431
+
432
+ The real learning rate is:
433
+
434
+ if seg_{i-1} <= numSamples <= seg_i,
435
+ return lr_{i}
436
+
437
+ :type learning_rates: list of list. Each element should be (int, float)
438
+ """
439
+
440
+ def __init__ (self , learning_rates ):
441
+ assert isinstance (learning_rates , collections .Sequence )
442
+ with cStringIO .StringIO () as buf :
443
+ for i , each in enumerate (learning_rates ):
444
+ assert isinstance (each , collections .Sequence )
445
+ assert len (each ) == 2
446
+ buf .write ("{0}:{1:.5f}" .format (int (each [0 ]), float (each [1 ])))
447
+ if i + 1 != len (learning_rates ): # not at end
448
+ buf .write ("," )
449
+ self .__args__ = buf .getvalue ()
450
+
451
+ def to_setting_kwargs (self ):
452
+ return {
453
+ 'learning_rate_schedule' : 'manual' ,
454
+ 'learning_rate_args' : self .__args__
455
+ }
456
+
457
+
458
+ class PassManualLRS (ManualLRS ):
459
+ """
460
+ Pass Manual Learning Rate Scheduler.
461
+
462
+ Basically same as manual learning rate scheduler, except pass manual LRS use
463
+ pass number as segment number.
464
+
465
+ The real learning rate is:
466
+
467
+ if seg_{i-1} <= pass_id <= seg_i:
468
+ return lr_{i}
469
+ """
470
+
471
+ def __init__ (self , learning_rates ):
472
+ super (PassManualLRS , self ).__init__ (learning_rates = learning_rates )
473
+
474
+ def to_setting_kwargs (self ):
475
+ return {
476
+ 'learning_rate_schedule' : 'pass_manual' ,
477
+ 'learning_rate_args' : self .__args__
478
+ }
479
+
480
+
354
481
@wrap_param_default (
355
482
['learning_method' ], default_factory = lambda _ : MomentumOptimizer ())
356
483
@wrap_param_default (
357
484
['regularization' ], default_factory = lambda _ : BaseRegularization ())
485
+ @wrap_param_default (
486
+ ['learning_rate_args' ], default_factory = lambda _ : ConstantLRS ())
358
487
def settings (batch_size ,
359
488
learning_rate = 1e-3 ,
360
489
learning_rate_decay_a = 0. ,
361
490
learning_rate_decay_b = 0. ,
362
- learning_rate_schedule = 'poly' ,
491
+ learning_rate_schedule = None ,
363
492
learning_rate_args = '' ,
364
493
learning_method = None ,
365
494
regularization = None ,
@@ -396,6 +525,19 @@ def settings(batch_size,
396
525
value larger than some value, will be
397
526
clipped.
398
527
:type gradient_clipping_threshold: float
528
+
529
+ :param learning_rate_schedule: A Learning Rate Scheduler object or basestr.
530
+ It is recommend to pass a LRS object.
531
+ If you set learning_rate_schedule as basestr,
532
+ you should manually set learning_rate_decay_a
533
+ learning_rate_decay_b and learning_rate_args.
534
+
535
+ Check LRS.to_setting_kwargs to figure out
536
+ how to set these arguments.
537
+ :type learning_rate_schedule: basestring|Optimizer
538
+ :param learning_rate_decay_a: See learning_rate_schedule.
539
+ :param learning_rate_decay_b: See learning_rate_schedule.
540
+ :param learning_rate_args: See learning_rate_schedule.
399
541
"""
400
542
if isinstance (regularization , BaseRegularization ):
401
543
regularization = [regularization ]
@@ -406,15 +548,24 @@ def settings(batch_size,
406
548
else :
407
549
algorithm = 'owlqn'
408
550
409
- args = [
410
- 'batch_size' , 'learning_rate' , 'learning_rate_decay_a' ,
411
- 'learning_rate_decay_b' , 'learning_rate_schedule' , 'learning_rate_args'
412
- ]
551
+ args = ['batch_size' , 'learning_rate' ]
413
552
kwargs = dict ()
414
553
kwargs ['algorithm' ] = algorithm
554
+
415
555
for arg in args :
416
556
kwargs [arg ] = locals ()[arg ]
417
557
558
+ if isinstance (learning_rate_schedule , Optimizer ):
559
+ kwargs = __extends__ (kwargs , learning_rate_schedule .to_setting_kwargs ())
560
+ elif isinstance (learning_rate_schedule , basestring ):
561
+ for arg in [
562
+ 'learning_rate_decay_a' , 'learning_rate_decay_b' ,
563
+ 'learning_rate_schedule' , 'learning_rate_args'
564
+ ]:
565
+ kwargs [arg ] = locals ()[arg ]
566
+ else :
567
+ raise RuntimeWarning ("Unexcepted branch" )
568
+
418
569
kwargs = __extends__ (kwargs , learning_method .to_setting_kwargs ())
419
570
learning_method .extra_settings ()
420
571
0 commit comments