Skip to content

Commit 4f123b8

Browse files
authored
Merge pull request #131 from kengz/schedule
fine tune PER, fix bugs
2 parents 49be028 + d54676e commit 4f123b8

15 files changed

+286
-281
lines changed

rl/agent/actor_critic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def train_critic(self, minibatch):
114114
actor_delta = Q_next_vals - Q_vals
115115
loss = self.critic.train_on_batch(minibatch['states'], Q_targets)
116116

117+
# update memory, needed for PER
117118
errors = abs(np.sum(Q_vals - Q_targets, axis=1))
119+
# Q size is only 1, from critic
120+
assert Q_targets.shape == (self.batch_size, 1)
121+
assert errors.shape == (self.batch_size, )
118122
self.memory.update(errors)
119123
return loss, actor_delta
120124

rl/agent/ddpg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def train_an_epoch(self):
242242

243243
# train critic
244244
mu_prime = self.actor.target_predict(minibatch['next_states'])
245+
q_val = self.critic.target_predict(minibatch['states'], mu_prime)
245246
q_prime = self.critic.target_predict(
246247
minibatch['next_states'], mu_prime)
247248
# reshape for element-wise multiplication
@@ -250,6 +251,13 @@ def train_an_epoch(self):
250251
(1 - minibatch['terminals']) * np.reshape(q_prime, (-1))
251252
y = np.reshape(y, (-1, 1))
252253

254+
# update memory, needed for PER
255+
errors = abs(np.sum(q_val - y, axis=1))
256+
# Q size is only 1, from critic
257+
assert y.shape == (self.batch_size, 1)
258+
assert errors.shape == (self.batch_size, )
259+
self.memory.update(errors)
260+
253261
_, _, critic_loss = self.critic.train_tf(
254262
minibatch['states'], minibatch['actions'], y)
255263

rl/agent/deep_sarsa.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from rl.agent.dqn import DQN
23

34

@@ -30,4 +31,10 @@ def train_an_epoch(self):
3031
Q_targets = self.compute_Q_targets(
3132
minibatch, Q_states, Q_next_states_selected)
3233
loss = self.model.train_on_batch(minibatch['states'], Q_targets)
34+
35+
errors = abs(np.sum(Q_states - Q_targets, axis=1))
36+
assert Q_targets.shape == (
37+
self.batch_size, self.env_spec['action_dim'])
38+
assert errors.shape == (self.batch_size, )
39+
self.memory.update(errors)
3340
return loss

rl/agent/dqn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,12 @@ def train_an_epoch(self):
190190
minibatch)
191191
Q_targets = self.compute_Q_targets(
192192
minibatch, Q_states, Q_next_states_max)
193-
194193
loss = self.model.train_on_batch(minibatch['states'], Q_targets)
195194

196195
errors = abs(np.sum(Q_states - Q_targets, axis=1))
196+
assert Q_targets.shape == (
197+
self.batch_size, self.env_spec['action_dim'])
198+
assert errors.shape == (self.batch_size, )
197199
self.memory.update(errors)
198200
return loss
199201

rl/analytics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,10 @@ def compose_data(trial):
317317
}
318318

319319
# param variables for independent vars of trials
320+
default_param = trial.experiment_spec['param']
320321
param_variables = {
321-
pv: trial.experiment_spec['param'][pv] for
322-
pv in trial.param_variables}
322+
pv: default_param[pv] for
323+
pv in trial.param_variables if pv in default_param}
323324

324325
trial.data['metrics'].update(metrics)
325326
trial.data['param_variables'] = param_variables
@@ -459,7 +460,7 @@ def analyze_data(experiment_data_or_experiment_id):
459460

460461
data_df.sort_values(
461462
['fitness_score'], ascending=False, inplace=True)
462-
data_df.reset_index(inplace=True)
463+
data_df.reset_index(drop=True, inplace=True)
463464

464465
trial_id = experiment_data[0]['trial_id']
465466
save_experiment_data(data_df, trial_id)

