Skip to content

Commit 5d38241

Browse files
committed
Add code for A3C
1 parent 16b70be commit 5d38241

File tree

3 files changed

+514
-0
lines changed

3 files changed

+514
-0
lines changed

sota-implementations/a3c/a3c_atari.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from __future__ import annotations
2+
3+
from copy import deepcopy
4+
5+
import hydra
6+
import torch
7+
8+
import torch.multiprocessing as mp
9+
import torch.nn as nn
10+
import torch.optim
11+
import tqdm
12+
13+
from torchrl.collectors import SyncDataCollector
14+
from torchrl.objectives import A2CLoss
15+
from torchrl.objectives.value.advantages import GAE
16+
17+
from torchrl.record.loggers import generate_exp_name, get_logger
18+
from utils_atari import make_parallel_env, make_ppo_models
19+
20+
21+
torch.set_float32_matmul_precision("high")
22+
23+
24+
class SharedAdam(torch.optim.Adam):
25+
def __init__(self, params, **kwargs):
26+
super().__init__(params, **kwargs)
27+
for group in self.param_groups:
28+
for p in group["params"]:
29+
state = self.state[p]
30+
state["step"] = torch.zeros(1)
31+
state["exp_avg"] = torch.zeros_like(p.data)
32+
state["exp_avg_sq"] = torch.zeros_like(p.data)
33+
state["exp_avg"].share_memory_()
34+
state["exp_avg_sq"].share_memory_()
35+
state["step"].share_memory_()
36+
37+
38+
class A3CWorker(mp.Process):
39+
def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=None):
40+
super().__init__()
41+
self.name = name
42+
self.cfg = cfg
43+
44+
self.optimizer = optimizer
45+
46+
self.device = cfg.loss.device or torch.device(
47+
"cuda:0" if torch.cuda.is_available() else "cpu"
48+
)
49+
50+
self.frame_skip = 4
51+
self.total_frames = cfg.collector.total_frames // self.frame_skip
52+
self.frames_per_batch = cfg.collector.frames_per_batch // self.frame_skip
53+
self.mini_batch_size = cfg.loss.mini_batch_size // self.frame_skip
54+
self.test_interval = cfg.logger.test_interval // self.frame_skip
55+
56+
self.global_actor = global_actor
57+
self.global_critic = global_critic
58+
self.local_actor = deepcopy(global_actor)
59+
self.local_critic = deepcopy(global_critic)
60+
61+
self.logger = logger
62+
63+
self.adv_module = GAE(
64+
gamma=cfg.loss.gamma,
65+
lmbda=cfg.loss.gae_lambda,
66+
value_network=self.local_critic,
67+
average_gae=True,
68+
vectorized=not cfg.compile.compile,
69+
device=self.device,
70+
)
71+
self.loss_module = A2CLoss(
72+
actor_network=self.local_actor,
73+
critic_network=self.local_critic,
74+
loss_critic_type=cfg.loss.loss_critic_type,
75+
entropy_coef=cfg.loss.entropy_coef,
76+
critic_coef=cfg.loss.critic_coef,
77+
)
78+
79+
self.adv_module.set_keys(done="end-of-life", terminated="end-of-life")
80+
self.loss_module.set_keys(done="end-of-life", terminated="end-of-life")
81+
82+
def update(self, batch, max_grad_norm=None):
83+
if max_grad_norm is None:
84+
max_grad_norm = self.cfg.optim.max_grad_norm
85+
86+
loss = self.loss_module(batch)
87+
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
88+
loss_sum.backward()
89+
90+
for local_param, global_param in zip(
91+
self.local_actor.parameters(), self.global_actor.parameters()
92+
):
93+
global_param._grad = local_param.grad
94+
95+
for local_param, global_param in zip(
96+
self.local_critic.parameters(), self.global_critic.parameters()
97+
):
98+
global_param._grad = local_param.grad
99+
100+
gn = torch.nn.utils.clip_grad_norm_(
101+
self.loss_module.parameters(), max_norm=max_grad_norm
102+
)
103+
104+
self.optimizer.step()
105+
self.optimizer.zero_grad(set_to_none=True)
106+
107+
return (
108+
loss.select("loss_critic", "loss_entropy", "loss_objective")
109+
.detach()
110+
.set("grad_norm", gn)
111+
)
112+
113+
def run(self):
114+
cfg = self.cfg
115+
116+
collector = SyncDataCollector(
117+
create_env_fn=make_parallel_env(
118+
cfg.env.env_name,
119+
num_envs=cfg.env.num_envs,
120+
device=self.device,
121+
gym_backend=cfg.env.backend,
122+
),
123+
policy=self.local_actor,
124+
frames_per_batch=self.frames_per_batch,
125+
total_frames=self.total_frames,
126+
device=self.device,
127+
storing_device=self.device,
128+
policy_device=self.device,
129+
compile_policy=False,
130+
cudagraph_policy=False,
131+
)
132+
133+
collected_frames = 0
134+
num_network_updates = 0
135+
pbar = tqdm.tqdm(total=self.total_frames)
136+
num_mini_batches = self.frames_per_batch // self.mini_batch_size
137+
total_network_updates = (
138+
self.total_frames // self.frames_per_batch
139+
) * num_mini_batches
140+
lr = cfg.optim.lr
141+
142+
c_iter = iter(collector)
143+
total_iter = len(collector)
144+
145+
for _ in range(total_iter):
146+
data = next(c_iter)
147+
148+
metrics_to_log = {}
149+
frames_in_batch = data.numel()
150+
collected_frames += self.frames_per_batch * self.frame_skip
151+
pbar.update(frames_in_batch)
152+
153+
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
154+
if len(episode_rewards) > 0:
155+
episode_length = data["next", "step_count"][data["next", "terminated"]]
156+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
157+
metrics_to_log[
158+
"train/episode_length"
159+
] = episode_length.sum().item() / len(episode_length)
160+
161+
with torch.no_grad():
162+
data = self.adv_module(data)
163+
data_reshape = data.reshape(-1)
164+
losses = []
165+
166+
mini_batches = data_reshape.split(self.mini_batch_size)
167+
for batch in mini_batches:
168+
alpha = 1.0
169+
if cfg.optim.anneal_lr:
170+
alpha = 1 - (num_network_updates / total_network_updates)
171+
for group in self.optimizer.param_groups:
172+
group["lr"] = lr * alpha
173+
174+
num_network_updates += 1
175+
loss = self.update(batch).clone()
176+
losses.append(loss)
177+
178+
losses = torch.stack(losses).float().mean()
179+
180+
for key, value in losses.items():
181+
metrics_to_log[f"train/{key}"] = value.item()
182+
183+
metrics_to_log["train/lr"] = lr * alpha
184+
if self.logger:
185+
for key, value in metrics_to_log.items():
186+
self.logger.log_scalar(key, value, collected_frames)
187+
collector.shutdown()
188+
189+
190+
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
191+
def main(cfg: DictConfig): # noqa: F821
192+
193+
global_actor, global_critic, global_critic_head = make_ppo_models(
194+
cfg.env.env_name, device=cfg.loss.device, gym_backend=cfg.env.backend
195+
)
196+
global_model = nn.ModuleList([global_actor, global_critic_head])
197+
global_model.share_memory()
198+
optimizer = SharedAdam(global_model.parameters(), lr=cfg.optim.lr)
199+
200+
num_workers = cfg.multiprocessing.num_workers
201+
202+
if num_workers is None:
203+
num_workers = mp.cpu_count()
204+
logger = None
205+
if cfg.logger.backend:
206+
exp_name = generate_exp_name("A3C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
207+
logger = get_logger(
208+
cfg.logger.backend,
209+
logger_name="a3c",
210+
experiment_name=exp_name,
211+
wandb_kwargs={
212+
"config": dict(cfg),
213+
"project": cfg.logger.project_name,
214+
"group": cfg.logger.group_name,
215+
},
216+
)
217+
218+
workers = [
219+
A3CWorker(f"worker_{i}", cfg, global_actor, global_critic, optimizer, logger)
220+
for i in range(num_workers)
221+
]
222+
[w.start() for w in workers]
223+
[w.join() for w in workers]
224+
225+
226+
if __name__ == "__main__":
227+
main()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Environment
2+
env:
3+
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
5+
num_envs: 1
6+
7+
# collector
8+
collector:
9+
frames_per_batch: 800
10+
total_frames: 40_000_000
11+
12+
# logger
13+
logger:
14+
backend: wandb
15+
project_name: torchrl_example_a2c
16+
group_name: null
17+
exp_name: Atari_Schulman17
18+
test_interval: 40_000_000
19+
num_test_episodes: 3
20+
video: False
21+
22+
# Optim
23+
optim:
24+
lr: 0.0001
25+
eps: 1.0e-8
26+
weight_decay: 0.0
27+
max_grad_norm: 40.0
28+
anneal_lr: True
29+
30+
# loss
31+
loss:
32+
gamma: 0.99
33+
mini_batch_size: 80
34+
gae_lambda: 0.95
35+
critic_coef: 0.25
36+
entropy_coef: 0.01
37+
loss_critic_type: l2
38+
device:
39+
40+
compile:
41+
compile: False
42+
compile_mode:
43+
cudagraphs: False
44+
45+
multiprocessing:
46+
num_workers: 16

0 commit comments

Comments
 (0)