Skip to content

Commit ac59796

Browse files
[Feature] Buffer device (#87)
* amend * amend * empty
1 parent 0a28a16 commit ac59796

File tree

6 files changed

+18
-8
lines changed

6 files changed

+18
-8
lines changed

benchmarl/algorithms/common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self, experiment):
4141
self.experiment = experiment
4242

4343
self.device: DEVICE_TYPING = experiment.config.train_device
44+
self.buffer_device: DEVICE_TYPING = experiment.config.buffer_device
4445
self.experiment_config = experiment.config
4546
self.model_config = experiment.model_config
4647
self.critic_model_config = experiment.critic_model_config
@@ -141,11 +142,12 @@ def get_replay_buffer(
141142
"""
142143
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
143144
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
144-
storing_device = self.device
145145
sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()
146-
147146
return TensorDictReplayBuffer(
148-
storage=LazyTensorStorage(memory_size, device=storing_device),
147+
storage=LazyTensorStorage(
148+
memory_size,
149+
device=self.device if self.on_policy else self.buffer_device,
150+
),
149151
sampler=sampler,
150152
batch_size=sampling_size,
151153
priority_key=(group, "td_error"),

benchmarl/conf/experiment/base_experiment.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ defaults:
66
sampling_device: "cpu"
77
# The device for training (e.g. cuda)
88
train_device: "cpu"
9+
# The device for the replay buffer of off-policy algorithms (e.g. cuda)
10+
buffer_device: "cpu"
911

1012
# Whether to share the parameters of the policy within agent groups
1113
share_policy_params: True

benchmarl/experiment/experiment.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class ExperimentConfig:
5050

5151
sampling_device: str = MISSING
5252
train_device: str = MISSING
53+
buffer_device: str = MISSING
5354

5455
share_policy_params: bool = MISSING
5556
prefer_continuous_actions: bool = MISSING
@@ -462,9 +463,9 @@ def _setup_collector(self):
462463
storing_device=self.config.train_device,
463464
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
464465
total_frames=self.config.get_max_n_frames(self.on_policy),
465-
init_random_frames=self.config.off_policy_init_random_frames
466-
if not self.on_policy
467-
else 0,
466+
init_random_frames=(
467+
self.config.off_policy_init_random_frames if not self.on_policy else 0
468+
),
468469
)
469470

470471
def _setup_name(self):
@@ -647,7 +648,7 @@ def _get_excluded_keys(self, group: str):
647648
return excluded_keys
648649

649650
def _optimizer_loop(self, group: str) -> TensorDictBase:
650-
subdata = self.replay_buffers[group].sample()
651+
subdata = self.replay_buffers[group].sample().to(self.config.train_device)
651652
loss_vals = self.losses[group](subdata)
652653
training_td = loss_vals.detach()
653654
loss_vals = self.algorithm.process_loss_vals(group, loss_vals)

benchmarl/hydra_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from omegaconf import DictConfig, OmegaConf
1818

1919

20-
def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
20+
def load_experiment_from_hydra(
21+
cfg: DictConfig, task_name: str, callbacks=()
22+
) -> Experiment:
2123
"""Creates an :class:`~benchmarl.experiment.Experiment` from hydra config.
2224
2325
Args:
@@ -41,6 +43,7 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
4143
critic_model_config=critic_model_config,
4244
seed=cfg.seed,
4345
config=experiment_config,
46+
callbacks=callbacks,
4447
)
4548

4649

fine_tuned/smacv2/conf/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ seed: 0
1616
experiment:
1717
sampling_device: "cpu"
1818
train_device: "cuda"
19+
buffer_device: "cuda"
1920

2021
share_policy_params: True
2122
prefer_continuous_actions: True

fine_tuned/vmas/conf/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ experiment:
1717

1818
sampling_device: "cuda"
1919
train_device: "cuda"
20+
buffer_device: "cuda"
2021

2122
share_policy_params: True
2223
prefer_continuous_actions: True

0 commit comments

Comments
 (0)