-
Notifications
You must be signed in to change notification settings - Fork 398
[Feature, Example] A3C Atari Implementation for TorchRL #3001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
simeetnayan81
wants to merge
15
commits into
pytorch:main
Choose a base branch
from
simeetnayan81:a3c-implementation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
5d38241
Add code for A3C
simeetnayan81 748b673
Add readme
simeetnayan81 ecbec8b
log only worker-0 stats
simeetnayan81 72eea77
Add linux test script file, and sota-check for a3c
simeetnayan81 7b9ba6b
Merge branch 'main' into a3c-implementation
simeetnayan81 d95de87
modify sota-check a3c
simeetnayan81 c4184f6
Add code for A3C
simeetnayan81 7cfa7d7
Add readme
simeetnayan81 87ec6f3
log only worker-0 stats
simeetnayan81 bba7ba5
Add linux test script file, and sota-check for a3c
simeetnayan81 b49e35a
modify sota-check a3c
simeetnayan81 a6eb18d
amend
vmoens 836f03e
Merge branch 'pytorch:main' into a3c-implementation
simeetnayan81 19209e2
Move SharedAdam to utils
simeetnayan81 3ebad77
Move SharedAdam to utils
simeetnayan81 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH --job-name=a3c_atari | ||
#SBATCH --ntasks=32 | ||
#SBATCH --cpus-per-task=1 | ||
#SBATCH --gres=gpu:1 | ||
#SBATCH --output=slurm_logs/a3c_atari_%j.txt | ||
#SBATCH --error=slurm_errors/a3c_atari_%j.txt | ||
|
||
current_commit=$(git rev-parse --short HEAD) | ||
project_name="torchrl-example-check-$current_commit" | ||
group_name="a3c_atari" | ||
|
||
export PYTHONPATH=$(dirname $(dirname $PWD)) | ||
python $PYTHONPATH/sota-implementations/a3c/a3c_atari.py \ | ||
logger.backend=wandb \ | ||
logger.project_name="$project_name" \ | ||
logger.group_name="$group_name" | ||
|
||
# Capture the exit status of the Python command | ||
exit_status=$? | ||
# Write the exit status to a file | ||
if [ $exit_status -eq 0 ]; then | ||
echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log | ||
else | ||
echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log | ||
fi |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Reproducing Asynchronous Advantage Actor Critic (A3C) Algorithm Results | ||
|
||
This repository contains scripts that enable training agents using the Asynchronous Advantage Actor Critic (A3C) Algorithm on Atari environments. We follow the original paper [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) by Mnih et al. (2016) to implement the A3C algorithm with a fixed number of steps during the collection phase. | ||
|
||
## Examples Structure | ||
|
||
Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: | ||
|
||
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. `a3c_atari.py`). | ||
|
||
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. `utils_atari.py`). | ||
|
||
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. `config_atari.yaml`). | ||
|
||
## Running the Examples | ||
|
||
You can execute the A3C algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python a3c_atari.py | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
|
||
import hydra | ||
import torch | ||
|
||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
import torch.optim | ||
import tqdm | ||
from tensordict import from_module | ||
|
||
from torchrl.collectors import SyncDataCollector | ||
from torchrl.objectives import A2CLoss | ||
from torchrl.objectives.value.advantages import GAE | ||
|
||
from torchrl.record.loggers import generate_exp_name, get_logger | ||
from utils_atari import make_parallel_env, make_ppo_models, SharedAdam | ||
|
||
|
||
torch.set_float32_matmul_precision("high") | ||
|
||
|
||
class A3CWorker(mp.Process): | ||
def __init__( | ||
self, name, cfg, global_actor, global_critic, optimizer, use_logger=False | ||
): | ||
super().__init__() | ||
self.name = name | ||
self.cfg = cfg | ||
|
||
self.optimizer = optimizer | ||
|
||
self.device = cfg.loss.device or torch.device( | ||
"cuda:0" if torch.cuda.is_available() else "cpu" | ||
) | ||
|
||
self.frame_skip = 4 | ||
self.total_frames = cfg.collector.total_frames // self.frame_skip | ||
self.frames_per_batch = cfg.collector.frames_per_batch // self.frame_skip | ||
self.mini_batch_size = cfg.loss.mini_batch_size // self.frame_skip | ||
self.test_interval = cfg.logger.test_interval // self.frame_skip | ||
|
||
self.global_actor = global_actor | ||
self.global_critic = global_critic | ||
self.local_actor = self.copy_model(global_actor) | ||
self.local_critic = self.copy_model(global_critic) | ||
|
||
logger = None | ||
if use_logger and cfg.logger.backend: | ||
exp_name = generate_exp_name( | ||
"A3C", f"{cfg.logger.exp_name}_{cfg.env.env_name}" | ||
) | ||
logger = get_logger( | ||
cfg.logger.backend, | ||
logger_name="a3c", | ||
experiment_name=exp_name, | ||
wandb_kwargs={ | ||
"config": dict(cfg), | ||
"project": cfg.logger.project_name, | ||
"group": cfg.logger.group_name, | ||
}, | ||
) | ||
|
||
self.logger = logger | ||
|
||
self.adv_module = GAE( | ||
gamma=cfg.loss.gamma, | ||
lmbda=cfg.loss.gae_lambda, | ||
value_network=self.local_critic, | ||
average_gae=True, | ||
vectorized=not cfg.compile.compile, | ||
device=self.device, | ||
) | ||
self.loss_module = A2CLoss( | ||
actor_network=self.local_actor, | ||
critic_network=self.local_critic, | ||
loss_critic_type=cfg.loss.loss_critic_type, | ||
entropy_coef=cfg.loss.entropy_coef, | ||
critic_coef=cfg.loss.critic_coef, | ||
) | ||
|
||
self.adv_module.set_keys(done="end-of-life", terminated="end-of-life") | ||
self.loss_module.set_keys(done="end-of-life", terminated="end-of-life") | ||
|
||
def copy_model(self, model): | ||
td_params = from_module(model) | ||
td_new_params = td_params.data.clone() | ||
td_new_params = td_new_params.apply( | ||
lambda p0, p1: torch.nn.Parameter(p0) | ||
if isinstance(p1, torch.nn.Parameter) | ||
else p0, | ||
td_params, | ||
) | ||
with td_params.data.to("meta").to_module(model): | ||
# we don't copy any param here | ||
new_model = deepcopy(model) | ||
td_new_params.to_module(new_model) | ||
return new_model | ||
|
||
def update(self, batch, max_grad_norm=None): | ||
if max_grad_norm is None: | ||
max_grad_norm = self.cfg.optim.max_grad_norm | ||
|
||
loss = self.loss_module(batch) | ||
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] | ||
loss_sum.backward() | ||
|
||
for local_param, global_param in zip( | ||
self.local_actor.parameters(), self.global_actor.parameters() | ||
): | ||
global_param._grad = local_param.grad | ||
|
||
for local_param, global_param in zip( | ||
self.local_critic.parameters(), self.global_critic.parameters() | ||
): | ||
global_param._grad = local_param.grad | ||
|
||
gn = torch.nn.utils.clip_grad_norm_( | ||
self.loss_module.parameters(), max_norm=max_grad_norm | ||
) | ||
|
||
self.optimizer.step() | ||
self.optimizer.zero_grad(set_to_none=True) | ||
|
||
return ( | ||
loss.select("loss_critic", "loss_entropy", "loss_objective") | ||
.detach() | ||
.set("grad_norm", gn) | ||
) | ||
|
||
def run(self): | ||
cfg = self.cfg | ||
|
||
collector = SyncDataCollector( | ||
create_env_fn=make_parallel_env( | ||
cfg.env.env_name, | ||
num_envs=cfg.env.num_envs, | ||
device=self.device, | ||
gym_backend=cfg.env.backend, | ||
), | ||
policy=self.local_actor, | ||
frames_per_batch=self.frames_per_batch, | ||
total_frames=self.total_frames, | ||
device=self.device, | ||
storing_device=self.device, | ||
policy_device=self.device, | ||
compile_policy=False, | ||
cudagraph_policy=False, | ||
) | ||
|
||
collected_frames = 0 | ||
num_network_updates = 0 | ||
pbar = tqdm.tqdm(total=self.total_frames) | ||
num_mini_batches = self.frames_per_batch // self.mini_batch_size | ||
total_network_updates = ( | ||
self.total_frames // self.frames_per_batch | ||
) * num_mini_batches | ||
lr = cfg.optim.lr | ||
|
||
c_iter = iter(collector) | ||
total_iter = len(collector) | ||
|
||
for _ in range(total_iter): | ||
data = next(c_iter) | ||
|
||
metrics_to_log = {} | ||
frames_in_batch = data.numel() | ||
collected_frames += self.frames_per_batch * self.frame_skip | ||
pbar.update(frames_in_batch) | ||
|
||
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] | ||
if len(episode_rewards) > 0: | ||
episode_length = data["next", "step_count"][data["next", "terminated"]] | ||
metrics_to_log["train/reward"] = episode_rewards.mean().item() | ||
metrics_to_log[ | ||
"train/episode_length" | ||
] = episode_length.sum().item() / len(episode_length) | ||
|
||
with torch.no_grad(): | ||
data = self.adv_module(data) | ||
data_reshape = data.reshape(-1) | ||
losses = [] | ||
|
||
mini_batches = data_reshape.split(self.mini_batch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To shuffle things a bit I usually rely on a replay buffer instance rather than just splitting the data |
||
for batch in mini_batches: | ||
alpha = 1.0 | ||
if cfg.optim.anneal_lr: | ||
alpha = 1 - (num_network_updates / total_network_updates) | ||
for group in self.optimizer.param_groups: | ||
group["lr"] = lr * alpha | ||
|
||
num_network_updates += 1 | ||
loss = self.update(batch).clone() | ||
losses.append(loss) | ||
|
||
losses = torch.stack(losses).float().mean() | ||
|
||
for key, value in losses.items(): | ||
metrics_to_log[f"train/{key}"] = value.item() | ||
|
||
metrics_to_log["train/lr"] = lr * alpha | ||
|
||
# Logging only on the first worker in the dashboard. | ||
# Alternatively, you can use a distributed logger, or aggregate metrics from all workers. | ||
if self.logger: | ||
for key, value in metrics_to_log.items(): | ||
self.logger.log_scalar(key, value, collected_frames) | ||
collector.shutdown() | ||
|
||
|
||
@hydra.main(config_path="", config_name="config_atari", version_base="1.1") | ||
def main(cfg: DictConfig): # noqa: F821 | ||
|
||
global_actor, global_critic, global_critic_head = make_ppo_models( | ||
cfg.env.env_name, device=cfg.loss.device, gym_backend=cfg.env.backend | ||
) | ||
global_model = nn.ModuleList([global_actor, global_critic_head]) | ||
global_model.share_memory() | ||
optimizer = SharedAdam(global_model.parameters(), lr=cfg.optim.lr) | ||
|
||
num_workers = cfg.multiprocessing.num_workers | ||
|
||
workers = [ | ||
A3CWorker( | ||
f"worker_{i}", | ||
cfg, | ||
global_actor, | ||
global_critic, | ||
optimizer, | ||
use_logger=i == 0, | ||
) | ||
for i in range(num_workers) | ||
] | ||
[w.start() for w in workers] | ||
[w.join() for w in workers] | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
backend: gymnasium | ||
num_envs: 1 | ||
|
||
# collector | ||
collector: | ||
frames_per_batch: 800 | ||
total_frames: 40_000_00 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
project_name: torchrl_example_a3c | ||
group_name: null | ||
exp_name: a3c_atari_training | ||
test_interval: 40_000_000 | ||
num_test_episodes: 3 | ||
video: False | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0001 | ||
eps: 1.0e-8 | ||
weight_decay: 0.0 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
mini_batch_size: 80 | ||
gae_lambda: 0.95 | ||
critic_coef: 0.25 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 | ||
device: | ||
|
||
compile: | ||
compile: False | ||
compile_mode: | ||
cudagraphs: False | ||
|
||
multiprocessing: | ||
num_workers: 16 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain what we do here? What do we use the _grad for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_grad is used to store the gradients for each parameter.
We copy local gradients to the global model so the global model can be updated with the optimizer.
This is a key step in A3C, where multiple workers asynchronously update a shared global model.