Skip to content

Commit 9b04929

Browse files
committed
[Test] Use spawn in older pytorch
ghstack-source-id: 0800d3e Pull-Request: #3285
1 parent d781f9e commit 9b04929

File tree

8 files changed

+135
-27
lines changed

8 files changed

+135
-27
lines changed

test/test_collectors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,11 @@ def make_env():
616616
reason="Nested spawned multiprocessed is currently failing in python 3.11. "
617617
"See https://github.com/python/cpython/pull/108568 for info and fix.",
618618
)
619+
@pytest.mark.skipif(
620+
TORCH_VERSION < version.parse("2.8.0"),
621+
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
622+
"when using spawn multiprocessing start method with file_system sharing strategy.",
623+
)
619624
@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv")
620625
@pytest.mark.parametrize("static_seed", [True, False])
621626
def test_collector_vecnorm_envcreator(self, static_seed):

test/test_transforms.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9673,6 +9673,11 @@ def _test_vecnorm_subproc_auto(
96739673
def rename_t(self):
96749674
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])
96759675

9676+
@pytest.mark.skipif(
9677+
TORCH_VERSION < version.parse("2.8.0"),
9678+
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
9679+
"when using spawn multiprocessing start method.",
9680+
)
96769681
@retry(AssertionError, tries=10, delay=0)
96779682
@pytest.mark.parametrize("nprc", [2, 5])
96789683
def test_vecnorm_parallel_auto(self, nprc):
@@ -9785,6 +9790,11 @@ def _run_parallelenv(parallel_env, queue_in, queue_out):
97859790
reason="Nested spawned multiprocessed is currently failing in python 3.11. "
97869791
"See https://github.com/python/cpython/pull/108568 for info and fix.",
97879792
)
9793+
@pytest.mark.skipif(
9794+
TORCH_VERSION < version.parse("2.8.0"),
9795+
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
9796+
"when using spawn multiprocessing start method.",
9797+
)
97889798
def test_parallelenv_vecnorm(self):
97899799
if _has_gym:
97909800
make_env = EnvCreator(
@@ -10051,6 +10061,11 @@ def _test_vecnorm_subproc_auto(
1005110061
def rename_t(self):
1005210062
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])
1005310063

