From 4e7e0b8e59a90a461fe93a34e7b9528e6a3972eb Mon Sep 17 00:00:00 2001 From: Viktor Zatorskyi Date: Mon, 15 Sep 2025 15:49:24 -0700 Subject: [PATCH 1/4] Ensure tensors use default device in torch policy and utils Updated tensor creation in torch_policy.py and utils.py to explicitly use the default device, ensuring consistency across devices (CPU/GPU). Also set torch config in TrainerController to use the default device. This improves device management and prevents potential device mismatch errors. --- .../mlagents/trainers/policy/torch_policy.py | 17 +++++++++++------ .../mlagents/trainers/torch_entities/utils.py | 11 ++++++++--- .../mlagents/trainers/trainer_controller.py | 4 ++++ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index fceacda6e9..f7cdb75c4e 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -69,13 +69,17 @@ def export_memory_size(self) -> int: return self._export_m_size def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray: + device = default_device() mask = None if self.behavior_spec.action_spec.discrete_size > 0: num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches) - mask = torch.ones([len(decision_requests), num_discrete_flat]) + mask = torch.ones( + [len(decision_requests), num_discrete_flat], device=device + ) if decision_requests.action_mask is not None: mask = torch.as_tensor( - 1 - np.concatenate(decision_requests.action_mask, axis=1) + 1 - np.concatenate(decision_requests.action_mask, axis=1), + device=device, ) return mask @@ -91,11 +95,12 @@ def evaluate( """ obs = decision_requests.obs masks = self._extract_masks(decision_requests) - tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs] + device = default_device() + tensor_obs = [torch.as_tensor(np_ob, device=device) for np_ob in obs] - memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze( - 0 - ) + memories = torch.as_tensor( + self.retrieve_memories(global_agent_ids), device=device + ).unsqueeze(0) with torch.no_grad(): action, run_out, memories = self.actor.get_action_and_stats( tensor_obs, masks=masks, memories=memories diff --git a/ml-agents/mlagents/trainers/torch_entities/utils.py b/ml-agents/mlagents/trainers/torch_entities/utils.py index d5381cbecb..6efb571679 100644 --- a/ml-agents/mlagents/trainers/torch_entities/utils.py +++ b/ml-agents/mlagents/trainers/torch_entities/utils.py @@ -1,5 +1,5 @@ from typing import List, Optional, Tuple, Dict -from mlagents.torch_utils import torch, nn +from mlagents.torch_utils import torch, nn, default_device from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization import numpy as np @@ -233,7 +233,10 @@ def list_to_tensor( Converts a list of numpy arrays into a tensor. MUCH faster than calling as_tensor on the list directly. """ - return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype) + device = default_device() + return torch.as_tensor( + np.asanyarray(ndarray_list), dtype=dtype, device=device + ) @staticmethod def list_to_tensor_list( @@ -243,8 +246,10 @@ def list_to_tensor_list( Converts a list of numpy arrays into a list of tensors. MUCH faster than calling as_tensor on the list directly. """ + device = default_device() return [ - torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list + torch.as_tensor(np.asanyarray(_arr), dtype=dtype, device=device) + for _arr in ndarray_list ] @staticmethod diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 69da1e5694..a5c97b29af 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -28,6 +28,7 @@ from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers from mlagents.trainers.agent_processor import AgentManager from mlagents import torch_utils +from mlagents.trainers.settings import TorchSettings from mlagents.torch_utils.globals import get_rank @@ -292,6 +293,9 @@ def join_threads(self, timeout_seconds: float = 1.0) -> None: merge_gauges(thread_timer_stack.gauges) def trainer_update_func(self, trainer: Trainer) -> None: + torch_utils.set_torch_config( + TorchSettings(device=str(torch_utils.default_device())) + ) while not self.kill_trainers: with hierarchical_timer("trainer_advance"): trainer.advance() From 87d6c9c22d81841c71de377e4bb109a7b77292df Mon Sep 17 00:00:00 2001 From: Viktor Zatorskyi Date: Mon, 15 Sep 2025 16:36:19 -0700 Subject: [PATCH 2/4] Ensure tensors are created on default device Updated tensor creation in optimizers, reward providers, and network normalization to explicitly use the configured default_device. Removed redundant set_torch_config call in trainer_controller to avoid interfering with PyTorch's global device context. These changes improve device consistency and prevent device mismatch errors in multi-threaded or multi-device training scenarios. --- .../mlagents/trainers/optimizer/torch_optimizer.py | 4 ++-- .../mlagents/trainers/poca/optimizer_torch.py | 4 ++-- ml-agents/mlagents/trainers/policy/torch_policy.py | 1 + .../reward_providers/gail_reward_provider.py | 14 +++++++------- .../mlagents/trainers/torch_entities/networks.py | 8 +++++--- ml-agents/mlagents/trainers/trainer_controller.py | 6 +++--- 6 files changed, 20 insertions(+), 17 deletions(-) diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 9ee3845515..40663b38bd 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -1,5 +1,5 @@ from typing import Dict, Optional, Tuple, List -from mlagents.torch_utils import torch +from mlagents.torch_utils import torch, default_device import numpy as np from collections import defaultdict @@ -162,7 +162,7 @@ def get_trajectory_value_estimates( memory = self.critic_memory_dict[agent_id] else: memory = ( - torch.zeros((1, 1, self.critic.memory_size)) + torch.zeros((1, 1, self.critic.memory_size), device=default_device()) if self.policy.use_recurrent else None ) diff --git a/ml-agents/mlagents/trainers/poca/optimizer_torch.py b/ml-agents/mlagents/trainers/poca/optimizer_torch.py index de17f3d3b2..9f5ecf11d3 100644 --- a/ml-agents/mlagents/trainers/poca/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/poca/optimizer_torch.py @@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates( _init_baseline_mem = self.baseline_memory_dict[agent_id] else: _init_value_mem = ( - torch.zeros((1, 1, self.critic.memory_size)) + torch.zeros((1, 1, self.critic.memory_size), device=default_device()) if self.policy.use_recurrent else None ) _init_baseline_mem = ( - torch.zeros((1, 1, self.critic.memory_size)) + torch.zeros((1, 1, self.critic.memory_size), device=default_device()) if self.policy.use_recurrent else None ) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index f7cdb75c4e..8eb50c24e4 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -109,6 +109,7 @@ def evaluate( if "log_probs" in run_out: run_out["log_probs"] = run_out["log_probs"].to_log_probs_tuple() if "entropy" in run_out: + # Ensure entropy is detached and moved to CPU before NumPy conversion run_out["entropy"] = ModelUtils.to_numpy(run_out["entropy"]) if self.use_recurrent: run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0) diff --git a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py index 0ae77ba143..73785e331d 100644 --- a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py @@ -143,7 +143,7 @@ def compute_estimate( if self._settings.use_actions: actions = self.get_action_input(mini_batch) dones = torch.as_tensor( - mini_batch[BufferKey.DONE], dtype=torch.float + mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) action_inputs = torch.cat([actions, dones], dim=1) hidden, _ = self.encoder(inputs, action_inputs) @@ -162,7 +162,7 @@ def compute_loss( """ Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. """ - total_loss = torch.zeros(1) + total_loss = torch.zeros(1, device=default_device()) stats_dict: Dict[str, np.ndarray] = {} policy_estimate, policy_mu = self.compute_estimate( policy_batch, use_vail_noise=True @@ -219,21 +219,21 @@ def compute_gradient_magnitude( expert_inputs = self.get_state_inputs(expert_batch) interp_inputs = [] for policy_input, expert_input in zip(policy_inputs, expert_inputs): - obs_epsilon = torch.rand(policy_input.shape) + obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device) interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input interp_input.requires_grad = True # For gradient calculation interp_inputs.append(interp_input) if self._settings.use_actions: policy_action = self.get_action_input(policy_batch) expert_action = self.get_action_input(expert_batch) - action_epsilon = torch.rand(policy_action.shape) + action_epsilon = torch.rand(policy_action.shape, device=policy_action.device) policy_dones = torch.as_tensor( - policy_batch[BufferKey.DONE], dtype=torch.float + policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) expert_dones = torch.as_tensor( - expert_batch[BufferKey.DONE], dtype=torch.float + expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) - dones_epsilon = torch.rand(policy_dones.shape) + dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device) action_inputs = torch.cat( [ action_epsilon * policy_action diff --git a/ml-agents/mlagents/trainers/torch_entities/networks.py b/ml-agents/mlagents/trainers/torch_entities/networks.py index 555268075c..ec302e6432 100644 --- a/ml-agents/mlagents/trainers/torch_entities/networks.py +++ b/ml-agents/mlagents/trainers/torch_entities/networks.py @@ -1,7 +1,7 @@ from typing import Callable, List, Dict, Tuple, Optional, Union, Any import abc -from mlagents.torch_utils import torch, nn +from mlagents.torch_utils import torch, nn, default_device from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType from mlagents.trainers.torch_entities.action_model import ActionModel @@ -86,8 +86,10 @@ def total_goal_enc_size(self) -> int: def update_normalization(self, buffer: AgentBuffer) -> None: obs = ObsUtil.from_buffer(buffer, len(self.processors)) for vec_input, enc in zip(obs, self.processors): - if isinstance(enc, VectorInput): - enc.update_normalization(torch.as_tensor(vec_input.to_ndarray())) + if isinstance(enc, VectorInput): + enc.update_normalization( + torch.as_tensor(vec_input.to_ndarray(), device=default_device()) + ) def copy_normalization(self, other_encoder: "ObservationEncoder") -> None: if self.normalize: diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index a5c97b29af..eb31f1a1c1 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -293,9 +293,9 @@ def join_threads(self, timeout_seconds: float = 1.0) -> None: merge_gauges(thread_timer_stack.gauges) def trainer_update_func(self, trainer: Trainer) -> None: - torch_utils.set_torch_config( - TorchSettings(device=str(torch_utils.default_device())) - ) + # Note: Avoid calling torch.set_default_device in worker threads; it can + # interfere with PyTorch's global device context manager. The policy and + # optimizer code explicitly places tensors on the configured default_device(). while not self.kill_trainers: with hierarchical_timer("trainer_advance"): trainer.advance() From 2486c58b88406b8c8b253680bb728008f37b2609 Mon Sep 17 00:00:00 2001 From: Viktor Zatorskyi Date: Mon, 15 Sep 2025 16:50:02 -0700 Subject: [PATCH 3/4] Cleanup --- ml-agents/mlagents/trainers/policy/torch_policy.py | 1 - ml-agents/mlagents/trainers/trainer_controller.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 8eb50c24e4..f7cdb75c4e 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -109,7 +109,6 @@ def evaluate( if "log_probs" in run_out: run_out["log_probs"] = run_out["log_probs"].to_log_probs_tuple() if "entropy" in run_out: - # Ensure entropy is detached and moved to CPU before NumPy conversion run_out["entropy"] = ModelUtils.to_numpy(run_out["entropy"]) if self.use_recurrent: run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0) diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index eb31f1a1c1..69da1e5694 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -28,7 +28,6 @@ from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers from mlagents.trainers.agent_processor import AgentManager from mlagents import torch_utils -from mlagents.trainers.settings import TorchSettings from mlagents.torch_utils.globals import get_rank @@ -293,9 +292,6 @@ def join_threads(self, timeout_seconds: float = 1.0) -> None: merge_gauges(thread_timer_stack.gauges) def trainer_update_func(self, trainer: Trainer) -> None: - # Note: Avoid calling torch.set_default_device in worker threads; it can - # interfere with PyTorch's global device context manager. The policy and - # optimizer code explicitly places tensors on the configured default_device(). while not self.kill_trainers: with hierarchical_timer("trainer_advance"): trainer.advance() From 202629dbc9113973ec64ce59ac2a67d5e3955a76 Mon Sep 17 00:00:00 2001 From: maryam-zia Date: Tue, 16 Sep 2025 11:23:41 -0400 Subject: [PATCH 4/4] Black reformatting --- .../components/reward_providers/gail_reward_provider.py | 4 +++- ml-agents/mlagents/trainers/torch_entities/networks.py | 8 ++++---- ml-agents/mlagents/trainers/torch_entities/utils.py | 4 +--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py index 73785e331d..906f9e32c1 100644 --- a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py @@ -226,7 +226,9 @@ def compute_gradient_magnitude( if self._settings.use_actions: policy_action = self.get_action_input(policy_batch) expert_action = self.get_action_input(expert_batch) - action_epsilon = torch.rand(policy_action.shape, device=policy_action.device) + action_epsilon = torch.rand( + policy_action.shape, device=policy_action.device + ) policy_dones = torch.as_tensor( policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) diff --git a/ml-agents/mlagents/trainers/torch_entities/networks.py b/ml-agents/mlagents/trainers/torch_entities/networks.py index ec302e6432..196d23698f 100644 --- a/ml-agents/mlagents/trainers/torch_entities/networks.py +++ b/ml-agents/mlagents/trainers/torch_entities/networks.py @@ -86,10 +86,10 @@ def total_goal_enc_size(self) -> int: def update_normalization(self, buffer: AgentBuffer) -> None: obs = ObsUtil.from_buffer(buffer, len(self.processors)) for vec_input, enc in zip(obs, self.processors): - if isinstance(enc, VectorInput): - enc.update_normalization( - torch.as_tensor(vec_input.to_ndarray(), device=default_device()) - ) + if isinstance(enc, VectorInput): + enc.update_normalization( + torch.as_tensor(vec_input.to_ndarray(), device=default_device()) + ) def copy_normalization(self, other_encoder: "ObservationEncoder") -> None: if self.normalize: diff --git a/ml-agents/mlagents/trainers/torch_entities/utils.py b/ml-agents/mlagents/trainers/torch_entities/utils.py index 6efb571679..7f2cea40ab 100644 --- a/ml-agents/mlagents/trainers/torch_entities/utils.py +++ b/ml-agents/mlagents/trainers/torch_entities/utils.py @@ -234,9 +234,7 @@ def list_to_tensor( calling as_tensor on the list directly. """ device = default_device() - return torch.as_tensor( - np.asanyarray(ndarray_list), dtype=dtype, device=device - ) + return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype, device=device) @staticmethod def list_to_tensor_list(