|
4 | 4 | import abc |
5 | 5 |
|
6 | 6 | import contextlib |
| 7 | +import sys |
7 | 8 | import warnings |
8 | 9 | from collections import OrderedDict |
9 | 10 | from collections.abc import Callable, Mapping, Sequence |
|
20 | 21 | _check_for_faulty_process, |
21 | 22 | _get_mp_ctx, |
22 | 23 | _make_process_no_warn_cls, |
| 24 | + _mp_sharing_strategy_for_spawn, |
23 | 25 | _set_mp_start_method_if_unset, |
24 | 26 | RL_WARNINGS, |
25 | 27 | ) |
|
33 | 35 | ) |
34 | 36 | from torchrl.collectors._runner import _main_async_collector |
35 | 37 | from torchrl.collectors._single import Collector |
36 | | -from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool |
| 38 | +from torchrl.collectors.utils import _make_meta_policy_cm, _TrajectoryPool |
37 | 39 | from torchrl.collectors.weight_update import WeightUpdaterBase |
38 | 40 | from torchrl.data import ReplayBuffer |
39 | 41 | from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING |
@@ -945,7 +947,14 @@ def _run_processes(self) -> None: |
945 | 947 | ctx = _get_mp_ctx() |
946 | 948 | # Best-effort global init (only if unset) to keep other mp users consistent. |
947 | 949 | _set_mp_start_method_if_unset(ctx.get_start_method()) |
948 | | - |
| 950 | + if ( |
| 951 | + sys.platform == "linux" |
| 952 | + and sys.version_info < (3, 10) |
| 953 | + and ctx.get_start_method() == "spawn" |
| 954 | + ): |
| 955 | + strategy = _mp_sharing_strategy_for_spawn() |
| 956 | + if strategy is not None: |
| 957 | + mp.set_sharing_strategy(strategy) |
949 | 958 | queue_out = ctx.Queue(self._queue_len) # sends data from proc to main |
950 | 959 | self.procs = [] |
951 | 960 | self._traj_pool = _TrajectoryPool(ctx=ctx, lock=True) |
@@ -1004,9 +1013,11 @@ def _run_processes(self) -> None: |
1004 | 1013 | policy_to_send = None |
1005 | 1014 | cm = contextlib.nullcontext() |
1006 | 1015 | elif policy is not None: |
1007 | | - # Send policy with meta-device parameters (empty structure) - schemes apply weights |
| 1016 | + # Send a stateless policy down to workers: schemes apply weights. |
1008 | 1017 | policy_to_send = policy |
1009 | | - cm = _make_meta_policy(policy) |
| 1018 | + cm = _make_meta_policy_cm( |
| 1019 | + policy, mp_start_method=ctx.get_start_method() |
| 1020 | + ) |
1010 | 1021 | else: |
1011 | 1022 | policy_to_send = None |
1012 | 1023 | cm = contextlib.nullcontext() |
@@ -1037,7 +1048,6 @@ def _run_processes(self) -> None: |
1037 | 1048 | with cm: |
1038 | 1049 | kwargs = { |
1039 | 1050 | "policy_factory": policy_factory[i], |
1040 | | - "pipe_parent": pipe_parent, |
1041 | 1051 | "pipe_child": pipe_child, |
1042 | 1052 | "queue_out": queue_out, |
1043 | 1053 | "create_env_fn": env_fun, |
@@ -1128,6 +1138,29 @@ def _run_processes(self) -> None: |
1128 | 1138 | ) from err |
1129 | 1139 | else: |
1130 | 1140 | raise err |
| 1141 | + except ValueError as err: |
| 1142 | + if "bad value(s) in fds_to_keep" in str(err): |
| 1143 | + # This error occurs on old Python versions (e.g., 3.9) with old PyTorch (e.g., 2.3) |
| 1144 | + # when using the spawn multiprocessing start method. The spawn implementation tries to |
| 1145 | + # preserve file descriptors across exec, but some descriptors may be invalid/closed. |
| 1146 | + # This is a compatibility issue with old Python multiprocessing implementations. |
| 1147 | + python_version = ( |
| 1148 | + f"{sys.version_info.major}.{sys.version_info.minor}" |
| 1149 | + ) |
| 1150 | + raise RuntimeError( |
| 1151 | + f"Failed to start collector worker process due to file descriptor issues " |
| 1152 | + f"with spawn multiprocessing on Python {python_version}.\n\n" |
| 1153 | + f"This is a known compatibility issue with old Python/PyTorch stacks. " |
| 1154 | + f"Consider upgrading to Python >= 3.10 and PyTorch >= 2.5, or use the 'fork' " |
| 1155 | + f"multiprocessing start method on Unix systems.\n\n" |
| 1156 | + f"Workarounds:\n" |
| 1157 | + f"- Upgrade Python to >= 3.10 and PyTorch to >= 2.5\n" |
| 1158 | + f"- On Unix systems, force fork start method:\n" |
| 1159 | + f" import torch.multiprocessing as mp\n" |
| 1160 | + f" if __name__ == '__main__':\n" |
| 1161 | + f" mp.set_start_method('fork', force=True)\n\n" |
| 1162 | + f"Upstream Python issue: https://github.com/python/cpython/issues/87706" |
| 1163 | + ) from err |
1131 | 1164 | except _pickle.PicklingError as err: |
1132 | 1165 | if "<lambda>" in str(err): |
1133 | 1166 | raise RuntimeError( |
|
0 commit comments