1+ from copy import deepcopy
2+
13import torch
24from fastai .basic_train import LearnerCallback , Any , OptimWrapper , ifnone , F
35import numpy as np
46from fastai .metrics import RMSE
7+ from torch import nn
58from torch .nn import MSELoss
69from torch .optim import Adam
710
@@ -30,7 +33,7 @@ def on_loss_begin(self, **kwargs: Any):
3033 """Performs memory updates, exploration updates, and model optimization."""
3134 if self .learn .model .training :
3235 self .learn .model .memory .update (item = self .learn .data .x .items [- 1 ])
33- self .learn .model .exploration_strategy .update (self .episode , self .max_episodes ,
36+ self .learn .model .exploration_strategy .update (episode = self .episode , max_episodes = self .max_episodes ,
3437 do_exploration = self .learn .model .training )
3538 post_optimize = self .learn .model .optimize ()
3639 if self .learn .model .training :
@@ -44,10 +47,31 @@ def on_loss_begin(self, **kwargs: Any):
4447 # self.learn.model.target_copy_over()
4548
4649
50+ class Critic (nn .Module ):
51+ def __init__ (self , layer_list : list , action_size , state_size , use_bn = False , use_embed = True ,
52+ activation_function = None ):
53+ super ().__init__ ()
54+ self .action_size = action_size [0 ]
55+ self .state_size = state_size [0 ]
56+
57+ self .fc1 = nn .Linear (self .state_size , layer_list [0 ])
58+ self .fc2 = nn .Linear (layer_list [0 ] + self .action_size , layer_list [1 ])
59+ self .fc3 = nn .Linear (layer_list [1 ], 1 )
60+
61+ def forward (self , x ):
62+ action , x = x [:, self .state_size :], x [:, :self .state_size ]
63+
64+ x = nn .LeakyReLU ()(self .fc1 (x ))
65+ x = nn .LeakyReLU ()(self .fc2 (torch .cat ((x , action ), 1 )))
66+ x = nn .LeakyReLU ()(self .fc3 (x ))
67+
68+ return x
69+
70+
4771class DDPG (BaseAgent ):
4872
49- def __init__ (self , data : MDPDataBunch , memory = None , tau = 0.001 , batch = 64 , discount = 0.99 ,
50- lr = 0.005 , exploration_strategy = None , env_was_discrete = False ):
73+ def __init__ (self , data : MDPDataBunch , memory = None , tau = 1e-3 , batch = 64 , discount = 0.99 ,
74+ lr = 1e-3 , actor_lr = 1e-4 , exploration_strategy = None , env_was_discrete = False ):
5175 """
5276 Implementation of a continuous control algorithm using an actor/critic architecture.
5377
@@ -74,42 +98,45 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=0.001, batch=64, discoun
7498 self .lr = lr
7599 self .discount = discount
76100 self .batch = batch
77- self .tao = tau
101+ self .tau = 1
78102 self .memory = ifnone (memory , ExperienceReplay (10000 ))
79103
80- self .action_model = self .initialize_action_model ([30 , 30 ], data )
81- self .critic_model = self .initialize_critic_model ([30 , 30 ], data )
104+ self .action_model = self .initialize_action_model ([400 , 300 ], data )
105+ self .critic_model = self .initialize_critic_model ([400 , 300 ], data )
82106
83- self .opt = OptimWrapper .create (Adam , lr = lr , layer_groups = [self .action_model ])
107+ self .opt = OptimWrapper .create (Adam , lr = actor_lr , layer_groups = [self .action_model ])
84108 self .critic_optimizer = OptimWrapper .create (Adam , lr = lr , layer_groups = [self .critic_model ])
85109
86- self .t_action_model = self .initialize_action_model ([ 30 , 30 ], data )
87- self .t_critic_model = self .initialize_critic_model ([ 30 , 30 ], data )
110+ self .t_action_model = deepcopy ( self .action_model )
111+ self .t_critic_model = deepcopy ( self .critic_model )
88112
89113 self .target_copy_over ()
114+ self .tau = tau
90115
91116 self .learner_callbacks = [BaseDDPGCallback ]
92117
93- self .loss_func = F . smooth_l1_loss # MSELoss()
94- # TODO Move to Ornstein-Uhlenbeck process
118+ self .loss_func = MSELoss ()
119+
95120 self .exploration_strategy = ifnone (exploration_strategy , GreedyEpsilon (epsilon_start = 1 , epsilon_end = 0.1 ,
96121 decay = 0.001 ,
97122 do_exploration = self .training ))
98123
99124 def initialize_action_model (self , layers , data ):
100- return create_nn_model (layers , * data .get_action_state_size (), True , use_embed = data .train_ds .embeddable )
125+ return create_nn_model (layers , * data .get_action_state_size (), False , use_embed = data .train_ds .embeddable ,
126+ final_activation_function = nn .Tanh )
101127
102128 def initialize_critic_model (self , layers , data ):
103129 """ Instead of state -> action, we are going state + action -> single expected reward. """
104- return create_nn_model (layers , (1 , 0 ), (sum ([_ [0 ] for _ in data .get_action_state_size ()]), 0 ), True ,
105- use_embed = data .train_ds .embeddable )
130+ return Critic (layers , * data .get_action_state_size ())
106131
107132 def pick_action (self , x ):
108133 if self .training : self .action_model .eval ()
109134 with torch .no_grad ():
110- action = super (DDPG , self ).pick_action (x )
135+ action , x = super (DDPG , self ).pick_action (x )
111136 if self .training : self .action_model .train ()
112- return action
137+
138+ if not self .env_was_discrete : action = np .clip (action , - 1 , 1 )
139+ return action , np .clip (x , - 1 , 1 )
113140
114141 def optimize (self ):
115142 """
@@ -140,16 +167,12 @@ def optimize(self):
140167
141168 y_hat = self .critic_model (torch .cat ((s , a ), 1 ))
142169
143- critic_loss = self .loss_func (y , y_hat )
144-
145- print (f'{ y [0 ][:15 ]} , { y_hat [0 ][:15 ]} ' )
170+ critic_loss = self .loss_func (y_hat , y )
146171
147172 if self .training :
148173 # Optimize critic network
149174 self .critic_optimizer .zero_grad ()
150175 critic_loss .backward ()
151- for param in self .critic_model .parameters ():
152- param .grad .data .clamp_ (- 1 , 1 )
153176 self .critic_optimizer .step ()
154177
155178 actor_loss = - self .critic_model (torch .cat ((s , self .action_model (s )), 1 )).mean ()
@@ -160,8 +183,6 @@ def optimize(self):
160183 # Optimize actor network
161184 self .opt .zero_grad ()
162185 actor_loss .backward ()
163- for param in self .action_model .parameters ():
164- param .grad .data .clamp_ (- 1 , 1 )
165186 self .opt .step ()
166187
167188 with torch .no_grad ():
@@ -174,8 +195,8 @@ def forward(self, x):
174195
175196 def target_copy_over (self ):
176197 """ Soft target updates the actor and critic models.."""
177- self .soft_target_copy_over (self .t_action_model , self .action_model , self .tao )
178- self .soft_target_copy_over (self .t_critic_model , self .critic_model , self .tao )
198+ self .soft_target_copy_over (self .t_action_model , self .action_model , self .tau )
199+ self .soft_target_copy_over (self .t_critic_model , self .critic_model , self .tau )
179200
180201 def soft_target_copy_over (self , t_m , f_m , tau ):
181202 for target_param , local_param in zip (t_m .parameters (), f_m .parameters ()):
0 commit comments