Skip to content

Commit c43f212

Browse files
lowdy1vmoens
andauthored
[Feature] Add NPU Support for Single Agent (#3229)
Co-authored-by: vmoens <[email protected]>
1 parent 807e9fe commit c43f212

File tree

24 files changed

+147
-180
lines changed

24 files changed

+147
-180
lines changed

docs/source/reference/utils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Set of utility methods that are used internally by the library.
1010
:toctree: generated/
1111
:template: rl_template.rst
1212

13+
get_available_device
1314
implement_for
1415
set_auto_unwrap_transformed_env
1516
auto_unwrap_transformed_env

sota-implementations/a2c/a2c_atari.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def main(cfg: DictConfig): # noqa: F821
2222
from tensordict import from_module
2323
from tensordict.nn import CudaGraphModule
2424

25-
from torchrl._utils import timeit
25+
from torchrl._utils import get_available_device, timeit
2626
from torchrl.collectors import SyncDataCollector
2727
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2828
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
@@ -33,11 +33,9 @@ def main(cfg: DictConfig): # noqa: F821
3333
from torchrl.record.loggers import generate_exp_name, get_logger
3434
from utils_atari import eval_model, make_parallel_env, make_ppo_models
3535

36-
device = cfg.loss.device
37-
if not device:
38-
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
39-
else:
40-
device = torch.device(device)
36+
device = (
37+
torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
38+
)
4139

4240
# Correct for frame_skip
4341
frame_skip = 4

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def main(cfg: DictConfig): # noqa: F821
2323
from tensordict import from_module
2424
from tensordict.nn import CudaGraphModule
2525

26-
from torchrl._utils import timeit
26+
from torchrl._utils import get_available_device, timeit
2727
from torchrl.collectors import SyncDataCollector
2828
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2929
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
@@ -36,11 +36,9 @@ def main(cfg: DictConfig): # noqa: F821
3636

3737
# Define paper hyperparameters
3838

39-
device = cfg.loss.device
40-
if not device:
41-
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
42-
else:
43-
device = torch.device(device)
39+
device = (
40+
torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
41+
)
4442

4543
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
4644
total_network_updates = (

sota-implementations/cql/cql_offline.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import tqdm
2020
from tensordict.nn import CudaGraphModule
21-
from torchrl._utils import timeit
21+
from torchrl._utils import get_available_device, timeit
2222
from torchrl.envs.utils import ExplorationType, set_exploration_type
2323
from torchrl.objectives import group_optimizers
2424
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -55,13 +55,9 @@ def main(cfg: DictConfig): # noqa: F821
5555
# Set seeds
5656
torch.manual_seed(cfg.env.seed)
5757
np.random.seed(cfg.env.seed)
58-
device = cfg.optim.device
59-
if device in ("", None):
60-
if torch.cuda.is_available():
61-
device = "cuda:0"
62-
else:
63-
device = "cpu"
64-
device = torch.device(device)
58+
device = (
59+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
60+
)
6561

6662
# Create replay buffer
6763
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

sota-implementations/cql/cql_online.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tqdm
2222
from tensordict import TensorDict
2323
from tensordict.nn import CudaGraphModule
24-
from torchrl._utils import timeit
24+
from torchrl._utils import get_available_device, timeit
2525
from torchrl.envs.utils import ExplorationType, set_exploration_type
2626
from torchrl.objectives import group_optimizers
2727
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -60,13 +60,9 @@ def main(cfg: DictConfig): # noqa: F821
6060
# Set seeds
6161
torch.manual_seed(cfg.env.seed)
6262
np.random.seed(cfg.env.seed)
63-
device = cfg.optim.device
64-
if device in ("", None):
65-
if torch.cuda.is_available():
66-
device = "cuda:0"
67-
else:
68-
device = "cpu"
69-
device = torch.device(device)
63+
device = (
64+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
65+
)
7066

7167
# Create env
7268
train_env, eval_env = make_environment(

sota-implementations/cql/discrete_cql_offline.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import tqdm
2020
from tensordict.nn import CudaGraphModule
21-
from torchrl._utils import timeit
21+
from torchrl._utils import get_available_device, timeit
2222
from torchrl.envs.utils import ExplorationType, set_exploration_type
2323
from torchrl.record.loggers import generate_exp_name, get_logger
2424
from utils import (
@@ -36,10 +36,9 @@
3636

3737
@hydra.main(version_base="1.1", config_path="", config_name="discrete_offline_config")
3838
def main(cfg): # noqa: F821
39-
device = cfg.optim.device
40-
if device in ("", None):
41-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
42-
device = torch.device(device)
39+
device = (
40+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
41+
)
4342

4443
# Create logger
4544
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)

sota-implementations/cql/discrete_cql_online.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.cuda
2121
import tqdm
2222
from tensordict.nn import CudaGraphModule
23-
from torchrl._utils import timeit
23+
from torchrl._utils import get_available_device, timeit
2424
from torchrl.envs.utils import ExplorationType, set_exploration_type
2525
from torchrl.record.loggers import generate_exp_name, get_logger
2626
from utils import (
@@ -38,13 +38,9 @@
3838

3939
@hydra.main(version_base="1.1", config_path="", config_name="discrete_online_config")
4040
def main(cfg: DictConfig): # noqa: F821
41-
device = cfg.optim.device
42-
if device in ("", None):
43-
if torch.cuda.is_available():
44-
device = "cuda:0"
45-
else:
46-
device = "cpu"
47-
device = torch.device(device)
41+
device = (
42+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
43+
)
4844

4945
# Create logger
5046
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)

sota-implementations/crossq/crossq.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tqdm
2222
from tensordict import TensorDict
2323
from tensordict.nn import CudaGraphModule
24-
from torchrl._utils import timeit
24+
from torchrl._utils import get_available_device, timeit
2525
from torchrl.envs.utils import ExplorationType, set_exploration_type
2626
from torchrl.objectives import group_optimizers
2727
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -40,13 +40,11 @@
4040

4141
@hydra.main(version_base="1.1", config_path=".", config_name="config")
4242
def main(cfg: DictConfig): # noqa: F821
43-
device = cfg.network.device
44-
if device in ("", None):
45-
if torch.cuda.is_available():
46-
device = torch.device("cuda:0")
47-
else:
48-
device = torch.device("cpu")
49-
device = torch.device(device)
43+
device = (
44+
torch.device(cfg.network.device)
45+
if cfg.network.device
46+
else get_available_device()
47+
)
5048

5149
# Create logger
5250
exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)

sota-implementations/ddpg/ddpg.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tqdm
2222
from tensordict import TensorDict
2323
from tensordict.nn import CudaGraphModule
24-
from torchrl._utils import timeit
24+
from torchrl._utils import get_available_device, timeit
2525
from torchrl.envs.utils import ExplorationType, set_exploration_type
2626
from torchrl.objectives import group_optimizers
2727
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -39,21 +39,14 @@
3939

4040
@hydra.main(version_base="1.1", config_path="", config_name="config")
4141
def main(cfg: DictConfig): # noqa: F821
42-
device = cfg.optim.device
43-
if device in ("", None):
44-
if torch.cuda.is_available():
45-
device = "cuda:0"
46-
else:
47-
device = "cpu"
48-
device = torch.device(device)
49-
50-
collector_device = cfg.collector.device
51-
if collector_device in ("", None):
52-
if torch.cuda.is_available():
53-
collector_device = "cuda:0"
54-
else:
55-
collector_device = "cpu"
56-
collector_device = torch.device(collector_device)
42+
device = (
43+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
44+
)
45+
collector_device = (
46+
torch.device(cfg.collector.device)
47+
if cfg.collector.device
48+
else get_available_device()
49+
)
5750

5851
# Create logger
5952
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)

sota-implementations/decision_transformer/dt.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tqdm
1818
from tensordict import TensorDict
1919
from tensordict.nn import CudaGraphModule
20-
from torchrl._utils import logger as torchrl_logger, timeit
20+
from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
2121
from torchrl.envs.libs.gym import set_gym_backend
2222
from torchrl.envs.utils import ExplorationType, set_exploration_type
2323
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
@@ -38,13 +38,9 @@
3838
def main(cfg: DictConfig): # noqa: F821
3939
set_gym_backend(cfg.env.backend).set()
4040

41-
model_device = cfg.optim.device
42-
if model_device in ("", None):
43-
if torch.cuda.is_available():
44-
model_device = "cuda:0"
45-
else:
46-
model_device = "cpu"
47-
model_device = torch.device(model_device)
41+
model_device = (
42+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
43+
)
4844

4945
# Set seeds
5046
torch.manual_seed(cfg.env.seed)

0 commit comments

Comments
 (0)