1- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import argparse
16- import gym
17-
15+ import parl
1816import numpy as np
19-
2017import torch
21- import torch .nn as nn
22- import torch .optim as optim
23- import torch .nn .functional as F
24-
25- import parl
18+ from parl .utils .scheduler import LinearDecayScheduler
2619
2720
2821class AtariAgent (parl .Agent ):
29- """Base class of the Agent .
22+ """Agent of Atari env .
3023
3124 Args:
32- algorithm (object): Algorithm used by this agent.
33- args (argparse.Namespace): Model configurations.
34- device (torch.device): use cpu or gpu.
25+ algorithm (`parl.Algorithm`): algorithm to be used in this agent.
26+ act_dim (int): action space dimension
27+ total_step (int): total epsilon decay steps
28+ start_lr (float): initial learning rate
29+ update_target_step (int): target network update frequency
3530 """
3631
37- def __init__ (self , algorithm , act_dim ):
38- assert isinstance (act_dim , int )
39- super (AtariAgent , self ).__init__ (algorithm )
32+ def __init__ (self , algorithm , act_dim , start_lr , total_step ,
33+ update_target_step ):
34+ super ().__init__ (algorithm )
35+ self .global_update_step = 0
36+ self .update_target_step = update_target_step
4037 self .act_dim = act_dim
41- self .exploration = 1
42- self .global_step = 0
43- self .update_target_steps = 10000 // 4
44-
38+ self .curr_ep = 1
39+ self .ep_end = 0.1
40+ self .lr_end = 0.00001
4541 self .device = torch .device ('cuda' if torch .cuda .
4642 is_available () else 'cpu' )
4743
48- def save (self , filepath ):
49- state = {
50- 'model' : self .alg .model .state_dict (),
51- 'target_model' : self .alg .target_model .state_dict (),
52- 'optimizer' : self .alg .optimizer .state_dict (),
53- 'scheduler' : self .alg .scheduler .state_dict (),
54- 'exploration' : self .exploration ,
55- }
56- torch .save (state , filepath )
57-
58- def restore (self , filepath ):
59- checkpoint = torch .load (filepath )
60- self .exploration = checkpoint ['exploration' ]
61- self .alg .model .load_state_dict (checkpoint ['model' ])
62- self .alg .target_model .load_state_dict (checkpoint ['target_model' ])
63- self .alg .optimizer .load_state_dict (checkpoint ['optimizer' ])
64- self .alg .scheduler .load_state_dict (checkpoint ['scheduler' ])
44+ self .ep_scheduler = LinearDecayScheduler (1 , total_step )
45+ self .lr_scheduler = LinearDecayScheduler (start_lr , total_step )
6546
6647 def sample (self , obs ):
67- sample = np .random .random ()
68- if sample < self .exploration :
48+ """Sample an action when given an observation, base on the current epsilon value,
49+ either a greedy action or a random action will be returned.
50+
51+ Args:
52+ obs (np.float32): shape of (3, 84, 84) or (1, 3, 84, 84), current observation
53+
54+ Returns:
55+ act (int): action
56+ """
57+ explore = np .random .choice ([True , False ],
58+ p = [self .curr_ep , 1 - self .curr_ep ])
59+ if explore :
6960 act = np .random .randint (self .act_dim )
7061 else :
71- if np .random .random () < 0.01 :
72- act = np .random .randint (self .act_dim )
73- else :
74- act = self .predict (obs )
75- self .exploration = max (0.1 , self .exploration - 1e-6 )
62+ act = self .predict (obs )
63+
64+ self .curr_ep = max (self .ep_scheduler .step (1 ), self .ep_end )
7665 return act
7766
7867 def predict (self , obs ):
79- obs = np .expand_dims (obs , 0 )
68+ """Predict an action when given an observation, a greedy action will be returned.
69+
70+ Args:
71+ obs (np.float32): shape of (3, 84, 84) or (1, 3, 84, 84), current observation
72+
73+ Returns:
74+ act(int): action
75+ """
76+ if obs .ndim == 3 : # if obs is 3 dimensional, we need to expand it to have batch_size = 1
77+ obs = np .expand_dims (obs , axis = 0 )
78+
8079 obs = torch .tensor (obs , dtype = torch .float , device = self .device )
81- pred_q = self .alg .predict (obs )
82- action = pred_q .max (1 )[1 ].item ()
83- return action
80+ pred_q = self .alg .predict (obs ).cpu ().detach ().numpy ().squeeze ()
81+
82+ best_actions = np .where (pred_q == pred_q .max ())[0 ]
83+ act = np .random .choice (best_actions )
84+ return act
8485
8586 def learn (self , obs , act , reward , next_obs , terminal ):
86- if self .global_step % self .update_target_steps == 0 :
87+ """Update model with an episode data
88+
89+ Args:
90+ obs (np.float32): shape of (batch_size, obs_dim)
91+ act (np.int32): shape of (batch_size)
92+ reward (np.float32): shape of (batch_size)
93+ next_obs (np.float32): shape of (batch_size, obs_dim)
94+ terminal (np.float32): shape of (batch_size)
95+
96+ Returns:
97+ loss (float)
98+ """
99+ if self .global_update_step % self .update_target_step == 0 :
87100 self .alg .sync_target ()
88- self .global_step += 1
89101
90- act = np .expand_dims (act , - 1 )
91- terminal = np .expand_dims (terminal , - 1 )
92- reward = np .expand_dims (reward , - 1 )
102+ self .global_update_step += 1
103+
93104 reward = np .clip (reward , - 1 , 1 )
105+ act = np .expand_dims (act , axis = - 1 )
106+ reward = np .expand_dims (reward , axis = - 1 )
107+ terminal = np .expand_dims (terminal , axis = - 1 )
94108
95109 obs = torch .tensor (obs , dtype = torch .float , device = self .device )
96110 next_obs = torch .tensor (
@@ -100,5 +114,10 @@ def learn(self, obs, act, reward, next_obs, terminal):
100114 terminal = torch .tensor (
101115 terminal , dtype = torch .float , device = self .device )
102116
103- cost = self .alg .learn (obs , act , reward , next_obs , terminal )
104- return cost
117+ loss = self .alg .learn (obs , act , reward , next_obs , terminal )
118+
119+ # learning rate decay
120+ for param_group in self .alg .optimizer .param_groups :
121+ param_group ['lr' ] = max (self .lr_scheduler .step (1 ), self .lr_end )
122+
123+ return loss
0 commit comments