@@ -18,19 +18,18 @@ def __init__(self, env_spec,
1818 self .final_tau = final_tau
1919 self .tau = self .init_tau
2020 self .exploration_anneal_episodes = exploration_anneal_episodes
21- self .clip_val = 200
21+ self .clip_val = 500.
2222 log_self (self )
2323
2424 def select_action (self , state ):
2525 agent = self .agent
2626 state = np .expand_dims (state , axis = 0 )
2727 Q_state = agent .model .predict (state )[0 ] # extract from batch predict
2828 assert Q_state .ndim == 1
29- Q_state = Q_state .astype ('float32' ) # fix precision nan issue
30- Q_state = Q_state - np .amax (Q_state ) # prevent overflow
29+ Q_state = Q_state .astype ('float64' ) # fix precision overflow
3130 exp_values = np .exp (
32- np .clip (Q_state / float ( self .tau ) , - self .clip_val , self .clip_val ))
33- assert not np .isnan (exp_values ).any ()
31+ np .clip (Q_state / self .tau , - self .clip_val , self .clip_val ))
32+ assert np .isfinite (exp_values ).all ()
3433 probs = np .array (exp_values / np .sum (exp_values ))
3534 probs /= probs .sum () # renormalize to prevent floating pt error
3635 action = np .random .choice (agent .env_spec ['actions' ], p = probs )
@@ -66,11 +65,10 @@ def select_action(self, state):
6665 Q_state2 = agent .model_2 .predict (state )[0 ]
6766 Q_state = Q_state1 + Q_state2
6867 assert Q_state .ndim == 1
69- Q_state = Q_state .astype ('float32' ) # fix precision nan issue
70- Q_state = Q_state - np .amax (Q_state ) # prevent overflow
68+ Q_state = Q_state .astype ('float64' ) # fix precision overflow
7169 exp_values = np .exp (
72- np .clip (Q_state / float ( self .tau ) , - self .clip_val , self .clip_val ))
73- assert not np .isnan (exp_values ).any ()
70+ np .clip (Q_state / self .tau , - self .clip_val , self .clip_val ))
71+ assert np .isfinite (exp_values ).all ()
7472 probs = np .array (exp_values / np .sum (exp_values ))
7573 probs /= probs .sum () # renormalize to prevent floating pt error
7674 action = np .random .choice (agent .env_spec ['actions' ], p = probs )
0 commit comments