|
21 | 21 | import tqdm |
22 | 22 | from tensordict import TensorDict |
23 | 23 | from tensordict.nn import CudaGraphModule |
24 | | -from torchrl._utils import timeit |
| 24 | +from torchrl._utils import get_available_device, timeit |
25 | 25 | from torchrl.envs.utils import ExplorationType, set_exploration_type |
26 | 26 | from torchrl.objectives import group_optimizers |
27 | 27 | from torchrl.record.loggers import generate_exp_name, get_logger |
|
39 | 39 |
|
40 | 40 | @hydra.main(version_base="1.1", config_path="", config_name="config") |
41 | 41 | def main(cfg: DictConfig): # noqa: F821 |
42 | | - device = cfg.optim.device |
43 | | - if device in ("", None): |
44 | | - if torch.cuda.is_available(): |
45 | | - device = "cuda:0" |
46 | | - else: |
47 | | - device = "cpu" |
48 | | - device = torch.device(device) |
49 | | - |
50 | | - collector_device = cfg.collector.device |
51 | | - if collector_device in ("", None): |
52 | | - if torch.cuda.is_available(): |
53 | | - collector_device = "cuda:0" |
54 | | - else: |
55 | | - collector_device = "cpu" |
56 | | - collector_device = torch.device(collector_device) |
| 42 | + device = ( |
| 43 | + torch.device(cfg.optim.device) if cfg.optim.device else get_available_device() |
| 44 | + ) |
| 45 | + collector_device = ( |
| 46 | + torch.device(cfg.collector.device) |
| 47 | + if cfg.collector.device |
| 48 | + else get_available_device() |
| 49 | + ) |
57 | 50 |
|
58 | 51 | # Create logger |
59 | 52 | exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) |
|
0 commit comments