diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index 24dc45cca3..ae2752de89 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -53,6 +53,8 @@ def set_torch_config(torch_settings: TorchSettings) -> None: if _device.type == "cuda": torch.set_default_tensor_type(torch.cuda.FloatTensor) + elif _device.type == 'mps': + torch.set_default_device(device_str) else: torch.set_default_tensor_type(torch.FloatTensor) logger.debug(f"default Torch device: {_device}")