Skip to content

Commit d7f5cec

Browse files
committed
boltzman fix overflow by np float64; remove offset minus
1 parent 668729a commit d7f5cec

File tree

3 files changed

+11
-14
lines changed

3 files changed

+11
-14
lines changed

rl/memory/prioritized_exp_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, env_spec, max_mem_len=10000, e=0.01, alpha=0.6,
3030
def get_priority(self, error):
3131
# add min_priority to prevent root of negative = complex
3232
p = (error + self.e) ** self.alpha
33-
assert not np.isnan(p)
33+
assert np.isfinite(p)
3434
return p
3535

3636
def add_exp(self, action, reward, next_state, terminal):

rl/policy/actor_critic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,18 @@ class SoftmaxPolicy(Policy):
3939
def __init__(self, env_spec,
4040
**kwargs): # absorb generic param without breaking
4141
super(SoftmaxPolicy, self).__init__(env_spec)
42-
self.clip_val = 500
42+
self.clip_val = 500.
4343
log_self(self)
4444

4545
def select_action(self, state):
4646
agent = self.agent
4747
state = np.expand_dims(state, axis=0)
4848
A_score = agent.actor.predict(state)[0] # extract from batch predict
4949
assert A_score.ndim == 1
50-
A_score = A_score.astype('float32') # fix precision nan issue
51-
A_score = A_score - np.amax(A_score) # prevent overflow
50+
A_score = A_score.astype('float64') # fix precision overflow
5251
exp_values = np.exp(
5352
np.clip(A_score, -self.clip_val, self.clip_val))
54-
assert not np.isnan(exp_values).any()
53+
assert np.isfinite(exp_values).all()
5554
probs = np.array(exp_values / np.sum(exp_values))
5655
probs /= probs.sum() # renormalize to prevent floating pt error
5756
action = np.random.choice(agent.env_spec['actions'], p=probs)

rl/policy/boltzmann.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,18 @@ def __init__(self, env_spec,
1818
self.final_tau = final_tau
1919
self.tau = self.init_tau
2020
self.exploration_anneal_episodes = exploration_anneal_episodes
21-
self.clip_val = 200
21+
self.clip_val = 500.
2222
log_self(self)
2323

2424
def select_action(self, state):
2525
agent = self.agent
2626
state = np.expand_dims(state, axis=0)
2727
Q_state = agent.model.predict(state)[0] # extract from batch predict
2828
assert Q_state.ndim == 1
29-
Q_state = Q_state.astype('float32') # fix precision nan issue
30-
Q_state = Q_state - np.amax(Q_state) # prevent overflow
29+
Q_state = Q_state.astype('float64') # fix precision overflow
3130
exp_values = np.exp(
32-
np.clip(Q_state / float(self.tau), -self.clip_val, self.clip_val))
33-
assert not np.isnan(exp_values).any()
31+
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
32+
assert np.isfinite(exp_values).all()
3433
probs = np.array(exp_values / np.sum(exp_values))
3534
probs /= probs.sum() # renormalize to prevent floating pt error
3635
action = np.random.choice(agent.env_spec['actions'], p=probs)
@@ -66,11 +65,10 @@ def select_action(self, state):
6665
Q_state2 = agent.model_2.predict(state)[0]
6766
Q_state = Q_state1 + Q_state2
6867
assert Q_state.ndim == 1
69-
Q_state = Q_state.astype('float32') # fix precision nan issue
70-
Q_state = Q_state - np.amax(Q_state) # prevent overflow
68+
Q_state = Q_state.astype('float64') # fix precision overflow
7169
exp_values = np.exp(
72-
np.clip(Q_state / float(self.tau), -self.clip_val, self.clip_val))
73-
assert not np.isnan(exp_values).any()
70+
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
71+
assert np.isfinite(exp_values).all()
7472
probs = np.array(exp_values / np.sum(exp_values))
7573
probs /= probs.sum() # renormalize to prevent floating pt error
7674
action = np.random.choice(agent.env_spec['actions'], p=probs)

0 commit comments

Comments
 (0)