Skip to content

[BUG] MPS device crashes on float64 specs when moving environments with .to('mps') #3549

@bsprenger

Description

@bsprenger

Describe the bug

Creating a SyncDataCollector (or Collector) with env_device="mps" crashes with TypeError: Cannot convert a MPS Tensor to float64 dtype because gymnasium/MuJoCo environments produce float64 observation and action spaces, and MPS does not support float64 tensors.

To Reproduce

On a Mac with MPS:

import torch
from torchrl.envs import GymEnv
from torchrl.collectors import SyncDataCollector

env = GymEnv("HalfCheetah-v4")

collector = SyncDataCollector(
    env,
    policy=env.rand_action,
    frames_per_batch=10,
    total_frames=10,
    device="mps",
    env_device="mps",
)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/bensprenger/repositories/torchRL/torchrl/collectors/_base.py", line 1050, in __call__
    return super().__call__(*args, **kwargs)
  File "/Users/bensprenger/repositories/torchRL/torchrl/collectors/_single.py", line 519, in __init__
    self._apply_env_device()
  File "/Users/bensprenger/repositories/torchRL/torchrl/collectors/_single.py", line 872, in _apply_env_device
    self.env: EnvBase = self.env.to(self.env_device)
  File "/Users/bensprenger/repositories/torchRL/torchrl/envs/common.py", line 76, in wrapper
    result = func(self, *args, **kwargs)
  File "/Users/bensprenger/repositories/torchRL/torchrl/envs/common.py", line 4001, in to
    self.__dict__["_output_spec"] = self.output_spec.to(device)
  File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 5962, in to
    kwargs[key] = value.to(dest)
  File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 5962, in to
    kwargs[key] = value.to(dest)
  File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 3165, in to
    return Unbounded(shape=self.shape, device=dest_device, dtype=dest_dtype)
  File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 3034, in __call__
    instance = super().__call__(*args, **kwargs)
  File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 3130, in __init__
    torch.full(
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

System info

  • Installed from source in a micromamba environment
  • macOS (Apple Silicon M-series)
  • Python version 3.10
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.10.0 2.2.6 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:46:00) [Clang 18.1.8 ] darwin

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions