Skip to content

Commit d1eef93

Browse files
authored
update torch dqn, add benchmark results for dqn (#694)
* update torch dqn * add benchmarks * update benchmarks * replace no_grad with detach * yapf
1 parent 4d9ac1d commit d1eef93

File tree

9 files changed

+291
-188
lines changed

9 files changed

+291
-188
lines changed

benchmark/torch/dqn/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ Performance of **DQN** on various environments:
2222
<img src=".benchmark/dqn.png" alt="result"/>
2323
</p>
2424

25+
Performance of **Dueling DQN** on 55 Atari environments:
26+
27+
| | | | | |
28+
|---------------------|----------------------|----------------------|--------------------|-----------------|
29+
|Alien (2390) | Amidar (468) | Assault (13898) |Asterix (24067) | Asteroids (450) |
30+
|Atlantis (136833) | WizardOfWor (1767) | BankHeist (953) |BattleZone (26667) | BeamRider (9771) |
31+
|Berzerk (531) | Bowling (30) | Boxing (100) |Breakout (531) | Centipede (7416) |
32+
|ChopperCommand (1533)| CrazyClimber (102072)| DemonAttack (83478) |DoubleDunk (0) | Enduro (1634) |
33+
|FishingDerby (26) | Freeway (32) | Frostbite (4803) |Gopher (8128) | Gravitar (83) |
34+
|Hero (11810) | IceHockey (-3) | Jamesbond (616) |Kangaroo (4900) | Krull (8789) |
35+
|KungFuMaster (33144) | MontezumaRevenge (0) | MsPacman (2873) |NameThisGame (15010)| Phoenix (14837) |
36+
|Pitfall (0) | Pong (21) | PrivateEye (100) |Qbert (4850) | Riverraid (12453)|
37+
|RoadRunner (58000) | Robotank (26) | Seaquest (5960) |Skiing (-10584) | Solaris (347) |
38+
|SpaceInvaders (2068) | StarGunner (22100) | Tennis (1) |TimePilot (2967) | Tutankham (132) |
39+
|UpNDown (12350) | Venture (0) | VideoPinball (876611)|YarsRevenge (30281) | Zaxxon (4400) |
40+
2541
## How to use
2642
### Dependencies:
2743
+ python>=3.6.2

benchmark/torch/dqn/agent.py

Lines changed: 75 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,85 +12,99 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import argparse
16-
import gym
17-
15+
import parl
1816
import numpy as np
19-
2017
import torch
21-
import torch.nn as nn
22-
import torch.optim as optim
23-
import torch.nn.functional as F
24-
25-
import parl
18+
from parl.utils.scheduler import LinearDecayScheduler
2619

2720

2821
class AtariAgent(parl.Agent):
29-
"""Base class of the Agent.
22+
"""Agent of Atari env.
3023
3124
Args:
32-
algorithm (object): Algorithm used by this agent.
33-
args (argparse.Namespace): Model configurations.
34-
device (torch.device): use cpu or gpu.
25+
algorithm (`parl.Algorithm`): algorithm to be used in this agent.
26+
act_dim (int): action space dimension
27+
total_step (int): total epsilon decay steps
28+
start_lr (float): initial learning rate
29+
update_target_step (int): target network update frequency
3530
"""
3631

37-
def __init__(self, algorithm, act_dim):
38-
assert isinstance(act_dim, int)
39-
super(AtariAgent, self).__init__(algorithm)
32+
def __init__(self, algorithm, act_dim, start_lr, total_step,
33+
update_target_step):
34+
super().__init__(algorithm)
35+
self.global_update_step = 0
36+
self.update_target_step = update_target_step
4037
self.act_dim = act_dim
41-
self.exploration = 1
42-
self.global_step = 0
43-
self.update_target_steps = 10000 // 4
44-
38+
self.curr_ep = 1
39+
self.ep_end = 0.1
40+
self.lr_end = 0.00001
4541
self.device = torch.device('cuda' if torch.cuda.
4642
is_available() else 'cpu')
4743

48-
def save(self, filepath):
49-
state = {
50-
'model': self.alg.model.state_dict(),
51-
'target_model': self.alg.target_model.state_dict(),
52-
'optimizer': self.alg.optimizer.state_dict(),
53-
'scheduler': self.alg.scheduler.state_dict(),
54-
'exploration': self.exploration,
55-
}
56-
torch.save(state, filepath)
57-
58-
def restore(self, filepath):
59-
checkpoint = torch.load(filepath)
60-
self.exploration = checkpoint['exploration']
61-
self.alg.model.load_state_dict(checkpoint['model'])
62-
self.alg.target_model.load_state_dict(checkpoint['target_model'])
63-
self.alg.optimizer.load_state_dict(checkpoint['optimizer'])
64-
self.alg.scheduler.load_state_dict(checkpoint['scheduler'])
44+
self.ep_scheduler = LinearDecayScheduler(1, total_step)
45+
self.lr_scheduler = LinearDecayScheduler(start_lr, total_step)
6546

6647
def sample(self, obs):
67-
sample = np.random.random()
68-
if sample < self.exploration:
48+
"""Sample an action when given an observation, base on the current epsilon value,
49+
either a greedy action or a random action will be returned.
50+
51+
Args:
52+
obs (np.float32): shape of (3, 84, 84) or (1, 3, 84, 84), current observation
53+
54+
Returns:
55+
act (int): action
56+
"""
57+
explore = np.random.choice([True, False],
58+
p=[self.curr_ep, 1 - self.curr_ep])
59+
if explore:
6960
act = np.random.randint(self.act_dim)
7061
else:
71-
if np.random.random() < 0.01:
72-
act = np.random.randint(self.act_dim)
73-
else:
74-
act = self.predict(obs)
75-
self.exploration = max(0.1, self.exploration - 1e-6)
62+
act = self.predict(obs)
63+
64+
self.curr_ep = max(self.ep_scheduler.step(1), self.ep_end)
7665
return act
7766

7867
def predict(self, obs):
79-
obs = np.expand_dims(obs, 0)
68+
"""Predict an action when given an observation, a greedy action will be returned.
69+
70+
Args:
71+
obs (np.float32): shape of (3, 84, 84) or (1, 3, 84, 84), current observation
72+
73+
Returns:
74+
act(int): action
75+
"""
76+
if obs.ndim == 3: # if obs is 3 dimensional, we need to expand it to have batch_size = 1
77+
obs = np.expand_dims(obs, axis=0)
78+
8079
obs = torch.tensor(obs, dtype=torch.float, device=self.device)
81-
pred_q = self.alg.predict(obs)
82-
action = pred_q.max(1)[1].item()
83-
return action
80+
pred_q = self.alg.predict(obs).cpu().detach().numpy().squeeze()
81+
82+
best_actions = np.where(pred_q == pred_q.max())[0]
83+
act = np.random.choice(best_actions)
84+
return act
8485

8586
def learn(self, obs, act, reward, next_obs, terminal):
86-
if self.global_step % self.update_target_steps == 0:
87+
"""Update model with an episode data
88+
89+
Args:
90+
obs (np.float32): shape of (batch_size, obs_dim)
91+
act (np.int32): shape of (batch_size)
92+
reward (np.float32): shape of (batch_size)
93+
next_obs (np.float32): shape of (batch_size, obs_dim)
94+
terminal (np.float32): shape of (batch_size)
95+
96+
Returns:
97+
loss (float)
98+
"""
99+
if self.global_update_step % self.update_target_step == 0:
87100
self.alg.sync_target()
88-
self.global_step += 1
89101

90-
act = np.expand_dims(act, -1)
91-
terminal = np.expand_dims(terminal, -1)
92-
reward = np.expand_dims(reward, -1)
102+
self.global_update_step += 1
103+
93104
reward = np.clip(reward, -1, 1)
105+
act = np.expand_dims(act, axis=-1)
106+
reward = np.expand_dims(reward, axis=-1)
107+
terminal = np.expand_dims(terminal, axis=-1)
94108

95109
obs = torch.tensor(obs, dtype=torch.float, device=self.device)
96110
next_obs = torch.tensor(
@@ -100,5 +114,10 @@ def learn(self, obs, act, reward, next_obs, terminal):
100114
terminal = torch.tensor(
101115
terminal, dtype=torch.float, device=self.device)
102116

103-
cost = self.alg.learn(obs, act, reward, next_obs, terminal)
104-
return cost
117+
loss = self.alg.learn(obs, act, reward, next_obs, terminal)
118+
119+
# learning rate decay
120+
for param_group in self.alg.optimizer.param_groups:
121+
param_group['lr'] = max(self.lr_scheduler.step(1), self.lr_end)
122+
123+
return loss

benchmark/torch/dqn/model.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,38 +12,55 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import torch
1615
import torch.nn as nn
17-
import torch.nn.functional as F
18-
1916
import parl
2017

2118

2219
class AtariModel(parl.Model):
23-
"""CNN network used in TensorPack examples.
20+
""" Neural Network to solve Atari problem.
2421
2522
Args:
26-
input_channel (int): Input channel of states.
2723
act_dim (int): Dimension of action space.
28-
algo (str): which ('DQN', 'Double', 'Dueling') model to use.
24+
dueling (bool): True if use dueling architecture else False
2925
"""
3026

31-
def __init__(self, input_channel, act_dim, algo='DQN'):
32-
super(AtariModel, self).__init__()
27+
def __init__(self, act_dim, dueling=False):
28+
super().__init__()
3329
self.conv1 = nn.Conv2d(
34-
input_channel, 32, kernel_size=8, stride=4, padding=2)
35-
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=2)
36-
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
30+
in_channels=4, out_channels=32, kernel_size=5, stride=1, padding=2)
31+
self.conv2 = nn.Conv2d(
32+
in_channels=32,
33+
out_channels=32,
34+
kernel_size=5,
35+
stride=1,
36+
padding=2)
37+
self.conv3 = nn.Conv2d(
38+
in_channels=32,
39+
out_channels=64,
40+
kernel_size=4,
41+
stride=1,
42+
padding=1)
43+
self.conv4 = nn.Conv2d(
44+
in_channels=64,
45+
out_channels=64,
46+
kernel_size=3,
47+
stride=1,
48+
padding=1)
49+
self.relu = nn.ReLU()
50+
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
51+
self.flatten = nn.Flatten()
52+
53+
self.dueling = dueling
54+
55+
if dueling:
56+
self.linear_1_adv = nn.Linear(in_features=6400, out_features=512)
57+
self.linear_2_adv = nn.Linear(
58+
in_features=512, out_features=act_dim)
59+
self.linear_1_val = nn.Linear(in_features=6400, out_features=512)
60+
self.linear_2_val = nn.Linear(in_features=512, out_features=1)
3761

