55from tensorflow .keras .optimizers import Adam
66import os
77
8- class ReplayBuffer ():
8+
9+ class ReplayBuffer :
910 def __init__ (self , max_size , input_shape , n_actions ):
1011 self .mem_size = max_size
1112 self .mem_cntr = 0
@@ -60,8 +61,10 @@ def call(self, state, action):
6061
6162 return q
6263
64+
6365class ActorNetwork (keras .Model ):
64- def __init__ (self , fc1_dims , fc2_dims , n_actions , name , chkpt_dir = 'tmp/td3' ):
66+ def __init__ (self , fc1_dims , fc2_dims , n_actions , name ,
67+ chkpt_dir = 'tmp/td3' ):
6568 super (ActorNetwork , self ).__init__ ()
6669 self .fc1_dims = fc1_dims
6770 self .fc2_dims = fc2_dims
@@ -74,7 +77,6 @@ def __init__(self, fc1_dims, fc2_dims, n_actions, name, chkpt_dir='tmp/td3'):
7477 self .fc2 = Dense (self .fc2_dims , activation = 'relu' )
7578 self .mu = Dense (self .n_actions , activation = 'tanh' )
7679
77-
7880 def call (self , state ):
7981 prob = self .fc1 (state )
8082 prob = self .fc2 (prob )
@@ -83,11 +85,12 @@ def call(self, state):
8385
8486 return mu
8587
86- class Agent ():
88+
89+ class Agent :
8790 def __init__ (self , alpha , beta , input_dims , tau , env ,
88- gamma = 0.99 , update_actor_interval = 2 , warmup = 1000 ,
89- n_actions = 2 , max_size = 1000000 , layer1_size = 400 ,
90- layer2_size = 300 , batch_size = 100 , noise = 0.1 ):
91+ gamma = 0.99 , update_actor_interval = 2 , warmup = 1000 ,
92+ n_actions = 2 , max_size = 1000000 , layer1_size = 400 ,
93+ layer2_size = 300 , batch_size = 100 , noise = 0.1 ):
9194 self .gamma = gamma
9295 self .tau = tau
9396 self .max_action = env .action_space .high [0 ]
@@ -100,33 +103,34 @@ def __init__(self, alpha, beta, input_dims, tau, env,
100103 self .n_actions = n_actions
101104 self .update_actor_iter = update_actor_interval
102105
103- self .actor = ActorNetwork (layer1_size , layer2_size ,
104- n_actions = n_actions , name = 'actor' )
106+ self .actor = ActorNetwork (layer1_size , layer2_size ,
107+ n_actions = n_actions , name = 'actor' )
105108
106- self .critic_1 = CriticNetwork (layer1_size , layer2_size ,
107- n_actions = n_actions , name = 'critic_1' )
109+ self .critic_1 = CriticNetwork (layer1_size , layer2_size ,
110+ name = 'critic_1' )
108111 self .critic_2 = CriticNetwork (layer1_size , layer2_size ,
109- n_actions = n_actions , name = 'critic_2' )
112+ name = 'critic_2' )
110113
111- self .target_actor = ActorNetwork (layer1_size , layer2_size ,
112- n_actions = n_actions , name = 'target_actor' )
113- self .target_critic_1 = CriticNetwork (layer1_size , layer2_size ,
114- n_actions = n_actions , name = 'target_critic_1' )
115- self .target_critic_2 = CriticNetwork (layer1_size , layer2_size ,
116- n_actions = n_actions , name = 'target_critic_2' )
114+ self .target_actor = ActorNetwork (layer1_size , layer2_size ,
115+ n_actions = n_actions ,
116+ name = 'target_actor' )
117+ self .target_critic_1 = CriticNetwork (layer1_size , layer2_size ,
118+ name = 'target_critic_1' )
119+ self .target_critic_2 = CriticNetwork (layer1_size , layer2_size ,
120+ name = 'target_critic_2' )
117121
118122 self .actor .compile (optimizer = Adam (learning_rate = alpha ), loss = 'mean' )
119- self .critic_1 .compile (optimizer = Adam (learning_rate = beta ),
123+ self .critic_1 .compile (optimizer = Adam (learning_rate = beta ),
120124 loss = 'mean_squared_error' )
121- self .critic_2 .compile (optimizer = Adam (learning_rate = beta ),
125+ self .critic_2 .compile (optimizer = Adam (learning_rate = beta ),
122126 loss = 'mean_squared_error' )
123127
124- self .target_actor .compile (optimizer = Adam (learning_rate = alpha ),
128+ self .target_actor .compile (optimizer = Adam (learning_rate = alpha ),
125129 loss = 'mean' )
126- self .target_critic_1 .compile (optimizer = Adam (learning_rate = beta ),
127- loss = 'mean_squared_error' )
128- self .target_critic_2 .compile (optimizer = Adam (learning_rate = beta ),
129- loss = 'mean_squared_error' )
130+ self .target_critic_1 .compile (optimizer = Adam (learning_rate = beta ),
131+ loss = 'mean_squared_error' )
132+ self .target_critic_2 .compile (optimizer = Adam (learning_rate = beta ),
133+ loss = 'mean_squared_error' )
130134
131135 self .noise = noise
132136 self .update_network_parameters (tau = 1 )
@@ -136,7 +140,8 @@ def choose_action(self, observation):
136140 mu = np .random .normal (scale = self .noise , size = (self .n_actions ,))
137141 else :
138142 state = tf .convert_to_tensor ([observation ], dtype = tf .float32 )
139- mu = self .actor (state )[0 ] # returns a batch size of 1, want a scalar array
143+ # returns a batch size of 1, want a scalar array
144+ mu = self .actor (state )[0 ]
140145 mu_prime = mu + np .random .normal (scale = self .noise )
141146
142147 mu_prime = tf .clip_by_value (mu_prime , self .min_action , self .max_action )
@@ -149,10 +154,10 @@ def remember(self, state, action, reward, new_state, done):
149154
150155 def learn (self ):
151156 if self .memory .mem_cntr < self .batch_size :
152- return
157+ return
153158
154159 states , actions , rewards , new_states , dones = \
155- self .memory .sample_buffer (self .batch_size )
160+ self .memory .sample_buffer (self .batch_size )
156161
157162 states = tf .convert_to_tensor (states , dtype = tf .float32 )
158163 actions = tf .convert_to_tensor (actions , dtype = tf .float32 )
@@ -162,11 +167,11 @@ def learn(self):
162167 with tf .GradientTape (persistent = True ) as tape :
163168 target_actions = self .target_actor (states_ )
164169 target_actions = target_actions + \
165- tf .clip_by_value (np .random .normal (scale = 0.2 ), - 0.5 , 0.5 )
170+ tf .clip_by_value (np .random .normal (scale = 0.2 ), - 0.5 , 0.5 )
171+
172+ target_actions = tf .clip_by_value (target_actions , self .min_action ,
173+ self .max_action )
166174
167- target_actions = tf .clip_by_value (target_actions , self .min_action ,
168- self .max_action )
169-
170175 q1_ = self .target_critic_1 (states_ , target_actions )
171176 q2_ = self .target_critic_2 (states_ , target_actions )
172177
@@ -182,23 +187,19 @@ def learn(self):
182187 # and eager exection doesn't support assignment, so we can't do
183188 # q1_[dones] = 0.0
184189 target = rewards + self .gamma * critic_value_ * (1 - dones )
185- #critic_1_loss = tf.math.reduce_mean(tf.math.square(target - q1))
186- #critic_2_loss = tf.math.reduce_mean(tf.math.square(target - q2))
187190 critic_1_loss = keras .losses .MSE (target , q1 )
188191 critic_2_loss = keras .losses .MSE (target , q2 )
189192
190-
191- critic_1_gradient = tape .gradient (critic_1_loss ,
193+ critic_1_gradient = tape .gradient (critic_1_loss ,
192194 self .critic_1 .trainable_variables )
193- critic_2_gradient = tape .gradient (critic_2_loss ,
195+ critic_2_gradient = tape .gradient (critic_2_loss ,
194196 self .critic_2 .trainable_variables )
195197
196198 self .critic_1 .optimizer .apply_gradients (
197- zip (critic_1_gradient , self .critic_1 .trainable_variables ))
199+ zip (critic_1_gradient , self .critic_1 .trainable_variables ))
198200 self .critic_2 .optimizer .apply_gradients (
199- zip (critic_2_gradient , self .critic_2 .trainable_variables ))
201+ zip (critic_2_gradient , self .critic_2 .trainable_variables ))
200202
201-
202203 self .learn_step_cntr += 1
203204
204205 if self .learn_step_cntr % self .update_actor_iter != 0 :
@@ -209,7 +210,8 @@ def learn(self):
209210 critic_1_value = self .critic_1 (states , new_actions )
210211 actor_loss = - tf .math .reduce_mean (critic_1_value )
211212
212- actor_gradient = tape .gradient (actor_loss , self .actor .trainable_variables )
213+ actor_gradient = tape .gradient (actor_loss ,
214+ self .actor .trainable_variables )
213215 self .actor .optimizer .apply_gradients (
214216 zip (actor_gradient , self .actor .trainable_variables ))
215217
@@ -250,13 +252,10 @@ def save_models(self):
250252 self .target_critic_2 .save_weights (self .target_critic_2 .checkpoint_file )
251253
252254 def load_models (self ):
253-
254255 print ('... loading models ...' )
255256 self .actor .load_weights (self .actor .checkpoint_file )
256257 self .critic_1 .load_weights (self .critic_1 .checkpoint_file )
257258 self .critic_2 .load_weights (self .critic_2 .checkpoint_file )
258259 self .target_actor .load_weights (self .target_actor .checkpoint_file )
259260 self .target_critic_1 .load_weights (self .target_critic_1 .checkpoint_file )
260261 self .target_critic_2 .load_weights (self .target_critic_2 .checkpoint_file )
261-
262-
0 commit comments