11from __future__ import print_function
22import numpy as np
3+ import scipy .stats as st
34import multiprocessing as mp
5+ from collections .abc import Iterable
46
57np .random .seed (0 )
68
@@ -9,24 +11,66 @@ def worker_process(arg):
911 get_reward_func , weights = arg
1012 return get_reward_func (weights )
1113
14+ class WeightUpdateStrategy :
15+ __slots__ = ("learning_rate" ,)
16+ def __init__ (self , dim , learning_rate ):
17+ self .learning_rate = learning_rate
18+
19+
20+ class strategies :
21+ class GD (WeightUpdateStrategy ):
22+ def update (self , i , g ):
23+ return self .learning_rate * g
24+
25+
26+ class Adam (WeightUpdateStrategy ):
27+ __slots__ = ("eps" , "beta1" , "beta2" , "m" , "v" )
28+ def __init__ (self , dim , learning_rate , eps = 1e-8 , beta1 = 0.9 , beta2 = 0.999 ):
29+ super ().__init__ (dim , learning_rate )
30+ self .eps = eps
31+ self .beta1 = beta1
32+ self .beta2 = beta2
33+ self .m = np .zeros (dim )
34+ self .v = np .zeros (dim )
35+
36+ def update (self , i , g ):
37+ self .m [i ] = self .beta1 * self .m [i ] + (1 - self .beta1 ) * g
38+ self .v [i ] = self .beta2 * self .v [i ] + (1 - self .beta2 ) * (g ** 2 )
39+ return self .learning_rate * np .sqrt (1 - self .beta2 ) / (1 - self .beta1 ) * self .m [i ] / np .sqrt (np .sqrt (self .v [i ])+ self .eps )
40+
1241
1342class EvolutionStrategy (object ):
1443 def __init__ (self , weights , get_reward_func , population_size = 50 , sigma = 0.1 , learning_rate = 0.03 , decay = 0.999 ,
15- num_threads = 1 ):
16-
44+ num_threads = 1 , limits = None , printer = None , distributions = None , strategy = None ):
45+ if limits is None :
46+ limits = (np .inf , - np .inf )
1747 self .weights = weights
48+ self .limits = limits
1849 self .get_reward = get_reward_func
1950 self .POPULATION_SIZE = population_size
20- self .SIGMA = sigma
51+ if distributions is None :
52+ distributions = st .norm (loc = 0. , scale = sigma )
53+ if isinstance (distributions , Iterable ):
54+ distributions = list (distributions )
55+ self .SIGMA = np .array ([d .std () for d in distributions ])
56+ else :
57+ self .SIGMA = distributions .std ()
58+
59+ self .distributions = distributions
2160 self .learning_rate = learning_rate
2261 self .decay = decay
2362 self .num_threads = mp .cpu_count () if num_threads == - 1 else num_threads
63+ if printer is None :
64+ printer = print
65+ self .printer = printer
66+ if strategy is None :
67+ strategy = strategies .GD
68+ self .strategy = strategy (len (weights ), self .learning_rate )
2469
2570 def _get_weights_try (self , w , p ):
2671 weights_try = []
2772 for index , i in enumerate (p ):
28- jittered = self .SIGMA * i
29- weights_try .append (w [index ] + jittered )
73+ weights_try .append (w [index ] + i )
3074 return weights_try
3175
3276 def get_weights (self ):
@@ -36,8 +80,13 @@ def _get_population(self):
3680 population = []
3781 for i in range (self .POPULATION_SIZE ):
3882 x = []
39- for w in self .weights :
40- x .append (np .random .randn (* w .shape ))
83+ if isinstance (self .distributions , Iterable ):
84+ for j , w in enumerate (self .weights ):
85+ x .append (self .distributions [j ].rvs (* w .shape ))
86+ else :
87+ for w in self .weights :
88+ x .append (self .distributions .rvs (* w .shape ))
89+
4190 population .append (x )
4291 return population
4392
@@ -59,10 +108,17 @@ def _update_weights(self, rewards, population):
59108 if std == 0 :
60109 return
61110 rewards = (rewards - rewards .mean ()) / std
111+ grad_factor = 1. / (self .POPULATION_SIZE * (self .SIGMA ** 2 ))
112+
62113 for index , w in enumerate (self .weights ):
63114 layer_population = np .array ([p [index ] for p in population ])
64- update_factor = self .learning_rate / (self .POPULATION_SIZE * self .SIGMA )
65- self .weights [index ] = w + update_factor * np .dot (layer_population .T , rewards ).T
115+ corr = np .dot (layer_population .T , rewards ).T
116+
117+ if not isinstance (grad_factor , np .ndarray ):
118+ g = grad_factor * corr
119+ else :
120+ g = grad_factor [index ] * corr
121+ self .weights [index ] = w + self .strategy .update (index , g )
66122 self .learning_rate *= self .decay
67123
68124 def run (self , iterations , print_step = 10 ):
@@ -75,7 +131,7 @@ def run(self, iterations, print_step=10):
75131 self ._update_weights (rewards , population )
76132
77133 if (iteration + 1 ) % print_step == 0 :
78- print ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )))
134+ self . printer ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )), ( self . weights if self . weights . shape [ 0 ] <= 10 else None ) )
79135 if pool is not None :
80136 pool .close ()
81137 pool .join ()
0 commit comments