Skip to content

Commit 01ce5cb

Browse files
committed
fix error in critic constructor call
1 parent aad48f1 commit 01ce5cb

File tree

1 file changed

+43
-44
lines changed
  • ReinforcementLearning/PolicyGradient/TD3/tf2

1 file changed

+43
-44
lines changed

ReinforcementLearning/PolicyGradient/TD3/tf2/td3_tf2.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from tensorflow.keras.optimizers import Adam
66
import os
77

8-
class ReplayBuffer():
8+
9+
class ReplayBuffer:
910
def __init__(self, max_size, input_shape, n_actions):
1011
self.mem_size = max_size
1112
self.mem_cntr = 0
@@ -60,8 +61,10 @@ def call(self, state, action):
6061

6162
return q
6263

64+
6365
class ActorNetwork(keras.Model):
64-
def __init__(self, fc1_dims, fc2_dims, n_actions, name, chkpt_dir='tmp/td3'):
66+
def __init__(self, fc1_dims, fc2_dims, n_actions, name,
67+
chkpt_dir='tmp/td3'):
6568
super(ActorNetwork, self).__init__()
6669
self.fc1_dims = fc1_dims
6770
self.fc2_dims = fc2_dims
@@ -74,7 +77,6 @@ def __init__(self, fc1_dims, fc2_dims, n_actions, name, chkpt_dir='tmp/td3'):
7477
self.fc2 = Dense(self.fc2_dims, activation='relu')
7578
self.mu = Dense(self.n_actions, activation='tanh')
7679

77-
7880
def call(self, state):
7981
prob = self.fc1(state)
8082
prob = self.fc2(prob)
@@ -83,11 +85,12 @@ def call(self, state):
8385

8486
return mu
8587

86-
class Agent():
88+
89+
class Agent:
8790
def __init__(self, alpha, beta, input_dims, tau, env,
88-
gamma=0.99, update_actor_interval=2, warmup=1000,
89-
n_actions=2, max_size=1000000, layer1_size=400,
90-
layer2_size=300, batch_size=100, noise=0.1):
91+
gamma=0.99, update_actor_interval=2, warmup=1000,
92+
n_actions=2, max_size=1000000, layer1_size=400,
93+
layer2_size=300, batch_size=100, noise=0.1):
9194
self.gamma = gamma
9295
self.tau = tau
9396
self.max_action = env.action_space.high[0]
@@ -100,33 +103,34 @@ def __init__(self, alpha, beta, input_dims, tau, env,
100103
self.n_actions = n_actions
101104
self.update_actor_iter = update_actor_interval
102105

103-
self.actor = ActorNetwork(layer1_size, layer2_size,
104-
n_actions=n_actions, name='actor')
106+
self.actor = ActorNetwork(layer1_size, layer2_size,
107+
n_actions=n_actions, name='actor')
105108

106-
self.critic_1 = CriticNetwork(layer1_size, layer2_size,
107-
n_actions=n_actions, name='critic_1')
109+
self.critic_1 = CriticNetwork(layer1_size, layer2_size,
110+
name='critic_1')
108111
self.critic_2 = CriticNetwork(layer1_size, layer2_size,
109-
n_actions=n_actions, name='critic_2')
112+
name='critic_2')
110113

111-
self.target_actor = ActorNetwork(layer1_size, layer2_size,
112-
n_actions=n_actions, name='target_actor')
113-
self.target_critic_1 = CriticNetwork(layer1_size, layer2_size,
114-
n_actions=n_actions, name='target_critic_1')
115-
self.target_critic_2 = CriticNetwork(layer1_size, layer2_size,
116-
n_actions=n_actions, name='target_critic_2')
114+
self.target_actor = ActorNetwork(layer1_size, layer2_size,
115+
n_actions=n_actions,
116+
name='target_actor')
117+
self.target_critic_1 = CriticNetwork(layer1_size, layer2_size,
118+
name='target_critic_1')
119+
self.target_critic_2 = CriticNetwork(layer1_size, layer2_size,
120+
name='target_critic_2')
117121

