11from __future__ import print_function
22import numpy as np
3+ import scipy .stats as st
34import multiprocessing as mp
5+ from collections .abc import Iterable
46
57np .random .seed (0 )
68
@@ -12,21 +14,33 @@ def worker_process(arg):
1214
1315class EvolutionStrategy (object ):
1416 def __init__ (self , weights , get_reward_func , population_size = 50 , sigma = 0.1 , learning_rate = 0.03 , decay = 0.999 ,
15- num_threads = 1 ):
16-
17+ num_threads = 1 , limits = None , printer = None , distributions = None ):
18+ if limits is None :
19+ limits = (np .inf , - np .inf )
1720 self .weights = weights
21+ self .limits = limits
1822 self .get_reward = get_reward_func
1923 self .POPULATION_SIZE = population_size
20- self .SIGMA = sigma
24+ if distributions is None :
25+ distributions = st .norm (loc = 0. , scale = sigma )
26+ if isinstance (distributions , Iterable ):
27+ distributions = list (distributions )
28+ self .SIGMA = np .array ([d .std () for d in distributions ])
29+ else :
30+ self .SIGMA = distributions .std ()
31+
32+ self .distributions = distributions
2133 self .learning_rate = learning_rate
2234 self .decay = decay
2335 self .num_threads = mp .cpu_count () if num_threads == - 1 else num_threads
36+ if printer is None :
37+ printer = print
38+ self .printer = printer
2439
2540 def _get_weights_try (self , w , p ):
2641 weights_try = []
2742 for index , i in enumerate (p ):
28- jittered = self .SIGMA * i
29- weights_try .append (w [index ] + jittered )
43+ weights_try .append (w [index ] + i )
3044 return weights_try
3145
3246 def get_weights (self ):
@@ -36,8 +50,13 @@ def _get_population(self):
3650 population = []
3751 for i in range (self .POPULATION_SIZE ):
3852 x = []
39- for w in self .weights :
40- x .append (np .random .randn (* w .shape ))
53+ if isinstance (self .distributions , Iterable ):
54+ for j , w in enumerate (self .weights ):
55+ x .append (self .distributions [j ].rvs (* w .shape ))
56+ else :
57+ for w in self .weights :
58+ x .append (self .distributions .rvs (* w .shape ))
59+
4160 population .append (x )
4261 return population
4362
@@ -59,10 +78,14 @@ def _update_weights(self, rewards, population):
5978 if std == 0 :
6079 return
6180 rewards = (rewards - rewards .mean ()) / std
81+ update_factor = self .learning_rate / (self .POPULATION_SIZE * self .SIGMA )
82+
6283 for index , w in enumerate (self .weights ):
6384 layer_population = np .array ([p [index ] for p in population ])
64- update_factor = self .learning_rate / (self .POPULATION_SIZE * self .SIGMA )
65- self .weights [index ] = w + update_factor * np .dot (layer_population .T , rewards ).T
85+ if not isinstance (update_factor , np .ndarray ):
86+ self .weights [index ] = w + update_factor * np .dot (layer_population .T , rewards ).T
87+ else :
88+ self .weights [index ] = w + update_factor [index ] * np .dot (layer_population .T , rewards ).T
6689 self .learning_rate *= self .decay
6790
6891 def run (self , iterations , print_step = 10 ):
@@ -75,7 +98,7 @@ def run(self, iterations, print_step=10):
7598 self ._update_weights (rewards , population )
7699
77100 if (iteration + 1 ) % print_step == 0 :
78- print ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )))
101+ self . printer ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )))
79102 if pool is not None :
80103 pool .close ()
81104 pool .join ()
0 commit comments