Skip to content

Commit f3018a9

Browse files
author
Ervin T
authored
[bug-fix] Move POCA critic to default device (#5124) (#5131)
* Move critic to default device * Make sure to clone onto default device * Add some debug stuff * Some more debug * Fix issue * Fix bool tensor too
1 parent f631fa6 commit f3018a9

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

ml-agents/mlagents/trainers/poca/optimizer_torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
)
55
import numpy as np
66
import math
7-
from mlagents.torch_utils import torch
7+
from mlagents.torch_utils import torch, default_device
88

99
from mlagents.trainers.buffer import (
1010
AgentBuffer,
@@ -155,6 +155,8 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
155155
network_settings=trainer_settings.network_settings,
156156
action_spec=policy.behavior_spec.action_spec,
157157
)
158+
# Move to GPU if needed
159+
self._critic.to(default_device())
158160

159161
params = list(self.policy.actor.parameters()) + list(self.critic.parameters())
160162
self.hyperparameters: POCASettings = cast(

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
268268
[_obs.flatten(start_dim=1)[:, 0] for _obs in only_first_obs], dim=1
269269
)
270270
# Get the mask from NaNs
271-
attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor)
271+
attn_mask = only_first_obs_flat.isnan().float()
272272
return attn_mask
273273

274274
def _copy_and_remove_nans_from_obs(
@@ -283,7 +283,7 @@ def _copy_and_remove_nans_from_obs(
283283
for obs in single_agent_obs:
284284
new_obs = obs.clone()
285285
new_obs[
286-
attention_mask.type(torch.BoolTensor)[:, i_agent], ::
286+
attention_mask.bool()[:, i_agent], ::
287287
] = 0.0 # Remoove NaNs fast
288288
no_nan_obs.append(new_obs)
289289
obs_with_no_nans.append(no_nan_obs)

0 commit comments

Comments
 (0)