Skip to content

Commit 3576709

Browse files
Add files via upload
1 parent 2402f74 commit 3576709

File tree

1 file changed

+234
-0
lines changed

1 file changed

+234
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Tutorial by www.pylessons.com
2+
# Tutorial written for - Tensorflow 2.3.1
3+
4+
import os
5+
import random
6+
import gym
7+
import pylab
8+
import numpy as np
9+
from collections import deque
10+
import tensorflow as tf
11+
from tensorflow.keras.models import Model, load_model
12+
from tensorflow.keras.layers import Input, Dense
13+
from tensorflow.keras.optimizers import Adam, RMSprop
14+
15+
16+
def OurModel(input_shape, action_space):
17+
X_input = Input(input_shape)
18+
X = X_input
19+
20+
# 'Dense' is the basic form of a neural network layer
21+
# Input Layer of state size(4) and Hidden Layer with 512 nodes
22+
X = Dense(512, input_shape=input_shape, activation="relu", kernel_initializer='he_uniform')(X)
23+
24+
# Hidden layer with 256 nodes
25+
X = Dense(256, activation="relu", kernel_initializer='he_uniform')(X)
26+
27+
# Hidden layer with 64 nodes
28+
X = Dense(64, activation="relu", kernel_initializer='he_uniform')(X)
29+
30+
# Output Layer with # of actions: 2 nodes (left, right)
31+
X = Dense(action_space, activation="linear", kernel_initializer='he_uniform')(X)
32+
33+
model = Model(inputs = X_input, outputs = X)
34+
model.compile(loss="mean_squared_error", optimizer=RMSprop(lr=0.00025, rho=0.95, epsilon=0.01), metrics=["accuracy"])
35+
36+
model.summary()
37+
return model
38+
39+
class DQNAgent:
40+
def __init__(self, env_name):
41+
self.env_name = env_name
42+
self.env = gym.make(env_name)
43+
self.env.seed(0)
44+
# by default, CartPole-v1 has max episode steps = 500
45+
self.env._max_episode_steps = 4000
46+
self.state_size = self.env.observation_space.shape[0]
47+
self.action_size = self.env.action_space.n
48+
49+
self.EPISODES = 1000
50+
self.memory = deque(maxlen=2000)
51+
52+
self.gamma = 0.95 # discount rate
53+
self.epsilon = 1.0 # exploration rate
54+
self.epsilon_min = 0.01
55+
self.epsilon_decay = 0.999
56+
self.batch_size = 32
57+
self.train_start = 1000
58+
59+
# defining model parameters
60+
self.ddqn = True
61+
self.Soft_Update = False
62+
63+
self.TAU = 0.1 # target network soft update hyperparameter
64+
65+
self.Save_Path = 'Models'
66+
self.scores, self.episodes, self.average = [], [], []
67+
68+
if self.ddqn:
69+
print("----------Double DQN--------")
70+
self.Model_name = os.path.join(self.Save_Path,"DDQN_"+self.env_name+".h5")
71+
else:
72+
print("-------------DQN------------")
73+
self.Model_name = os.path.join(self.Save_Path,"DQN_"+self.env_name+".h5")
74+
75+
# create main model
76+
self.model = OurModel(input_shape=(self.state_size,), action_space = self.action_size)
77+
self.target_model = OurModel(input_shape=(self.state_size,), action_space = self.action_size)
78+
79+
# after some time interval update the target model to be same with model
80+
def update_target_model(self):
81+
if not self.Soft_Update and self.ddqn:
82+
self.target_model.set_weights(self.model.get_weights())
83+
return
84+
if self.Soft_Update and self.ddqn:
85+
q_model_theta = self.model.get_weights()
86+
target_model_theta = self.target_model.get_weights()
87+
counter = 0
88+
for q_weight, target_weight in zip(q_model_theta, target_model_theta):
89+
target_weight = target_weight * (1-self.TAU) + q_weight * self.TAU
90+
target_model_theta[counter] = target_weight
91+
counter += 1
92+
self.target_model.set_weights(target_model_theta)
93+
94+
def remember(self, state, action, reward, next_state, done):
95+
self.memory.append((state, action, reward, next_state, done))
96+
if len(self.memory) > self.train_start:
97+
if self.epsilon > self.epsilon_min:
98+
self.epsilon *= self.epsilon_decay
99+
100+
def act(self, state):
101+
if np.random.random() <= self.epsilon:
102+
return random.randrange(self.action_size)
103+
else:
104+
return np.argmax(self.model.predict(state))
105+
106+
def replay(self):
107+
if len(self.memory) < self.train_start:
108+
return
109+
# Randomly sample minibatch from the memory
110+
minibatch = random.sample(self.memory, min(self.batch_size, self.batch_size))
111+
112+
state = np.zeros((self.batch_size, self.state_size))
113+
next_state = np.zeros((self.batch_size, self.state_size))
114+
action, reward, done = [], [], []
115+
116+
# do this before prediction
117+
# for speedup, this could be done on the tensor level
118+
# but easier to understand using a loop
119+
for i in range(self.batch_size):
120+
state[i] = minibatch[i][0]
121+
action.append(minibatch[i][1])
122+
reward.append(minibatch[i][2])
123+
next_state[i] = minibatch[i][3]
124+
done.append(minibatch[i][4])
125+
126+
# do batch prediction to save speed
127+
target = self.model.predict(state)
128+
target_next = self.model.predict(next_state)
129+
target_val = self.target_model.predict(next_state)
130+
131+
for i in range(len(minibatch)):
132+
# correction on the Q value for the action used
133+
if done[i]:
134+
target[i][action[i]] = reward[i]
135+
else:
136+
if self.ddqn: # Double - DQN
137+
# current Q Network selects the action
138+
# a'_max = argmax_a' Q(s', a')
139+
a = np.argmax(target_next[i])
140+
# target Q Network evaluates the action
141+
# Q_max = Q_target(s', a'_max)
142+
target[i][action[i]] = reward[i] + self.gamma * (target_val[i][a])
143+
else: # Standard - DQN
144+
# DQN chooses the max Q value among next actions
145+
# selection and evaluation of action is on the target Q Network
146+
# Q_max = max_a' Q_target(s', a')
147+
target[i][action[i]] = reward[i] + self.gamma * (np.amax(target_next[i]))
148+
149+
# Train the Neural Network with batches
150+
self.model.fit(state, target, batch_size=self.batch_size, verbose=0)
151+
152+
153+
def load(self, name):
154+
self.model = load_model(name)
155+
156+
def save(self, name):
157+
self.model.save(name)
158+
159+
pylab.figure(figsize=(18, 9))
160+
def PlotModel(self, score, episode):
161+
self.scores.append(score)
162+
self.episodes.append(episode)
163+
self.average.append(sum(self.scores) / len(self.scores))
164+
pylab.plot(self.episodes, self.average, 'r')
165+
pylab.plot(self.episodes, self.scores, 'b')
166+
pylab.ylabel('Score', fontsize=18)
167+
pylab.xlabel('Steps', fontsize=18)
168+
dqn = 'DQN_'
169+
softupdate = ''
170+
if self.ddqn:
171+
dqn = 'DDQN_'
172+
if self.Soft_Update:
173+
softupdate = '_soft'
174+
try:
175+
pylab.savefig(dqn+self.env_name+softupdate+".png")
176+
except OSError:
177+
pass
178+
179+
return str(self.average[-1])[:5]
180+
181+
def run(self):
182+
for e in range(self.EPISODES):
183+
state = self.env.reset()
184+
state = np.reshape(state, [1, self.state_size])
185+
done = False
186+
i = 0
187+
while not done:
188+
#self.env.render()
189+
action = self.act(state)
190+
next_state, reward, done, _ = self.env.step(action)
191+
next_state = np.reshape(next_state, [1, self.state_size])
192+
if not done or i == self.env._max_episode_steps-1:
193+
reward = reward
194+
else:
195+
reward = -100
196+
self.remember(state, action, reward, next_state, done)
197+
state = next_state
198+
i += 1
199+
if done:
200+
# every step update target model
201+
self.update_target_model()
202+
203+
# every episode, plot the result
204+
average = self.PlotModel(i, e)
205+
206+
print("episode: {}/{}, score: {}, e: {:.2}, average: {}".format(e, self.EPISODES, i, self.epsilon, average))
207+
if i == self.env._max_episode_steps:
208+
print("Saving trained model as cartpole-ddqn.h5")
209+
#self.save("cartpole-ddqn.h5")
210+
break
211+
self.replay()
212+
213+
def test(self):
214+
self.load("cartpole-ddqn.h5")
215+
for e in range(self.EPISODES):
216+
state = self.env.reset()
217+
state = np.reshape(state, [1, self.state_size])
218+
done = False
219+
i = 0
220+
while not done:
221+
self.env.render()
222+
action = np.argmax(self.model.predict(state))
223+
next_state, reward, done, _ = self.env.step(action)
224+
state = np.reshape(next_state, [1, self.state_size])
225+
i += 1
226+
if done:
227+
print("episode: {}/{}, score: {}".format(e, self.EPISODES, i))
228+
break
229+
230+
if __name__ == "__main__":
231+
env_name = 'CartPole-v1'
232+
agent = DQNAgent(env_name)
233+
agent.run()
234+
#agent.test()

0 commit comments

Comments
 (0)