10064+
@pytest.mark.skipif(
10065+
TORCH_VERSION < version.parse("2.8.0"),
10066+
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
10067+
"when using spawn multiprocessing start method.",
10068+
)
1005410069
@retry(AssertionError, tries=10, delay=0)
1005510070
@pytest.mark.parametrize("nprc", [2, 5])
1005610071
def test_vecnorm_parallel_auto(self, nprc):
@@ -10170,6 +10185,11 @@ def _run_parallelenv(parallel_env, queue_in, queue_out):
1017010185
reason="Nested spawned multiprocessed is currently failing in python 3.11. "
1017110186
"See https://github.com/python/cpython/pull/108568 for info and fix.",
1017210187
)
10188+
@pytest.mark.skipif(
10189+
TORCH_VERSION < version.parse("2.8.0"),
10190+
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
10191+
"when using spawn multiprocessing start method.",
10192+
)
1017310193
def test_parallelenv_vecnorm(self):
1017410194
if _has_gym:
1017510195
make_env = EnvCreator(

torchrl/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,27 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import os
6+
import warnings
67
import weakref
78
from warnings import warn
89

910
import torch
1011

11-
from tensordict import set_lazy_legacy
12+
# Silence noisy dependency warning triggered at import time on older torch stacks.
13+
# (Emitted by tensordict when registering pytree nodes.)
14+
warnings.filterwarnings(
15+
"ignore",
16+
category=UserWarning,
17+
message=r"torch\.utils\._pytree\._register_pytree_node is deprecated\.",
18+
)
19+
20+
from tensordict import set_lazy_legacy # noqa: E402
1221

13-
from torch import multiprocessing as mp
14-
from torch.distributions.transforms import _InverseTransform, ComposeTransform
22+
from torch import multiprocessing as mp # noqa: E402
23+
from torch.distributions.transforms import ( # noqa: E402
24+
_InverseTransform,
25+
ComposeTransform,
26+
)
1527

1628
torch._C._log_api_usage_once("torchrl")
1729

@@ -61,8 +73,7 @@
6173

6274
logger = logger
6375

64-
# TorchRL's multiprocessing default:
65-
# We only force "spawn" on newer PyTorch versions (see `_get_default_mp_start_method`).
76+
# TorchRL's multiprocessing default.
6677
_preferred_start_method = _get_default_mp_start_method()
6778
if _preferred_start_method == "spawn":
6879
try:

torchrl/_utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from torch._dynamo import is_compiling
3737

3838

39-
@implement_for("torch", "2.5.0")
4039
def _get_default_mp_start_method() -> str:
4140
"""Returns TorchRL's preferred multiprocessing start method for this torch version.
4241
@@ -46,20 +45,6 @@ def _get_default_mp_start_method() -> str:
4645
return "spawn"
4746

4847

49-
@implement_for("torch", None, "2.5.0")
50-
def _get_default_mp_start_method() -> str: # noqa: F811
51-
"""Returns TorchRL's preferred multiprocessing start method for this torch version.
52-
53-
On older PyTorch versions we prefer ``"fork"`` when available to avoid failures
54-
when spawning workers with non-CPU storages that must be pickled at process start.
55-
"""
56-
try:
57-
mp.get_context("fork")
58-
except ValueError:
59-
return "spawn"
60-
return "fork"
61-
62-
6348
def _get_mp_ctx(start_method: str | None = None):
6449
"""Return a multiprocessing context with TorchRL's preferred start method.
6550
@@ -108,6 +93,19 @@ def _set_mp_start_method_if_unset(start_method: str | None = None) -> str | None
10893
return current
10994

11095

96+
@implement_for("torch", None, "2.8")
97+
def _mp_sharing_strategy_for_spawn() -> str | None:
98+
# On older torch stacks, pickling Process objects for "spawn" can end up
99+
# passing file descriptors for shared storages; using "file_system" reduces
100+
# FD passing and avoids spawn-time failures on some old Python versions.
101+
return "file_system"
102+
103+
104+
@implement_for("torch", "2.8")
105+
def _mp_sharing_strategy_for_spawn() -> str | None: # noqa: F811
106+
return None
107+
108+
111109
def strtobool(val: Any) -> bool:
112110
"""Convert a string representation of truth to a boolean.
113111

torchrl/collectors/_multi_base.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import abc
55

66
import contextlib
7+
import sys
78
import warnings
89
from collections import OrderedDict
910
from collections.abc import Callable, Mapping, Sequence
@@ -20,6 +21,7 @@
2021
_check_for_faulty_process,
2122
_get_mp_ctx,
2223
_make_process_no_warn_cls,
24+
_mp_sharing_strategy_for_spawn,
2325
_set_mp_start_method_if_unset,
2426
RL_WARNINGS,
2527
)
@@ -33,7 +35,7 @@
3335
)
3436
from torchrl.collectors._runner import _main_async_collector
3537
from torchrl.collectors._single import Collector
36-
from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool
38+
from torchrl.collectors.utils import _make_meta_policy_cm, _TrajectoryPool
3739
from torchrl.collectors.weight_update import WeightUpdaterBase
3840
from torchrl.data import ReplayBuffer
3941
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
@@ -945,7 +947,14 @@ def _run_processes(self) -> None:
945947
ctx = _get_mp_ctx()
946948
# Best-effort global init (only if unset) to keep other mp users consistent.
947949
_set_mp_start_method_if_unset(ctx.get_start_method())
948-
950+
if (
951+
sys.platform == "linux"
952+
and sys.version_info < (3, 10)
953+
and ctx.get_start_method() == "spawn"
954+
):
955+
strategy = _mp_sharing_strategy_for_spawn()
956+
if strategy is not None:
957+
mp.set_sharing_strategy(strategy)
949958
queue_out = ctx.Queue(self._queue_len) # sends data from proc to main
950959
self.procs = []
951960
self._traj_pool = _TrajectoryPool(ctx=ctx, lock=True)
@@ -1004,9 +1013,11 @@ def _run_processes(self) -> None:
10041013
policy_to_send = None
10051014
cm = contextlib.nullcontext()
10061015
elif policy is not None:
1007-
# Send policy with meta-device parameters (empty structure) - schemes apply weights
1016+
# Send a stateless policy down to workers: schemes apply weights.
10081017
policy_to_send = policy
1009-
cm = _make_meta_policy(policy)
1018+
cm = _make_meta_policy_cm(
1019+
policy, mp_start_method=ctx.get_start_method()
1020+
)
10101021
else:
10111022
policy_to_send = None
10121023
cm = contextlib.nullcontext()
@@ -1037,7 +1048,6 @@ def _run_processes(self) -> None:
10371048
with cm:
10381049
kwargs = {
10391050
"policy_factory": policy_factory[i],
1040-
"pipe_parent": pipe_parent,
10411051
"pipe_child": pipe_child,
10421052
"queue_out": queue_out,
10431053
"create_env_fn": env_fun,
@@ -1128,6 +1138,29 @@ def _run_processes(self) -> None:
11281138
) from err
11291139
else:
11301140
raise err
1141+
except ValueError as err:
1142+
if "bad value(s) in fds_to_keep" in str(err):
1143+
# This error occurs on old Python versions (e.g., 3.9) with old PyTorch (e.g., 2.3)
1144+
# when using the spawn multiprocessing start method. The spawn implementation tries to
1145+
# preserve file descriptors across exec, but some descriptors may be invalid/closed.
1146+
# This is a compatibility issue with old Python multiprocessing implementations.
1147+
python_version = (
1148+
f"{sys.version_info.major}.{sys.version_info.minor}"
1149+
)
1150+
raise RuntimeError(
1151+
f"Failed to start collector worker process due to file descriptor issues "
1152+
f"with spawn multiprocessing on Python {python_version}.\n\n"
1153+
f"This is a known compatibility issue with old Python/PyTorch stacks. "
1154+
f"Consider upgrading to Python >= 3.10 and PyTorch >= 2.5, or use the 'fork' "
1155+
f"multiprocessing start method on Unix systems.\n\n"
1156+
f"Workarounds:\n"
1157+
f"- Upgrade Python to >= 3.10 and PyTorch to >= 2.5\n"
1158+
f"- On Unix systems, force fork start method:\n"
1159+
f" import torch.multiprocessing as mp\n"
1160+
f" if __name__ == '__main__':\n"
1161+
f" mp.set_start_method('fork', force=True)\n\n"
1162+
f"Upstream Python issue: https://github.com/python/cpython/issues/87706"
1163+
) from err
11311164
except _pickle.PicklingError as err:
11321165
if "<lambda>" in str(err):
11331166
raise RuntimeError(

torchrl/collectors/_runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535

3636
def _main_async_collector(
37-
pipe_parent: connection.Connection,
3837
pipe_child: connection.Connection,
3938
queue_out: queues.Queue,
4039
create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821
@@ -68,7 +67,6 @@ def _main_async_collector(
6867
) -> None:
6968
if collector_class is None:
7069
collector_class = Collector
71-
pipe_parent.close()
7270
# init variables that will be cleared when closing
7371
collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None
7472

torchrl/collectors/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,33 @@ def _make_meta_policy(policy: nn.Module):
298298
return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy)
299299

300300

301+
@implement_for("torch", None, "2.8")
302+
def _make_meta_policy_cm(
303+
policy: nn.Module, *, mp_start_method: str
304+
) -> contextlib.AbstractContextManager:
305+
"""Return the context manager used to make a policy 'stateless' for worker pickling.
306+
307+
On older PyTorch versions (<2.8), pickling meta-device storages when using the
308+
``spawn`` start method may fail (e.g., triggering ``_share_filename_: only available on CPU``).
309+
In that case, we avoid converting parameters/buffers to meta and simply return a no-op
310+
context manager.
311+
"""
312+
if mp_start_method == "spawn":
313+
return contextlib.nullcontext()
314+
return _make_meta_policy(policy)
315+
316+
317+
@implement_for("torch", "2.8")
318+
def _make_meta_policy_cm( # noqa: F811
319+
policy: nn.Module, *, mp_start_method: str
320+
) -> contextlib.AbstractContextManager:
321+
"""Return the context manager used to make a policy 'stateless' for worker pickling.
322+
323+
On PyTorch >= 2.8, meta-device policy structures can be pickled reliably under ``spawn``.
324+
"""
325+
return _make_meta_policy(policy)
326+
327+
301328
@implement_for("torch", None, "2.5.0")
302329
def _cast( # noqa
303330
p: nn.Parameter | torch.Tensor,

torchrl/envs/transforms/transforms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6872,6 +6872,22 @@ def __init__(
68726872
category=FutureWarning,
68736873
)
68746874

6875+
# Warn about shared memory limitations on older PyTorch
6876+
from packaging.version import parse as parse_version
6877+
6878+
if (
6879+
parse_version(torch.__version__).base_version < "2.8.0"
6880+
and shared_td is not None
6881+
):
6882+
warnings.warn(
6883+
"VecNorm with shared memory (shared_td) may not synchronize correctly "
6884+
"across processes on PyTorch < 2.8 when using the 'spawn' multiprocessing "
6885+
"start method. This is due to limitations in PyTorch's shared memory "
6886+
"implementation with the 'file_system' sharing strategy. "
6887+
"Consider upgrading to PyTorch >= 2.8 for full shared memory support.",
6888+
category=UserWarning,
6889+
)
6890+
68756891
if lock is None:
68766892
lock = mp.Lock()
68776893
if in_keys is None:

0 commit comments

Comments
 (0)