Skip to content

Commit f8bd4d5

Browse files
Tested with TF 2.3.1
Tested with TF 2.3.1
1 parent 4016a9c commit f8bd4d5

File tree

1 file changed

+323
-0
lines changed

1 file changed

+323
-0
lines changed
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Tutorial by www.pylessons.com
2+
# Tutorial written for - Tensorflow 2.3.1
3+
4+
import os
5+
import random
6+
import gym
7+
import pylab
8+
import numpy as np
9+
from collections import deque
10+
from tensorflow.keras.models import Model, load_model
11+
from tensorflow.keras.layers import Input, Dense, Lambda, Add, Conv2D, Flatten
12+
from tensorflow.keras.optimizers import Adam, RMSprop
13+
from tensorflow.keras import backend as K
14+
from PER import *
15+
import cv2
16+
17+
def OurModel(input_shape, action_space, dueling):
18+
X_input = Input(input_shape)
19+
X = X_input
20+
21+
X = Conv2D(64, 5, strides=(3, 3),padding="valid", input_shape=input_shape, activation="relu", data_format="channels_first")(X)
22+
X = Conv2D(64, 4, strides=(2, 2),padding="valid", activation="relu", data_format="channels_first")(X)
23+
X = Conv2D(64, 3, strides=(1, 1),padding="valid", activation="relu", data_format="channels_first")(X)
24+
X = Flatten()(X)
25+
# 'Dense' is the basic form of a neural network layer
26+
# Input Layer of state size(4) and Hidden Layer with 512 nodes
27+
X = Dense(512, activation="relu", kernel_initializer='he_uniform')(X)
28+
29+
# Hidden layer with 256 nodes
30+
X = Dense(256, activation="relu", kernel_initializer='he_uniform')(X)
31+
32+
# Hidden layer with 64 nodes
33+
X = Dense(64, activation="relu", kernel_initializer='he_uniform')(X)
34+
35+
if dueling:
36+
state_value = Dense(1, kernel_initializer='he_uniform')(X)
37+
state_value = Lambda(lambda s: K.expand_dims(s[:, 0], -1), output_shape=(action_space,))(state_value)
38+
39+
action_advantage = Dense(action_space, kernel_initializer='he_uniform')(X)
40+
action_advantage = Lambda(lambda a: a[:, :] - K.mean(a[:, :], keepdims=True), output_shape=(action_space,))(action_advantage)
41+
42+
X = Add()([state_value, action_advantage])
43+
else:
44+
# Output Layer with # of actions: 2 nodes (left, right)
45+
X = Dense(action_space, activation="linear", kernel_initializer='he_uniform')(X)
46+
47+
model = Model(inputs = X_input, outputs = X)
48+
model.compile(loss="mean_squared_error", optimizer=RMSprop(lr=0.00025, rho=0.95, epsilon=0.01), metrics=["accuracy"])
49+
50+
model.summary()
51+
return model
52+
53+
class DQNAgent:
54+
def __init__(self, env_name):
55+
self.env_name = env_name
56+
self.env = gym.make(env_name)
57+
self.env.seed(0)
58+
# by default, CartPole-v1 has max episode steps = 500
59+
# we can use this to experiment beyond 500
60+
self.env._max_episode_steps = 4000
61+
self.state_size = self.env.observation_space.shape[0]
62+
self.action_size = self.env.action_space.n
63+
self.EPISODES = 1000
64+
65+
# Instantiate memory
66+
memory_size = 10000
67+
self.MEMORY = Memory(memory_size)
68+
self.memory = deque(maxlen=2000)
69+
70+
self.gamma = 0.95 # discount rate
71+
72+
# EXPLORATION HYPERPARAMETERS for epsilon and epsilon greedy strategy
73+
self.epsilon = 1.0 # exploration probability at start
74+
self.epsilon_min = 0.01 # minimum exploration probability
75+
self.epsilon_decay = 0.0005 # exponential decay rate for exploration prob
76+
77+
self.batch_size = 32
78+
79+
# defining model parameters
80+
self.ddqn = True # use doudle deep q network
81+
self.Soft_Update = False # use soft parameter update
82+
self.dueling = True # use dealing netowrk
83+
self.epsilon_greedy = False # use epsilon greedy strategy
84+
self.USE_PER = True # use priority experienced replay
85+
86+
self.TAU = 0.1 # target network soft update hyperparameter
87+
88+
self.Save_Path = 'Models'
89+
if not os.path.exists(self.Save_Path): os.makedirs(self.Save_Path)
90+
self.scores, self.episodes, self.average = [], [], []
91+
92+
self.Model_name = os.path.join(self.Save_Path, self.env_name+"_PER_D3QN_CNN.h5")
93+
94+
self.ROWS = 160
95+
self.COLS = 240
96+
self.REM_STEP = 4
97+
98+
self.image_memory = np.zeros((self.REM_STEP, self.ROWS, self.COLS))
99+
self.state_size = (self.REM_STEP, self.ROWS, self.COLS)
100+
101+
# create main model and target model
102+
self.model = OurModel(input_shape=self.state_size, action_space = self.action_size, dueling = self.dueling)
103+
self.target_model = OurModel(input_shape=self.state_size, action_space = self.action_size, dueling = self.dueling)
104+
105+
# after some time interval update the target model to be same with model
106+
def update_target_model(self):
107+
if not self.Soft_Update and self.ddqn:
108+
self.target_model.set_weights(self.model.get_weights())
109+
return
110+
if self.Soft_Update and self.ddqn:
111+
q_model_theta = self.model.get_weights()
112+
target_model_theta = self.target_model.get_weights()
113+
counter = 0
114+
for q_weight, target_weight in zip(q_model_theta, target_model_theta):
115+
target_weight = target_weight * (1-self.TAU) + q_weight * self.TAU
116+
target_model_theta[counter] = target_weight
117+
counter += 1
118+
self.target_model.set_weights(target_model_theta)
119+
120+
def remember(self, state, action, reward, next_state, done):
121+
experience = state, action, reward, next_state, done
122+
if self.USE_PER:
123+
self.MEMORY.store(experience)
124+
else:
125+
self.memory.append((experience))
126+
127+
def act(self, state, decay_step):
128+
# EPSILON GREEDY STRATEGY
129+
if self.epsilon_greedy:
130+
# Here we'll use an improved version of our epsilon greedy strategy for Q-learning
131+
explore_probability = self.epsilon_min + (self.epsilon - self.epsilon_min) * np.exp(-self.epsilon_decay * decay_step)
132+
# OLD EPSILON STRATEGY
133+
else:
134+
if self.epsilon > self.epsilon_min:
135+
self.epsilon *= (1-self.epsilon_decay)
136+
explore_probability = self.epsilon
137+
138+
if explore_probability > np.random.rand():
139+
# Make a random action (exploration)
140+
return random.randrange(self.action_size), explore_probability
141+
else:
142+
# Get action from Q-network (exploitation)
143+
# Estimate the Qs values state
144+
# Take the biggest Q value (= the best action)
145+
return np.argmax(self.model.predict(state)), explore_probability
146+
147+
def replay(self):
148+
if self.USE_PER:
149+
# Sample minibatch from the PER memory
150+
tree_idx, minibatch = self.MEMORY.sample(self.batch_size)
151+
else:
152+
# Randomly sample minibatch from the deque memory
153+
minibatch = random.sample(self.memory, min(len(self.memory), self.batch_size))
154+
155+
state = np.zeros((self.batch_size,) + self.state_size)
156+
next_state = np.zeros((self.batch_size,) + self.state_size)
157+
action, reward, done = [], [], []
158+
159+
# do this before prediction
160+
# for speedup, this could be done on the tensor level
161+
# but easier to understand using a loop
162+
for i in range(len(minibatch)):
163+
state[i] = minibatch[i][0]
164+
action.append(minibatch[i][1])
165+
reward.append(minibatch[i][2])
166+
next_state[i] = minibatch[i][3]
167+
done.append(minibatch[i][4])
168+
169+
# do batch prediction to save speed
170+
# predict Q-values for starting state using the main network
171+
target = self.model.predict(state)
172+
target_old = np.array(target)
173+
# predict best action in ending state using the main network
174+
target_next = self.model.predict(next_state)
175+
# predict Q-values for ending state using the target network
176+
target_val = self.target_model.predict(next_state)
177+
178+
for i in range(len(minibatch)):
179+
# correction on the Q value for the action used
180+
if done[i]:
181+
target[i][action[i]] = reward[i]
182+
else:
183+
# the key point of Double DQN
184+
# selection of action is from model
185+
# update is from target model
186+
if self.ddqn: # Double - DQN
187+
# current Q Network selects the action
188+
# a'_max = argmax_a' Q(s', a')
189+
a = np.argmax(target_next[i])
190+
# target Q Network evaluates the action
191+
# Q_max = Q_target(s', a'_max)
192+
target[i][action[i]] = reward[i] + self.gamma * (target_val[i][a])
193+
else: # Standard - DQN
194+
# DQN chooses the max Q value among next actions
195+
# selection and evaluation of action is on the target Q Network
196+
# Q_max = max_a' Q_target(s', a')
197+
target[i][action[i]] = reward[i] + self.gamma * (np.amax(target_next[i]))
198+
199+
if self.USE_PER:
200+
indices = np.arange(self.batch_size, dtype=np.int32)
201+
absolute_errors = np.abs(target_old[indices, np.array(action)]-target[indices, np.array(action)])
202+
# Update priority
203+
self.MEMORY.batch_update(tree_idx, absolute_errors)
204+
205+
# Train the Neural Network with batches
206+
self.model.fit(state, target, batch_size=self.batch_size, verbose=0)
207+
208+
def load(self, name):
209+
self.model = load_model(name)
210+
211+
def save(self, name):
212+
self.model.save(name)
213+
214+
pylab.figure(figsize=(18, 9))
215+
def PlotModel(self, score, episode):
216+
self.scores.append(score)
217+
self.episodes.append(episode)
218+
self.average.append(sum(self.scores[-50:]) / len(self.scores[-50:]))
219+
pylab.plot(self.episodes, self.average, 'r')
220+
pylab.plot(self.episodes, self.scores, 'b')
221+
pylab.ylabel('Score', fontsize=18)
222+
pylab.xlabel('Steps', fontsize=18)
223+
dqn = 'DQN_'
224+
softupdate = ''
225+
dueling = ''
226+
greedy = ''
227+
PER = ''
228+
if self.ddqn: dqn = 'DDQN_'
229+
if self.Soft_Update: softupdate = '_soft'
230+
if self.dueling: dueling = '_Dueling'
231+
if self.epsilon_greedy: greedy = '_Greedy'
232+
if self.USE_PER: PER = '_PER'
233+
try:
234+
pylab.savefig(dqn+self.env_name+softupdate+dueling+greedy+PER+"_CNN.png")
235+
except OSError:
236+
pass
237+
238+
return str(self.average[-1])[:5]
239+
240+
def imshow(self, image, rem_step=0):
241+
cv2.imshow("cartpole"+str(rem_step), image[rem_step,...])
242+
if cv2.waitKey(25) & 0xFF == ord("q"):
243+
cv2.destroyAllWindows()
244+
return
245+
246+
def GetImage(self):
247+
img = self.env.render(mode='rgb_array')
248+
249+
img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
250+
img_rgb_resized = cv2.resize(img_rgb, (self.COLS, self.ROWS), interpolation=cv2.INTER_CUBIC)
251+
img_rgb_resized[img_rgb_resized < 255] = 0
252+
img_rgb_resized = img_rgb_resized / 255
253+
254+
self.image_memory = np.roll(self.image_memory, 1, axis = 0)
255+
self.image_memory[0,:,:] = img_rgb_resized
256+
257+
#self.imshow(self.image_memory,0)
258+
259+
return np.expand_dims(self.image_memory, axis=0)
260+
261+
def reset(self):
262+
self.env.reset()
263+
for i in range(self.REM_STEP):
264+
state = self.GetImage()
265+
return state
266+
267+
def step(self,action):
268+
next_state, reward, done, info = self.env.step(action)
269+
next_state = self.GetImage()
270+
return next_state, reward, done, info
271+
272+
def run(self):
273+
decay_step = 0
274+
for e in range(self.EPISODES):
275+
state = self.reset()
276+
done = False
277+
i = 0
278+
while not done:
279+
decay_step += 1
280+
action, explore_probability = self.act(state, decay_step)
281+
next_state, reward, done, _ = self.step(action)
282+
if not done or i == self.env._max_episode_steps-1:
283+
reward = reward
284+
else:
285+
reward = -100
286+
self.remember(state, action, reward, next_state, done)
287+
state = next_state
288+
i += 1
289+
if done:
290+
# every REM_STEP update target model
291+
if e % self.REM_STEP == 0:
292+
self.update_target_model()
293+
294+
# every episode, plot the result
295+
average = self.PlotModel(i, e)
296+
297+
print("episode: {}/{}, score: {}, e: {:.2}, average: {}".format(e, self.EPISODES, i, explore_probability, average))
298+
if i == self.env._max_episode_steps:
299+
print("Saving trained model to", self.Model_name)
300+
#self.save(self.Model_name)
301+
break
302+
self.replay()
303+
self.env.close()
304+
305+
def test(self):
306+
self.load(self.Model_name)
307+
for e in range(self.EPISODES):
308+
state = self.reset()
309+
done = False
310+
i = 0
311+
while not done:
312+
action = np.argmax(self.model.predict(state))
313+
next_state, reward, done, _ = env.step(action)
314+
i += 1
315+
if done:
316+
print("episode: {}/{}, score: {}".format(e, self.EPISODES, i))
317+
break
318+
319+
if __name__ == "__main__":
320+
env_name = 'CartPole-v1'
321+
agent = DQNAgent(env_name)
322+
agent.run()
323+
#agent.test()

0 commit comments

Comments
 (0)