118122
self.actor.compile(optimizer=Adam(learning_rate=alpha), loss='mean')
119-
self.critic_1.compile(optimizer=Adam(learning_rate=beta),
123+
self.critic_1.compile(optimizer=Adam(learning_rate=beta),
120124
loss='mean_squared_error')
121-
self.critic_2.compile(optimizer=Adam(learning_rate=beta),
125+
self.critic_2.compile(optimizer=Adam(learning_rate=beta),
122126
loss='mean_squared_error')
123127

124-
self.target_actor.compile(optimizer=Adam(learning_rate=alpha),
128+
self.target_actor.compile(optimizer=Adam(learning_rate=alpha),
125129
loss='mean')
126-
self.target_critic_1.compile(optimizer=Adam(learning_rate=beta),
127-
loss='mean_squared_error')
128-
self.target_critic_2.compile(optimizer=Adam(learning_rate=beta),
129-
loss='mean_squared_error')
130+
self.target_critic_1.compile(optimizer=Adam(learning_rate=beta),
131+
loss='mean_squared_error')
132+
self.target_critic_2.compile(optimizer=Adam(learning_rate=beta),
133+
loss='mean_squared_error')
130134

131135
self.noise = noise
132136
self.update_network_parameters(tau=1)
@@ -136,7 +140,8 @@ def choose_action(self, observation):
136140
mu = np.random.normal(scale=self.noise, size=(self.n_actions,))
137141
else:
138142
state = tf.convert_to_tensor([observation], dtype=tf.float32)
139-
mu = self.actor(state)[0] # returns a batch size of 1, want a scalar array
143+
# returns a batch size of 1, want a scalar array
144+
mu = self.actor(state)[0]
140145
mu_prime = mu + np.random.normal(scale=self.noise)
141146

142147
mu_prime = tf.clip_by_value(mu_prime, self.min_action, self.max_action)
@@ -149,10 +154,10 @@ def remember(self, state, action, reward, new_state, done):
149154

150155
def learn(self):
151156
if self.memory.mem_cntr < self.batch_size:
152-
return
157+
return
153158

154159
states, actions, rewards, new_states, dones = \
155-
self.memory.sample_buffer(self.batch_size)
160+
self.memory.sample_buffer(self.batch_size)
156161

157162
states = tf.convert_to_tensor(states, dtype=tf.float32)
158163
actions = tf.convert_to_tensor(actions, dtype=tf.float32)
@@ -162,11 +167,11 @@ def learn(self):
162167
with tf.GradientTape(persistent=True) as tape:
163168
target_actions = self.target_actor(states_)
164169
target_actions = target_actions + \
165-
tf.clip_by_value(np.random.normal(scale=0.2), -0.5, 0.5)
170+
tf.clip_by_value(np.random.normal(scale=0.2), -0.5, 0.5)
171+
172+
target_actions = tf.clip_by_value(target_actions, self.min_action,
173+
self.max_action)
166174

167-
target_actions = tf.clip_by_value(target_actions, self.min_action,
168-
self.max_action)
169-
170175
q1_ = self.target_critic_1(states_, target_actions)
171176
q2_ = self.target_critic_2(states_, target_actions)
172177

@@ -182,23 +187,19 @@ def learn(self):
182187
# and eager exection doesn't support assignment, so we can't do
183188
# q1_[dones] = 0.0
184189
target = rewards + self.gamma*critic_value_*(1-dones)
185-
#critic_1_loss = tf.math.reduce_mean(tf.math.square(target - q1))
186-
#critic_2_loss = tf.math.reduce_mean(tf.math.square(target - q2))
187190
critic_1_loss = keras.losses.MSE(target, q1)
188191
critic_2_loss = keras.losses.MSE(target, q2)
189192

190-
191-
critic_1_gradient = tape.gradient(critic_1_loss,
193+
critic_1_gradient = tape.gradient(critic_1_loss,
192194
self.critic_1.trainable_variables)
193-
critic_2_gradient = tape.gradient(critic_2_loss,
195+
critic_2_gradient = tape.gradient(critic_2_loss,
194196
self.critic_2.trainable_variables)
195197

196198
self.critic_1.optimizer.apply_gradients(
197-
zip(critic_1_gradient, self.critic_1.trainable_variables))
199+
zip(critic_1_gradient, self.critic_1.trainable_variables))
198200
self.critic_2.optimizer.apply_gradients(
199-
zip(critic_2_gradient, self.critic_2.trainable_variables))
201+
zip(critic_2_gradient, self.critic_2.trainable_variables))
200202

201-
202203
self.learn_step_cntr += 1
203204

204205
if self.learn_step_cntr % self.update_actor_iter != 0:
@@ -209,7 +210,8 @@ def learn(self):
209210
critic_1_value = self.critic_1(states, new_actions)
210211
actor_loss = -tf.math.reduce_mean(critic_1_value)
211212

212-
actor_gradient = tape.gradient(actor_loss, self.actor.trainable_variables)
213+
actor_gradient = tape.gradient(actor_loss,
214+
self.actor.trainable_variables)
213215
self.actor.optimizer.apply_gradients(
214216
zip(actor_gradient, self.actor.trainable_variables))
215217

@@ -250,13 +252,10 @@ def save_models(self):
250252
self.target_critic_2.save_weights(self.target_critic_2.checkpoint_file)
251253

252254
def load_models(self):
253-
254255
print('... loading models ...')
255256
self.actor.load_weights(self.actor.checkpoint_file)
256257
self.critic_1.load_weights(self.critic_1.checkpoint_file)
257258
self.critic_2.load_weights(self.critic_2.checkpoint_file)
258259
self.target_actor.load_weights(self.target_actor.checkpoint_file)
259260
self.target_critic_1.load_weights(self.target_critic_1.checkpoint_file)
260261
self.target_critic_2.load_weights(self.target_critic_2.checkpoint_file)
261-
262-

0 commit comments

Comments
 (0)