-
Notifications
You must be signed in to change notification settings - Fork 440
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Creating a SyncDataCollector (or Collector) with env_device="mps" crashes with TypeError: Cannot convert a MPS Tensor to float64 dtype because gymnasium/MuJoCo environments produce float64 observation and action spaces, and MPS does not support float64 tensors.
To Reproduce
On a Mac with MPS:
import torch
from torchrl.envs import GymEnv
from torchrl.collectors import SyncDataCollector
env = GymEnv("HalfCheetah-v4")
collector = SyncDataCollector(
env,
policy=env.rand_action,
frames_per_batch=10,
total_frames=10,
device="mps",
env_device="mps",
)Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/bensprenger/repositories/torchRL/torchrl/collectors/_base.py", line 1050, in __call__
return super().__call__(*args, **kwargs)
File "/Users/bensprenger/repositories/torchRL/torchrl/collectors/_single.py", line 519, in __init__
self._apply_env_device()
File "/Users/bensprenger/repositories/torchRL/torchrl/collectors/_single.py", line 872, in _apply_env_device
self.env: EnvBase = self.env.to(self.env_device)
File "/Users/bensprenger/repositories/torchRL/torchrl/envs/common.py", line 76, in wrapper
result = func(self, *args, **kwargs)
File "/Users/bensprenger/repositories/torchRL/torchrl/envs/common.py", line 4001, in to
self.__dict__["_output_spec"] = self.output_spec.to(device)
File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 5962, in to
kwargs[key] = value.to(dest)
File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 5962, in to
kwargs[key] = value.to(dest)
File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 3165, in to
return Unbounded(shape=self.shape, device=dest_device, dtype=dest_dtype)
File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 3034, in __call__
instance = super().__call__(*args, **kwargs)
File "/Users/bensprenger/repositories/torchRL/torchrl/data/tensor_specs.py", line 3130, in __init__
torch.full(
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.Expected behavior
A clear and concise description of what you expected to happen.
Screenshots
System info
- Installed from source in a micromamba environment
- macOS (Apple Silicon M-series)
- Python version 3.10
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.10.0 2.2.6 3.10.18 | packaged by conda-forge | (main, Jun 4 2025, 14:46:00) [Clang 18.1.8 ] darwinChecklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working