88from torch .nn import MSELoss
99from torch .optim import Adam
1010
11- from fast_rl .agents .BaseAgent import BaseAgent , create_nn_model
11+ from fast_rl .agents .BaseAgent import BaseAgent , create_nn_model , create_cnn_model , get_next_conv_shape , get_conv , \
12+ Flatten
1213from fast_rl .core .Learner import AgentLearner
1314from fast_rl .core .MarkovDecisionProcess import MDPDataBunch
1415from fast_rl .core .agent_core import GreedyEpsilon , ExperienceReplay
@@ -27,6 +28,8 @@ def on_train_begin(self, n_epochs, **kwargs: Any):
2728
2829 def on_epoch_begin (self , epoch , ** kwargs : Any ):
2930 self .episode = epoch
31+ # if self.learn.model.training and self.iteration != 0:
32+ # self.learn.model.memory.update(item=self.learn.data.x.items[-1])
3033 self .iteration = 0
3134
3235 def on_loss_begin (self , ** kwargs : Any ):
@@ -47,7 +50,7 @@ def on_loss_begin(self, **kwargs: Any):
4750 # self.learn.model.target_copy_over()
4851
4952
50- class Critic (nn .Module ):
53+ class NNCritic (nn .Module ):
5154 def __init__ (self , layer_list : list , action_size , state_size , use_bn = False , use_embed = True ,
5255 activation_function = None ):
5356 super ().__init__ ()
@@ -59,7 +62,7 @@ def __init__(self, layer_list: list, action_size, state_size, use_bn=False, use_
5962 self .fc3 = nn .Linear (layer_list [1 ], 1 )
6063
6164 def forward (self , x ):
62- action , x = x [:, self . state_size :], x [:, : self . state_size ]
65+ x , action = x
6366
6467 x = nn .LeakyReLU ()(self .fc1 (x ))
6568 x = nn .LeakyReLU ()(self .fc2 (torch .cat ((x , action ), 1 )))
@@ -68,17 +71,41 @@ def forward(self, x):
6871 return x
6972
7073
74+ class CNNCritic (nn .Module ):
75+ def __init__ (self , layer_list : list , action_size , state_size , activation_function = None ):
76+ super ().__init__ ()
77+ self .action_size = action_size [0 ]
78+ self .state_size = state_size [0 ]
79+
80+ layers = []
81+ layers , input_size = get_conv (self .state_size , nn .LeakyReLU (), 8 , 2 , 3 , layers )
82+ layers += [Flatten ()]
83+ self .conv_layers = nn .Sequential (* layers )
84+
85+ self .fc1 = nn .Linear (input_size + self .action_size , 200 )
86+ self .fc2 = nn .Linear (200 , 1 )
87+
88+ def forward (self , x ):
89+ x , action = x
90+
91+ x = nn .LeakyReLU ()(self .conv_layers (x ))
92+ x = nn .LeakyReLU ()(self .fc1 (torch .cat ((x , action ), 1 )))
93+ x = nn .LeakyReLU ()(self .fc2 (x ))
94+
95+ return x
96+
97+
7198class DDPG (BaseAgent ):
7299
73100 def __init__ (self , data : MDPDataBunch , memory = None , tau = 1e-3 , batch = 64 , discount = 0.99 ,
74- lr = 1e-3 , actor_lr = 1e-4 , exploration_strategy = None , env_was_discrete = False ):
101+ lr = 1e-3 , actor_lr = 1e-4 , exploration_strategy = None ):
75102 """
76103 Implementation of a continuous control algorithm using an actor/critic architecture.
77104
78105 Notes:
79106 Uses 4 networks, 2 actors, 2 critics.
80107 All models use batch norm for feature invariance.
81- Critic simply predicts Q while the Actor proposes the actions to take given a state s.
108+ NNCritic simply predicts Q while the Actor proposes the actions to take given a state s.
82109
83110 References:
84111 [1] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning."
@@ -93,7 +120,6 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount
93120 lr: Rate that the opt will learn parameter gradients.
94121 """
95122 super ().__init__ (data )
96- self .env_was_discrete = env_was_discrete
97123 self .name = 'DDPG'
98124 self .lr = lr
99125 self .discount = discount
@@ -122,21 +148,30 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount
122148 do_exploration = self .training ))
123149
124150 def initialize_action_model (self , layers , data ):
125- return create_nn_model (layers , * data .get_action_state_size (), False , use_embed = data .train_ds .embeddable ,
126- final_activation_function = nn .Tanh )
151+ actions , state = data .get_action_state_size ()
152+ if type (state [0 ]) is tuple and len (state [0 ]) == 3 :
153+ # actions, state = actions[0], state[0]
154+ # If the shape has 3 dimensions, we will try using cnn's instead.
155+ return create_cnn_model ([200 , 200 ], actions , state , False , kernel_size = 8 ,
156+ final_activation_function = nn .Tanh , action_val_to_dim = False )
157+ else :
158+ return create_nn_model (layers , * data .get_action_state_size (), False , use_embed = data .train_ds .embeddable ,
159+ final_activation_function = nn .Tanh , action_val_to_dim = False )
127160
128161 def initialize_critic_model (self , layers , data ):
129162 """ Instead of state -> action, we are going state + action -> single expected reward. """
130- return Critic (layers , * data .get_action_state_size ())
163+ actions , state = data .get_action_state_size ()
164+ if type (state [0 ]) is tuple and len (state [0 ]) == 3 :
165+ return CNNCritic (layers , * data .get_action_state_size ())
166+ else :
167+ return NNCritic (layers , * data .get_action_state_size ())
131168
132169 def pick_action (self , x ):
133170 if self .training : self .action_model .eval ()
134171 with torch .no_grad ():
135- action , x = super (DDPG , self ).pick_action (x )
172+ action = super (DDPG , self ).pick_action (x )
136173 if self .training : self .action_model .train ()
137-
138- if not self .env_was_discrete : action = np .clip (action , - 1 , 1 )
139- return action , np .clip (x , - 1 , 1 )
174+ return np .clip (action , - 1 , 1 )
140175
141176 def optimize (self ):
142177 """
@@ -160,12 +195,11 @@ def optimize(self):
160195 s_prime = torch .from_numpy (np .array ([item .result_state for item in sampled ])).float ()
161196 s = torch .from_numpy (np .array ([item .current_state for item in sampled ])).float ()
162197 a = torch .from_numpy (np .array ([item .actions for item in sampled ]).astype (float )).float ()
163- if self .env_was_discrete : a = torch .from_numpy (np .array ([item .raw_action for item in sampled ]).astype (float )).float ()
164198
165199 with torch .no_grad ():
166- y = r + self .discount * self .t_critic_model (torch . cat (( s_prime , self .t_action_model (s_prime )), 1 ))
200+ y = r + self .discount * self .t_critic_model (( s_prime , self .t_action_model (s_prime )))
167201
168- y_hat = self .critic_model (torch . cat (( s , a ), 1 ))
202+ y_hat = self .critic_model (( s , a ))
169203
170204 critic_loss = self .loss_func (y_hat , y )
171205
@@ -175,7 +209,7 @@ def optimize(self):
175209 critic_loss .backward ()
176210 self .critic_optimizer .step ()
177211
178- actor_loss = - self .critic_model (torch . cat (( s , self .action_model (s )), 1 )).mean ()
212+ actor_loss = - self .critic_model (( s , self .action_model (s ))).mean ()
179213
180214 self .loss = critic_loss .cpu ().detach ()
181215
0 commit comments