Skip to content

Commit 87d6c9c

Browse files
committed
Ensure tensors are created on default device
Updated tensor creation in optimizers, reward providers, and network normalization to explicitly use the configured default_device. Removed redundant set_torch_config call in trainer_controller to avoid interfering with PyTorch's global device context. These changes improve device consistency and prevent device mismatch errors in multi-threaded or multi-device training scenarios.
1 parent 4e7e0b8 commit 87d6c9c

File tree

6 files changed

+20
-17
lines changed

6 files changed

+20
-17
lines changed

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Dict, Optional, Tuple, List
2-
from mlagents.torch_utils import torch
2+
from mlagents.torch_utils import torch, default_device
33
import numpy as np
44
from collections import defaultdict
55

@@ -162,7 +162,7 @@ def get_trajectory_value_estimates(
162162
memory = self.critic_memory_dict[agent_id]
163163
else:
164164
memory = (
165-
torch.zeros((1, 1, self.critic.memory_size))
165+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
166166
if self.policy.use_recurrent
167167
else None
168168
)

ml-agents/mlagents/trainers/poca/optimizer_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates(
608608
_init_baseline_mem = self.baseline_memory_dict[agent_id]
609609
else:
610610
_init_value_mem = (
611-
torch.zeros((1, 1, self.critic.memory_size))
611+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
612612
if self.policy.use_recurrent
613613
else None
614614
)
615615
_init_baseline_mem = (
616-
torch.zeros((1, 1, self.critic.memory_size))
616+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
617617
if self.policy.use_recurrent
618618
else None
619619
)

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def evaluate(
109109
if "log_probs" in run_out:
110110
run_out["log_probs"] = run_out["log_probs"].to_log_probs_tuple()
111111
if "entropy" in run_out:
112+
# Ensure entropy is detached and moved to CPU before NumPy conversion
112113
run_out["entropy"] = ModelUtils.to_numpy(run_out["entropy"])
113114
if self.use_recurrent:
114115
run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)

ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def compute_estimate(
143143
if self._settings.use_actions:
144144
actions = self.get_action_input(mini_batch)
145145
dones = torch.as_tensor(
146-
mini_batch[BufferKey.DONE], dtype=torch.float
146+
mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
147147
).unsqueeze(1)
148148
action_inputs = torch.cat([actions, dones], dim=1)
149149
hidden, _ = self.encoder(inputs, action_inputs)
@@ -162,7 +162,7 @@ def compute_loss(
162162
"""
163163
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
164164
"""
165-
total_loss = torch.zeros(1)
165+
total_loss = torch.zeros(1, device=default_device())
166166
stats_dict: Dict[str, np.ndarray] = {}
167167
policy_estimate, policy_mu = self.compute_estimate(
168168
policy_batch, use_vail_noise=True
@@ -219,21 +219,21 @@ def compute_gradient_magnitude(
219219
expert_inputs = self.get_state_inputs(expert_batch)
220220
interp_inputs = []
221221
for policy_input, expert_input in zip(policy_inputs, expert_inputs):
222-
obs_epsilon = torch.rand(policy_input.shape)
222+
obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device)
223223
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
224224
interp_input.requires_grad = True # For gradient calculation
225225
interp_inputs.append(interp_input)
226226
if self._settings.use_actions:
227227
policy_action = self.get_action_input(policy_batch)
228228
expert_action = self.get_action_input(expert_batch)
229-
action_epsilon = torch.rand(policy_action.shape)
229+
action_epsilon = torch.rand(policy_action.shape, device=policy_action.device)
230230
policy_dones = torch.as_tensor(
231-
policy_batch[BufferKey.DONE], dtype=torch.float
231+
policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
232232
).unsqueeze(1)
233233
expert_dones = torch.as_tensor(
234-
expert_batch[BufferKey.DONE], dtype=torch.float
234+
expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
235235
).unsqueeze(1)
236-
dones_epsilon = torch.rand(policy_dones.shape)
236+
dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device)
237237
action_inputs = torch.cat(
238238
[
239239
action_epsilon * policy_action

ml-agents/mlagents/trainers/torch_entities/networks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, List, Dict, Tuple, Optional, Union, Any
22
import abc
33

4-
from mlagents.torch_utils import torch, nn
4+
from mlagents.torch_utils import torch, nn, default_device
55

66
from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType
77
from mlagents.trainers.torch_entities.action_model import ActionModel
@@ -86,8 +86,10 @@ def total_goal_enc_size(self) -> int:
8686
def update_normalization(self, buffer: AgentBuffer) -> None:
8787
obs = ObsUtil.from_buffer(buffer, len(self.processors))
8888
for vec_input, enc in zip(obs, self.processors):
89-
if isinstance(enc, VectorInput):
90-
enc.update_normalization(torch.as_tensor(vec_input.to_ndarray()))
89+
if isinstance(enc, VectorInput):
90+
enc.update_normalization(
91+
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
92+
)
9193

9294
def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
9395
if self.normalize:

ml-agents/mlagents/trainers/trainer_controller.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@ def join_threads(self, timeout_seconds: float = 1.0) -> None:
293293
merge_gauges(thread_timer_stack.gauges)
294294

295295
def trainer_update_func(self, trainer: Trainer) -> None:
296-
torch_utils.set_torch_config(
297-
TorchSettings(device=str(torch_utils.default_device()))
298-
)
296+
# Note: Avoid calling torch.set_default_device in worker threads; it can
297+
# interfere with PyTorch's global device context manager. The policy and
298+
# optimizer code explicitly places tensors on the configured default_device().
299299
while not self.kill_trainers:
300300
with hierarchical_timer("trainer_advance"):
301301
trainer.advance()

0 commit comments

Comments
 (0)