diff --git a/DDQN/ddqn_agent.py b/DDQN/ddqn_agent.py index 66a15fe..9cea127 100644 --- a/DDQN/ddqn_agent.py +++ b/DDQN/ddqn_agent.py @@ -42,7 +42,7 @@ def sample_memory(self): states = T.tensor(state).to(self.q_eval.device) rewards = T.tensor(reward).to(self.q_eval.device) - dones = T.tensor(done).to(self.q_eval.device) + dones = T.tensor(done, dtype=T.bool).to(self.q_eval.device) actions = T.tensor(action).to(self.q_eval.device) states_ = T.tensor(new_state).to(self.q_eval.device) @@ -83,7 +83,7 @@ def learn(self): q_next = self.q_next.forward(states_) q_eval = self.q_eval.forward(states_) - max_actions = T.argmax(q_eval, dim=1) + max_actions = T.argmax(q_eval, dim=1).detach() q_next[dones] = 0.0 q_target = rewards + self.gamma*q_next[indices, max_actions] diff --git a/DQN/dqn_agent.py b/DQN/dqn_agent.py index aeff431..7e81185 100644 --- a/DQN/dqn_agent.py +++ b/DQN/dqn_agent.py @@ -53,7 +53,7 @@ def sample_memory(self): states = T.tensor(state).to(self.q_eval.device) rewards = T.tensor(reward).to(self.q_eval.device) - dones = T.tensor(done).to(self.q_eval.device) + dones = T.tensor(done, dtype=T.bool).to(self.q_eval.device) actions = T.tensor(action).to(self.q_eval.device) states_ = T.tensor(new_state).to(self.q_eval.device) diff --git a/DuelingDDQN/dueling_ddqn_agent.py b/DuelingDDQN/dueling_ddqn_agent.py index 17193f2..d7bd893 100644 --- a/DuelingDDQN/dueling_ddqn_agent.py +++ b/DuelingDDQN/dueling_ddqn_agent.py @@ -42,7 +42,7 @@ def sample_memory(self): states = T.tensor(state).to(self.q_eval.device) rewards = T.tensor(reward).to(self.q_eval.device) - dones = T.tensor(done).to(self.q_eval.device) + dones = T.tensor(done, dtype=T.bool).to(self.q_eval.device) actions = T.tensor(action).to(self.q_eval.device) states_ = T.tensor(new_state).to(self.q_eval.device) @@ -92,7 +92,7 @@ def learn(self): q_eval = T.add(V_s_eval, (A_s_eval - A_s_eval.mean(dim=1,keepdim=True))) - max_actions = T.argmax(q_eval, dim=1) + max_actions = T.argmax(q_eval, dim=1).detach() q_next[dones] = 0.0 q_target = rewards + self.gamma*q_next[indices, max_actions] diff --git a/DuelingDQN/dueling_dqn_agent.py b/DuelingDQN/dueling_dqn_agent.py index 05fb63c..eab5829 100644 --- a/DuelingDQN/dueling_dqn_agent.py +++ b/DuelingDQN/dueling_dqn_agent.py @@ -42,7 +42,7 @@ def sample_memory(self): states = T.tensor(state).to(self.q_eval.device) rewards = T.tensor(reward).to(self.q_eval.device) - dones = T.tensor(done).to(self.q_eval.device) + dones = T.tensor(done, dtype=T.bool).to(self.q_eval.device) actions = T.tensor(action).to(self.q_eval.device) states_ = T.tensor(new_state).to(self.q_eval.device)