rl/memory/prioritized_exp_replay.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@ class PrioritizedExperienceReplay(LinearMemoryWithForgetting):
1212
memory unit
1313
'''
1414

15-
def __init__(self, env_spec, max_mem_len=10000, e=0.01, alpha=0.6,
15+
def __init__(self, env_spec, max_mem_len=None, e=0.01, alpha=0.6,
1616
**kwargs):
17+
if max_mem_len is None: # auto calculate mem len
18+
max_timestep = env_spec['timestep_limit']
19+
max_epis = env_spec['problem']['MAX_EPISODES']
20+
memory_epi = np.ceil(max_epis / 3.).astype(int)
21+
max_mem_len = max(10**6, max_timestep * memory_epi)
1722
super(PrioritizedExperienceReplay, self).__init__(
1823
env_spec, max_mem_len)
1924
self.exp_keys.append('error')
@@ -27,21 +32,18 @@ def __init__(self, env_spec, max_mem_len=10000, e=0.01, alpha=0.6,
2732
self.prio_tree = SumTree(self.max_mem_len)
2833
self.head = 0
2934

30-
# bump to account for negative terms in reward get_priority
31-
# and we cannot abs(reward) cuz it's sign sensitive
32-
SOLVED_MEAN_REWARD = self.env_spec['problem']['SOLVED_MEAN_REWARD'] or 10000
33-
self.min_priority = abs(10 * SOLVED_MEAN_REWARD)
34-
3535
def get_priority(self, error):
3636
# add min_priority to prevent root of negative = complex
37-
p = (self.min_priority + error + self.e) ** self.alpha
38-
assert not np.isnan(p)
37+
p = (error + self.e) ** self.alpha
38+
assert np.isfinite(p)
3939
return p
4040

4141
def add_exp(self, action, reward, next_state, terminal):
4242
'''Round robin memory updating'''
43-
# roughly the error between estimated Q and true q is the reward
44-
error = reward
43+
# init error to reward first, update later
44+
error = abs(reward)
45+
p = self.get_priority(error)
46+
4547
if self.size() < self.max_mem_len: # add as usual
4648
super(PrioritizedExperienceReplay, self).add_exp(
4749
action, reward, next_state, terminal)
@@ -59,7 +61,6 @@ def add_exp(self, action, reward, next_state, terminal):
5961
if self.head >= self.max_mem_len:
6062
self.head = 0 # reset for round robin
6163

62-
p = self.get_priority(error)
6364
self.prio_tree.add(p)
6465

6566
assert self.head == self.prio_tree.head, 'prio_tree head is wrong'

rl/policy/actor_critic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,18 @@ class SoftmaxPolicy(Policy):
3939
def __init__(self, env_spec,
4040
**kwargs): # absorb generic param without breaking
4141
super(SoftmaxPolicy, self).__init__(env_spec)
42-
self.clip_val = 500
42+
self.clip_val = 500.
4343
log_self(self)
4444

4545
def select_action(self, state):
4646
agent = self.agent
4747
state = np.expand_dims(state, axis=0)
4848
A_score = agent.actor.predict(state)[0] # extract from batch predict
4949
assert A_score.ndim == 1
50-
A_score = A_score.astype('float32') # fix precision nan issue
51-
A_score = A_score - np.amax(A_score) # prevent overflow
50+
A_score = A_score.astype('float64') # fix precision overflow
5251
exp_values = np.exp(
5352
np.clip(A_score, -self.clip_val, self.clip_val))
54-
assert not np.isnan(exp_values).any()
53+
assert np.isfinite(exp_values).all()
5554
probs = np.array(exp_values / np.sum(exp_values))
5655
probs /= probs.sum() # renormalize to prevent floating pt error
5756
action = np.random.choice(agent.env_spec['actions'], p=probs)
@@ -83,6 +82,9 @@ def select_action(self, state):
8382
a_mean = agent.actor.predict(state)[0] # extract from batch predict
8483
action = a_mean + np.random.normal(
8584
loc=0.0, scale=self.variance, size=a_mean.shape)
85+
action = np.clip(action,
86+
self.env_spec['action_bound_low'],
87+
self.env_spec['action_bound_high'])
8688
return action
8789

8890
def update(self, sys_vars):

rl/policy/boltzmann.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,18 @@ def __init__(self, env_spec,
1818
self.final_tau = final_tau
1919
self.tau = self.init_tau
2020
self.exploration_anneal_episodes = exploration_anneal_episodes
21-
self.clip_val = 500
21+
self.clip_val = 500.
2222
log_self(self)
2323

2424
def select_action(self, state):
2525
agent = self.agent
2626
state = np.expand_dims(state, axis=0)
2727
Q_state = agent.model.predict(state)[0] # extract from batch predict
2828
assert Q_state.ndim == 1
29-
Q_state = Q_state.astype('float32') # fix precision nan issue
30-
Q_state = Q_state - np.amax(Q_state) # prevent overflow
29+
Q_state = Q_state.astype('float64') # fix precision overflow
3130
exp_values = np.exp(
3231
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
33-
assert not np.isnan(exp_values).any()
32+
assert np.isfinite(exp_values).all()
3433
probs = np.array(exp_values / np.sum(exp_values))
3534
probs /= probs.sum() # renormalize to prevent floating pt error
3635
action = np.random.choice(agent.env_spec['actions'], p=probs)
@@ -66,11 +65,10 @@ def select_action(self, state):
6665
Q_state2 = agent.model_2.predict(state)[0]
6766
Q_state = Q_state1 + Q_state2
6867
assert Q_state.ndim == 1
69-
Q_state = Q_state.astype('float32') # fix precision nan issue
70-
Q_state = Q_state - np.amax(Q_state) # prevent overflow
68+
Q_state = Q_state.astype('float64') # fix precision overflow
7169
exp_values = np.exp(
7270
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
73-
assert not np.isnan(exp_values).any()
71+
assert np.isfinite(exp_values).all()
7472
probs = np.array(exp_values / np.sum(exp_values))
7573
probs /= probs.sum() # renormalize to prevent floating pt error
7674
action = np.random.choice(agent.env_spec['actions'], p=probs)

rl/policy/noise.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from rl.util import log_self
33
from rl.policy.base_policy import Policy
4+
from rl.policy.epsilon_greedy import EpsilonGreedyPolicy
45

56

67
class NoNoisePolicy(Policy):
@@ -25,6 +26,9 @@ def select_action(self, state):
2526
state = np.expand_dims(state, axis=0)
2627
if self.env_spec['actions'] == 'continuous':
2728
action = agent.actor.predict(state)[0] + self.sample()
29+
action = np.clip(action,
30+
self.env_spec['action_bound_low'],
31+
self.env_spec['action_bound_high'])
2832
else:
2933
Q_state = agent.actor.predict(state)[0]
3034
assert Q_state.ndim == 1
@@ -60,6 +64,26 @@ def update(self, sys_vars):
6064
self.n_step = sys_vars['epi']
6165

6266

67+
class EpsilonGreedyNoisePolicy(EpsilonGreedyPolicy, NoNoisePolicy):
68+
69+
'''
70+
akin to epsilon greedy decay,
71+
but return random sample instead
72+
'''
73+
74+
def sample(self):
75+
if self.e > np.random.rand():
76+
noise = np.random.uniform(
77+
0.5 * self.env_spec['action_bound_low'],
78+
0.5 * self.env_spec['action_bound_high'])
79+
else:
80+
noise = 0
81+
return noise
82+
83+
def select_action(self, state):
84+
return NoNoisePolicy.select_action(self, state)
85+
86+
6387
class AnnealedGaussianPolicy(LinearNoisePolicy):
6488

6589
'''

