Skip to content

Commit e5f9a9c

Browse files
committed
no TD state_dict
1 parent 6d581f0 commit e5f9a9c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchrl/collectors/_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,10 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
456456

457457
state_dict = inner_collector.state_dict()
458458
# Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility
459-
# CPU and CUDA tensors are already shareable and don't need conversion
459+
# CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth)
460460
state_dict = tree_map(_map_to_cpu_if_needed, state_dict)
461461
state_dict = TensorDict(state_dict)
462-
state_dict = state_dict.clone().apply(_cast, state_dict)
462+
state_dict = state_dict.clone().apply(_cast, state_dict).to_dict()
463463
pipe_child.send((state_dict, "state_dict"))
464464
has_timed_out = False
465465
continue

0 commit comments

Comments
 (0)