Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Optional, Tuple, List
from mlagents.torch_utils import torch
from mlagents.torch_utils import torch, default_device
import numpy as np
from collections import defaultdict

Expand Down Expand Up @@ -162,7 +162,7 @@ def get_trajectory_value_estimates(
memory = self.critic_memory_dict[agent_id]
else:
memory = (
torch.zeros((1, 1, self.critic.memory_size))
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
if self.policy.use_recurrent
else None
)
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/poca/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates(
_init_baseline_mem = self.baseline_memory_dict[agent_id]
else:
_init_value_mem = (
torch.zeros((1, 1, self.critic.memory_size))
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
if self.policy.use_recurrent
else None
)
_init_baseline_mem = (
torch.zeros((1, 1, self.critic.memory_size))
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
if self.policy.use_recurrent
else None
)
Expand Down
17 changes: 11 additions & 6 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,17 @@ def export_memory_size(self) -> int:
return self._export_m_size

def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray:
device = default_device()
mask = None
if self.behavior_spec.action_spec.discrete_size > 0:
num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches)
mask = torch.ones([len(decision_requests), num_discrete_flat])
mask = torch.ones(
[len(decision_requests), num_discrete_flat], device=device
)
if decision_requests.action_mask is not None:
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)
1 - np.concatenate(decision_requests.action_mask, axis=1),
device=device,
)
return mask

Expand All @@ -91,11 +95,12 @@ def evaluate(
"""
obs = decision_requests.obs
masks = self._extract_masks(decision_requests)
tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs]
device = default_device()
tensor_obs = [torch.as_tensor(np_ob, device=device) for np_ob in obs]

memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(
0
)
memories = torch.as_tensor(
self.retrieve_memories(global_agent_ids), device=device
).unsqueeze(0)
with torch.no_grad():
action, run_out, memories = self.actor.get_action_and_stats(
tensor_obs, masks=masks, memories=memories
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def compute_estimate(
if self._settings.use_actions:
actions = self.get_action_input(mini_batch)
dones = torch.as_tensor(
mini_batch[BufferKey.DONE], dtype=torch.float
mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
).unsqueeze(1)
action_inputs = torch.cat([actions, dones], dim=1)
hidden, _ = self.encoder(inputs, action_inputs)
Expand All @@ -162,7 +162,7 @@ def compute_loss(
"""
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
"""
total_loss = torch.zeros(1)
total_loss = torch.zeros(1, device=default_device())
stats_dict: Dict[str, np.ndarray] = {}
policy_estimate, policy_mu = self.compute_estimate(
policy_batch, use_vail_noise=True
Expand Down Expand Up @@ -219,21 +219,21 @@ def compute_gradient_magnitude(
expert_inputs = self.get_state_inputs(expert_batch)
interp_inputs = []
for policy_input, expert_input in zip(policy_inputs, expert_inputs):
obs_epsilon = torch.rand(policy_input.shape)
obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device)
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
interp_input.requires_grad = True # For gradient calculation
interp_inputs.append(interp_input)
if self._settings.use_actions:
policy_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(expert_batch)
action_epsilon = torch.rand(policy_action.shape)
action_epsilon = torch.rand(policy_action.shape, device=policy_action.device)
policy_dones = torch.as_tensor(
policy_batch[BufferKey.DONE], dtype=torch.float
policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
).unsqueeze(1)
expert_dones = torch.as_tensor(
expert_batch[BufferKey.DONE], dtype=torch.float
expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
).unsqueeze(1)
dones_epsilon = torch.rand(policy_dones.shape)
dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device)
action_inputs = torch.cat(
[
action_epsilon * policy_action
Expand Down
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/torch_entities/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, List, Dict, Tuple, Optional, Union, Any
import abc

from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn, default_device

from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType
from mlagents.trainers.torch_entities.action_model import ActionModel
Expand Down Expand Up @@ -86,8 +86,10 @@ def total_goal_enc_size(self) -> int:
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input.to_ndarray()))
if isinstance(enc, VectorInput):
enc.update_normalization(
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
)

def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
if self.normalize:
Expand Down
11 changes: 8 additions & 3 deletions ml-agents/mlagents/trainers/torch_entities/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Optional, Tuple, Dict
from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn, default_device
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
import numpy as np

Expand Down Expand Up @@ -233,7 +233,10 @@ def list_to_tensor(
Converts a list of numpy arrays into a tensor. MUCH faster than
calling as_tensor on the list directly.
"""
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
device = default_device()
return torch.as_tensor(
np.asanyarray(ndarray_list), dtype=dtype, device=device
)

@staticmethod
def list_to_tensor_list(
Expand All @@ -243,8 +246,10 @@ def list_to_tensor_list(
Converts a list of numpy arrays into a list of tensors. MUCH faster than
calling as_tensor on the list directly.
"""
device = default_device()
return [
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
torch.as_tensor(np.asanyarray(_arr), dtype=dtype, device=device)
for _arr in ndarray_list
]

@staticmethod
Expand Down
Loading