38-
self.algo = algo
39-
if self.algo == 'Dueling':
40-
self.fc1_adv = nn.Linear(7744, 512)
41-
self.fc1_val = nn.Linear(7744, 512)
42-
self.fc2_adv = nn.Linear(512, act_dim)
43-
self.fc2_val = nn.Linear(512, 1)
4462
else:
45-
self.fc1 = nn.Linear(7744, 512)
46-
self.fc2 = nn.Linear(512, act_dim)
63+
self.linear_1 = nn.Linear(in_features=6400, out_features=act_dim)
4764

4865
self.reset_params()
4966

@@ -54,16 +71,27 @@ def reset_params(self):
5471
m.weight, mode='fan_out', nonlinearity='relu')
5572
nn.init.zeros_(m.bias)
5673

57-
def forward(self, x):
58-
x = x / 255.0
59-
x = F.relu(self.conv1(x))
60-
x = F.relu(self.conv2(x))
61-
x = F.relu(self.conv3(x))
62-
x = x.view(x.size(0), -1)
63-
if self.algo == 'Dueling':
64-
As = self.fc2_adv(F.relu(self.fc1_adv(x)))
65-
V = self.fc2_val(F.relu(self.fc1_val(x)))
74+
def forward(self, obs):
75+
""" Perform forward pass
76+
77+
Args:
78+
obs (torch.Tensor): shape of (batch_size, 3, 84, 84), mini batch of observations
79+
"""
80+
obs = obs / 255.0
81+
out = self.max_pool(self.relu(self.conv1(obs)))
82+
out = self.max_pool(self.relu(self.conv2(out)))
83+
out = self.max_pool(self.relu(self.conv3(out)))
84+
out = self.relu(self.conv4(out))
85+
out = self.flatten(out)
86+
87+
if self.dueling:
88+
As = self.relu(self.linear_1_adv(out))
89+
As = self.linear_2_adv(As)
90+
V = self.relu(self.linear_1_val(out))
91+
V = self.linear_2_val(V)
6692
Q = As + (V - As.mean(dim=1, keepdim=True))
93+
6794
else:
68-
Q = self.fc2(F.relu(self.fc1(x)))
95+
Q = self.linear_1(out)
96+
6997
return Q

0 commit comments

Comments
 (0)