rl/spec/box2d_experiment_specs.json

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,35 +97,6 @@
9797
]
9898
}
9999
},
100-
"lunar_double_dqn_per": {
101-
"problem": "LunarLander-v2",
102-
"Agent": "DoubleDQN",
103-
"HyperOptimizer": "GridSearch",
104-
"Memory": "PrioritizedExperienceReplay",
105-
"Optimizer": "AdamOptimizer",
106-
"Policy": "DoubleDQNBoltzmannPolicy",
107-
"PreProcessor": "StackStates",
108-
"param": {
109-
"train_per_n_new_exp": 2,
110-
"lr": 0.005,
111-
"gamma": 0.99,
112-
"hidden_layers": [800, 400],
113-
"hidden_layers_activation": "sigmoid",
114-
"output_layer_activation": "linear",
115-
"exploration_anneal_episodes": 150,
116-
"epi_change_lr": 200,
117-
"max_mem_len": 30000
118-
},
119-
"param_range": {
120-
"lr": [0.001, 0.005, 0.01],
121-
"gamma": [0.97, 0.99, 0.999],
122-
"hidden_layers": [
123-
[400, 200],
124-
[800, 400],
125-
[400, 200, 100]
126-
]
127-
}
128-
},
129100
"lunar_double_dqn_nopreprocess": {
130101
"problem": "LunarLander-v2",
131102
"Agent": "DoubleDQN",
@@ -266,11 +237,11 @@
266237
]
267238
}
268239
},
269-
"lunar_ddpg_linearnoise": {
240+
"lunar_cont_ddpg_per_linearnoise": {
270241
"problem": "LunarLanderContinuous-v2",
271242
"Agent": "DDPG",
272243
"HyperOptimizer": "GridSearch",
273-
"Memory": "LinearMemoryWithForgetting",
244+
"Memory": "PrioritizedExperienceReplay",
274245
"Optimizer": "AdamOptimizer",
275246
"Policy": "LinearNoisePolicy",
276247
"PreProcessor": "NoPreProcessor",
@@ -327,5 +298,35 @@
327298
[800, 400, 200]
328299
]
329300
}
301+
},
302+
"walker_ddpg_per_linearnoise": {
303+
"problem": "BipedalWalker-v2",
304+
"Agent": "DDPG",
305+
"HyperOptimizer": "GridSearch",
306+
"Memory": "PrioritizedExperienceReplay",
307+
"Optimizer": "AdamOptimizer",
308+
"Policy": "LinearNoisePolicy",
309+
"PreProcessor": "NoPreProcessor",
310+
"param": {
311+
"batch_size": 64,
312+
"n_epoch": 1,
313+
"tau": 0.005,
314+
"lr": 0.0005,
315+
"critic_lr": 0.001,
316+
"gamma": 0.97,
317+
"hidden_layers": [400, 200],
318+
"hidden_layers_activation": "relu",
319+
"output_layer_activation": "tanh"
320+
},
321+
"param_range": {
322+
"lr": [0.0001, 0.0005],
323+
"critic_lr": [0.001, 0.005],
324+
"gamma": [0.95, 0.97, 0.99],
325+
"hidden_layers": [
326+
[200, 100],
327+
[400, 300],
328+
[800, 400]
329+
]
330+
}
330331
}
331332
}

0 commit comments

Comments
 (0)