Skip to content

Commit 1ef7605

Browse files
committed
ppo in tensorflow 2
1 parent a484628 commit 1ef7605

File tree

5 files changed

+248
-0
lines changed

5 files changed

+248
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
import tensorflow.keras as keras
4+
from tensorflow.keras.optimizers import Adam
5+
import tensorflow_probability as tfp
6+
from memory import PPOMemory
7+
from networks import ActorNetwork, CriticNetwork
8+
9+
10+
class Agent:
11+
def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0003,
12+
gae_lambda=0.95, policy_clip=0.2, batch_size=64,
13+
n_epochs=10, chkpt_dir='models/'):
14+
self.gamma = gamma
15+
self.policy_clip = policy_clip
16+
self.n_epochs = n_epochs
17+
self.gae_lambda = gae_lambda
18+
self.chkpt_dir = chkpt_dir
19+
20+
self.actor = ActorNetwork(n_actions)
21+
self.actor.compile(optimizer=Adam(learning_rate=alpha))
22+
self.critic = CriticNetwork()
23+
self.critic.compile(optimizer=Adam(learning_rate=alpha))
24+
self.memory = PPOMemory(batch_size)
25+
26+
def store_transition(self, state, action, probs, vals, reward, done):
27+
self.memory.store_memory(state, action, probs, vals, reward, done)
28+
29+
def save_models(self):
30+
print('... saving models ...')
31+
self.actor.save(self.chkpt_dir + 'actor')
32+
self.critic.save(self.chkpt_dir + 'critic')
33+
34+
def load_models(self):
35+
print('... loading models ...')
36+
self.actor = keras.models.load_model(self.chkpt_dir + 'actor')
37+
self.critic = keras.models.load_model(self.chkpt_dir + 'critic')
38+
39+
def choose_action(self, observation):
40+
state = tf.convert_to_tensor([observation])
41+
42+
probs = self.actor(state)
43+
dist = tfp.distributions.Categorical(probs)
44+
action = dist.sample()
45+
log_prob = dist.log_prob(action)
46+
value = self.critic(state)
47+
48+
action = action.numpy()[0]
49+
value = value.numpy()[0]
50+
log_prob = log_prob.numpy()[0]
51+
52+
return action, log_prob, value
53+
54+
def learn(self):
55+
for _ in range(self.n_epochs):
56+
state_arr, action_arr, old_prob_arr, vals_arr,\
57+
reward_arr, dones_arr, batches = \
58+
self.memory.generate_batches()
59+
60+
values = vals_arr
61+
advantage = np.zeros(len(reward_arr), dtype=np.float32)
62+
63+
for t in range(len(reward_arr)-1):
64+
discount = 1
65+
a_t = 0
66+
for k in range(t, len(reward_arr)-1):
67+
a_t += discount*(reward_arr[k] + self.gamma*values[k+1] * (
68+
1-int(dones_arr[k])) - values[k])
69+
discount *= self.gamma*self.gae_lambda
70+
advantage[t] = a_t
71+
72+
for batch in batches:
73+
with tf.GradientTape(persistent=True) as tape:
74+
states = tf.convert_to_tensor(state_arr[batch])
75+
old_probs = tf.convert_to_tensor(old_prob_arr[batch])
76+
actions = tf.convert_to_tensor(action_arr[batch])
77+
78+
probs = self.actor(states)
79+
dist = tfp.distributions.Categorical(probs)
80+
new_probs = dist.log_prob(actions)
81+
82+
critic_value = self.critic(states)
83+
84+
critic_value = tf.squeeze(critic_value, 1)
85+
86+
prob_ratio = tf.math.exp(new_probs - old_probs)
87+
weighted_probs = advantage[batch] * prob_ratio
88+
clipped_probs = tf.clip_by_value(prob_ratio,
89+
1-self.policy_clip,
90+
1+self.policy_clip)
91+
weighted_clipped_probs = clipped_probs * advantage[batch]
92+
actor_loss = -tf.math.minimum(weighted_probs,
93+
weighted_clipped_probs)
94+
actor_loss = tf.math.reduce_mean(actor_loss)
95+
96+
returns = advantage[batch] + values[batch]
97+
# critic_loss = tf.math.reduce_mean(tf.math.pow(
98+
# returns-critic_value, 2))
99+
critic_loss = keras.losses.MSE(critic_value, returns)
100+
101+
actor_params = self.actor.trainable_variables
102+
actor_grads = tape.gradient(actor_loss, actor_params)
103+
critic_params = self.critic.trainable_variables
104+
critic_grads = tape.gradient(critic_loss, critic_params)
105+
self.actor.optimizer.apply_gradients(
106+
zip(actor_grads, actor_params))
107+
self.critic.optimizer.apply_gradients(
108+
zip(critic_grads, critic_params))
109+
110+
self.memory.clear_memory()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import gym
2+
import numpy as np
3+
from agent import Agent
4+
from utils import plot_learning_curve
5+
6+
if __name__ == '__main__':
7+
env = gym.make('CartPole-v0')
8+
N = 20
9+
batch_size = 5
10+
n_epochs = 4
11+
alpha = 0.0003
12+
agent = Agent(n_actions=env.action_space.n, batch_size=batch_size,
13+
alpha=alpha, n_epochs=n_epochs,
14+
input_dims=env.observation_space.shape)
15+
n_games = 300
16+
17+
figure_file = 'plots/cartpole.png'
18+
19+
best_score = env.reward_range[0]
20+
score_history = []
21+
22+
learn_iters = 0
23+
avg_score = 0
24+
n_steps = 0
25+
26+
for i in range(n_games):
27+
observation = env.reset()
28+
done = False
29+
score = 0
30+
while not done:
31+
action, prob, val = agent.choose_action(observation)
32+
observation_, reward, done, info = env.step(action)
33+
n_steps += 1
34+
score += reward
35+
agent.store_transition(observation, action,
36+
prob, val, reward, done)
37+
if n_steps % N == 0:
38+
agent.learn()
39+
learn_iters += 1
40+
observation = observation_
41+
score_history.append(score)
42+
avg_score = np.mean(score_history[-100:])
43+
44+
if avg_score > best_score:
45+
best_score = avg_score
46+
agent.save_models()
47+
48+
print('episode', i, 'score %.1f' % score, 'avg score %.1f' % avg_score,
49+
'time_steps', n_steps, 'learning_steps', learn_iters)
50+
x = [i+1 for i in range(len(score_history))]
51+
plot_learning_curve(x, score_history, figure_file)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
3+
4+
class PPOMemory:
5+
def __init__(self, batch_size):
6+
self.states = []
7+
self.probs = []
8+
self.vals = []
9+
self.actions = []
10+
self.rewards = []
11+
self.dones = []
12+
13+
self.batch_size = batch_size
14+
15+
def generate_batches(self):
16+
n_states = len(self.states)
17+
batch_start = np.arange(0, n_states, self.batch_size)
18+
indices = np.arange(n_states, dtype=np.int64)
19+
np.random.shuffle(indices)
20+
batches = [indices[i:i+self.batch_size] for i in batch_start]
21+
22+
return np.array(self.states),\
23+
np.array(self.actions),\
24+
np.array(self.probs),\
25+
np.array(self.vals),\
26+
np.array(self.rewards),\
27+
np.array(self.dones),\
28+
batches
29+
30+
def store_memory(self, state, action, probs, vals, reward, done):
31+
self.states.append(state)
32+
self.actions.append(action)
33+
self.probs.append(probs)
34+
self.vals.append(vals)
35+
self.rewards.append(reward)
36+
self.dones.append(done)
37+
38+
def clear_memory(self):
39+
self.states = []
40+
self.probs = []
41+
self.actions = []
42+
self.rewards = []
43+
self.dones = []
44+
self.vals = []
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import tensorflow.keras as keras
2+
from tensorflow.keras.layers import Dense
3+
4+
5+
class ActorNetwork(keras.Model):
6+
def __init__(self, n_actions, fc1_dims=256, fc2_dims=256):
7+
super(ActorNetwork, self).__init__()
8+
9+
self.fc1 = Dense(fc1_dims, activation='relu')
10+
self.fc2 = Dense(fc2_dims, activation='relu')
11+
self.fc3 = Dense(n_actions, activation='softmax')
12+
13+
def call(self, state):
14+
x = self.fc1(state)
15+
x = self.fc2(x)
16+
x = self.fc3(x)
17+
18+
return x
19+
20+
21+
class CriticNetwork(keras.Model):
22+
def __init__(self, fc1_dims=256, fc2_dims=256):
23+
super(CriticNetwork, self).__init__()
24+
self.fc1 = Dense(fc1_dims, activation='relu')
25+
self.fc2 = Dense(fc2_dims, activation='relu')
26+
self.q = Dense(1, activation=None)
27+
28+
def call(self, state):
29+
x = self.fc1(state)
30+
x = self.fc2(x)
31+
q = self.q(x)
32+
33+
return q
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
def plot_learning_curve(x, scores, figure_file):
5+
running_avg = np.zeros(len(scores))
6+
for i in range(len(running_avg)):
7+
running_avg[i] = np.mean(scores[max(0, i-100):(i+1)])
8+
plt.plot(x, running_avg)
9+
plt.title('Running average of previous 100 scores')
10+
plt.savefig(figure_file)

0 commit comments

Comments
 (0)