From 4cb3dc3f42270df3d00e2a1d7695376a36314050 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Nov 2025 18:16:19 +0000 Subject: [PATCH 01/17] huge refactor --- test/services/test_python_executor_service.py | 16 +- test/test_collector.py | 125 +- torchrl/collectors/__init__.py | 13 +- torchrl/collectors/_constants.py | 84 + torchrl/collectors/_multi_async.py | 295 + torchrl/collectors/_multi_base.py | 1478 +++++ torchrl/collectors/_multi_sync.py | 430 ++ torchrl/collectors/_runner.py | 504 ++ torchrl/collectors/_single.py | 1779 ++++++ torchrl/collectors/_single_async.py | 248 + torchrl/collectors/base.py | 469 ++ torchrl/collectors/collectors.py | 5005 +---------------- torchrl/collectors/distributed/generic.py | 12 +- torchrl/collectors/distributed/ray.py | 12 +- torchrl/collectors/distributed/rpc.py | 12 +- torchrl/collectors/distributed/sync.py | 12 +- torchrl/collectors/llm/base.py | 2 +- torchrl/collectors/llm/weight_update/vllm.py | 2 +- .../collectors/llm/weight_update/vllm_v2.py | 2 +- torchrl/collectors/utils.py | 124 +- torchrl/envs/batched_envs.py | 8 +- torchrl/envs/llm/transforms/tools.py | 20 +- torchrl/weight_update/weight_sync_schemes.py | 262 +- 23 files changed, 5794 insertions(+), 5120 deletions(-) create mode 100644 torchrl/collectors/_constants.py create mode 100644 torchrl/collectors/_multi_async.py create mode 100644 torchrl/collectors/_multi_base.py create mode 100644 torchrl/collectors/_multi_sync.py create mode 100644 torchrl/collectors/_runner.py create mode 100644 torchrl/collectors/_single.py create mode 100644 torchrl/collectors/_single_async.py create mode 100644 torchrl/collectors/base.py diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py index cb55c0a6a10..b18181c573f 100644 --- a/test/services/test_python_executor_service.py +++ b/test/services/test_python_executor_service.py @@ -73,7 +73,7 @@ def test_service_execution(self, ray_init): result = x + y print(f"Result: {result}") """ - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is True assert "Result: 30" in result["stdout"] @@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init): # Execute code with an error code = "raise ValueError('Test error')" - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is False assert "ValueError: Test error" in result["stderr"] @@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init): "python_executor", PythonExecutorService, pool_size=4, - timeout=5.0, + timeout=10.0, num_cpus=4, max_concurrency=4, ) @@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init): code = f"print('Execution {i}')" futures.append(executor.execute.remote(code)) - # Wait for all to complete - results = ray.get(futures, timeout=5) + # Wait for all to complete with longer timeout + results = ray.get(futures, timeout=30) # All should succeed assert len(results) == 8 for i, result in enumerate(results): - assert result["success"] is True - assert f"Execution {i}" in result["stdout"] + assert result["success"] is True, f"Execution {i} failed: {result}" + assert ( + f"Execution {i}" in result["stdout"] + ), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}" finally: services.reset() diff --git a/test/test_collector.py b/test/test_collector.py index 73c6e5c3d21..bc99b51c08e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -13,11 +13,14 @@ import subprocess import sys import time +from contextlib import nullcontext from unittest.mock import patch import numpy as np import pytest import torch + +import torchrl.collectors._runner from packaging import version from tensordict import ( assert_allclose_td, @@ -33,7 +36,6 @@ TensorDictSequential, ) from torch import nn - from torchrl._utils import ( _make_ordinal_device, _replace_last, @@ -48,7 +50,7 @@ SyncDataCollector, WeightUpdaterBase, ) -from torchrl.collectors.collectors import _Interruptor +from torchrl.collectors._constants import _Interruptor from torchrl.collectors.utils import split_trajectories from torchrl.data import ( @@ -1487,12 +1489,14 @@ def env_fn(seed): assert_allclose_td(data10, data20) @pytest.mark.parametrize("use_async", [False, True]) - @pytest.mark.parametrize("cudagraph", [False, True]) + @pytest.mark.parametrize( + "cudagraph", [False, True] if torch.cuda.is_available() else [False] + ) @pytest.mark.parametrize( "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") + # @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found") def test_update_weights(self, use_async, cudagraph, weight_sync_scheme): def create_env(): return ContinuousActionVecMockEnv() @@ -1509,11 +1513,12 @@ def create_env(): kwargs = {} if weight_sync_scheme is not None: kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} + device = "cuda:0" if torch.cuda.is_available() else "cpu" collector = collector_class( [create_env] * 3, policy=policy, - device=[torch.device("cuda:0")] * 3, - storing_device=[torch.device("cuda:0")] * 3, + device=[torch.device(device)] * 3, + storing_device=[torch.device(device)] * 3, frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, @@ -1544,7 +1549,9 @@ def create_env(): # check they don't match for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError + ) if torch.cuda.is_available() else nullcontext(): torch.testing.assert_close( state_dict[f"worker{worker}"]["policy_state_dict"][k], policy_state_dict[k].cpu(), @@ -2401,7 +2408,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, - match=("Arguments to policy.forward are incompatible with entries in"), + match=( + "Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." + ), ): collector_class( **self._create_collector_kwargs( @@ -2980,6 +2989,94 @@ def test_param_sync_mixed_device( col.shutdown() del col + @pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="requires at least 3 CUDA devices", + ) + @pytest.mark.parametrize( + "weight_sync_scheme", + [SharedMemWeightSyncScheme, MultiProcessWeightSyncScheme], + ) + def test_shared_device_weight_update(self, weight_sync_scheme): + """Test that weight updates work correctly when multiple workers share the same device. + + This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme. + When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through + dedicated queues, preventing race conditions that could occur with a single shared queue. + """ + # Create policy on cuda:0 + policy = TensorDictModule( + nn.Linear(7, 7, device="cuda:0"), + in_keys=["observation"], + out_keys=["action"], + ) + + def make_env(): + return ContinuousActionVecMockEnv() + + # Create collector with workers on cuda:2, cuda:1, cuda:2 + # Workers 0 and 2 share cuda:2 - this is the key test case + collector = MultiaSyncDataCollector( + [make_env, make_env, make_env], + policy=policy, + frames_per_batch=30, + total_frames=300, + device=["cuda:2", "cuda:1", "cuda:2"], + storing_device=["cuda:2", "cuda:1", "cuda:2"], + weight_sync_schemes={"policy": weight_sync_scheme()}, + ) + + try: + # Collect first batch to initialize workers + for _ in collector: + break + + # Get initial weights + old_weight = policy.module.weight.data.clone() + + # Modify policy weights on cuda:0 + for p in policy.parameters(): + p.data += torch.randn_like(p) + + new_weight = policy.module.weight.data.clone() + assert not torch.allclose( + old_weight, new_weight + ), "Weights should have changed" + + # Update weights - this should propagate to all workers via their dedicated queues + collector.update_policy_weights_() + + # Collect more batches to ensure weights are propagated + for i, _ in enumerate(collector): + if i >= 2: + break + + # Get state dict from all workers + state_dict = collector.state_dict() + + # Verify all workers have the new weights, including both workers on cuda:2 + for worker_idx in range(3): + worker_key = f"worker{worker_idx}" + assert ( + "policy_state_dict" in state_dict[worker_key] + ), f"Worker {worker_idx} should have policy_state_dict" + worker_weight = state_dict[worker_key]["policy_state_dict"][ + "module.weight" + ] + torch.testing.assert_close( + worker_weight.cpu(), + new_weight.cpu(), + msg=( + f"Worker {worker_idx} weights don't match expected weights. " + f"Workers 0 and 2 share device cuda:2, worker 1 is on cuda:1. " + f"This test validates that the per-worker queue system correctly " + f"distributes weights even when multiple workers share a device." + ), + ) + finally: + collector.shutdown() + del collector + class TestAggregateReset: def test_aggregate_reset_to_root(self): @@ -3176,11 +3273,11 @@ class TestLibThreading: reason="setting different threads across workers can randomly fail on OSX.", ) def test_num_threads(self): - from torchrl.collectors import collectors + pass - _main_async_collector_saved = collectors._main_async_collector - collectors._main_async_collector = decorate_thread_sub_func( - collectors._main_async_collector, num_threads=3 + _main_async_collector_saved = torchrl.collectors._runner._main_async_collector + torchrl.collectors._runner._main_async_collector = decorate_thread_sub_func( + torchrl.collectors._runner._main_async_collector, num_threads=3 ) num_threads = torch.get_num_threads() try: @@ -3204,7 +3301,9 @@ def test_num_threads(self): except Exception: torchrl_logger.info("Failed to shut down collector") # reset vals - collectors._main_async_collector = _main_async_collector_saved + torchrl.collectors._runner._main_async_collector = ( + _main_async_collector_saved + ) torch.set_num_threads(num_threads) @pytest.mark.skipif( diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 7f1c812943d..5e2ef63fb69 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -5,13 +5,12 @@ from torchrl.envs.utils import RandomPolicy -from .collectors import ( - aSyncDataCollector, - DataCollectorBase, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from ._multi_async import MultiaSyncDataCollector +from ._multi_sync import MultiSyncDataCollector +from ._single import SyncDataCollector + +from ._single_async import aSyncDataCollector +from .base import DataCollectorBase from .weight_update import ( MultiProcessedWeightUpdater, RayWeightUpdater, diff --git a/torchrl/collectors/_constants.py b/torchrl/collectors/_constants.py new file mode 100644 index 00000000000..1587d800166 --- /dev/null +++ b/torchrl/collectors/_constants.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Constants and helper classes for collectors.""" +from __future__ import annotations + +import os +import sys +from multiprocessing.managers import SyncManager + +import torch +from torch import multiprocessing as mp + +from torchrl.envs.utils import ExplorationType + +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") + + +__all__ = [ + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] + +_TIMEOUT = 1.0 +INSTANTIATE_TIMEOUT = 20 +_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory +# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. +_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) + +DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM + +_is_osx = sys.platform.startswith("darwin") + + +class _Interruptor: + """A class for managing the collection state of a process. + + This class provides methods to start and stop collection, and to check + whether collection has been stopped. The collection state is protected + by a lock to ensure thread-safety. + """ + + # interrupter vs interruptor: google trends seems to indicate that "or" is more + # widely used than "er" even if my IDE complains about that... + def __init__(self): + self._collect = True + self._lock = mp.Lock() + + def start_collection(self): + with self._lock: + self._collect = True + + def stop_collection(self): + with self._lock: + self._collect = False + + def collection_stopped(self): + with self._lock: + return self._collect is False + + +class _InterruptorManager(SyncManager): + """A custom SyncManager for managing the collection state of a process. + + This class extends the SyncManager class and allows to share an Interruptor object + between processes. + """ + + +_InterruptorManager.register("_Interruptor", _Interruptor) diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py new file mode 100644 index 00000000000..6e9b3a55f7b --- /dev/null +++ b/torchrl/collectors/_multi_async.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import time +import warnings +from collections import defaultdict, OrderedDict +from collections.abc import Iterator, Sequence +from copy import deepcopy +from queue import Empty + +import torch + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl._utils import _check_for_faulty_process, accept_remote_rref_udf_invocation +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiaSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes asynchronously. + + .. aafig:: + + + +----------------------------------------------------------------------+ + | "MultiConcurrentCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor | "step" | "step" | "actor" | | + | | | | | | + | "yield batch 1" | "actor" | |"collect, train"| + | | | | | + | "step" | "step" | | "yield batch 2" |"collect, train"| + | | | | | | + | | | "yield batch 3" | |"collect, train"| + | | | | | | + +----------------------------------------------------------------------+ + + Environment types can be identical or different. + + The collection keeps on occurring on all processes even between the time + the batch of rollouts is collected and the next call to the iterator. + This class can be safely used with offline RL sota-implementations. + + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.out_tensordicts = defaultdict(lambda: None) + self.running = False + + if self.postprocs is not None and self.replay_buffer is None: + postproc = self.postprocs + self.postprocs = {} + for _device in self.storing_device: + if _device not in self.postprocs: + if hasattr(postproc, "to"): + postproc = deepcopy(postproc).to(_device) + self.postprocs[_device] = postproc + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: + return self.requested_frames_per_batch + + def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: + new_data, j = self.queue_out.get(timeout=timeout) + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.out_tensordicts[idx] = data + if use_buffers is None and j > 0: + use_buffers = self._use_buffers = False + except TypeError: + if use_buffers is None: + use_buffers = self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + out = self.out_tensordicts[idx] + if not self.replay_buffer and (j == 0 or use_buffers): + # we clone the data to make sure that we'll be working with a fixed copy + out = out.clone() + return idx, j, out + + @property + def _queue_len(self) -> int: + return 1 + + def iterator(self) -> Iterator[TensorDictBase]: + if self.update_at_each_batch: + self.update_policy_weights_() + + for i in range(self.num_workers): + if self.init_random_frames is not None and self.init_random_frames > 0: + self.pipes[i].send((None, "continue_random")) + else: + self.pipes[i].send((None, "continue")) + self.running = True + + workers_frames = [0 for _ in range(self.num_workers)] + while self._frames < self.total_frames: + self._iter += 1 + counter = 0 + while True: + try: + idx, j, out = self._get_from_queue(timeout=_TIMEOUT) + break + except (TimeoutError, Empty): + counter += _TIMEOUT + _check_for_faulty_process(self.procs) + if counter > (_TIMEOUT * _MAX_IDLE_COUNT): + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker() + self._frames += worker_frames + workers_frames[idx] = workers_frames[idx] + worker_frames + if out is not None and self.postprocs: + out = self.postprocs[out.device](out) + + # the function blocks here until the next item is asked, hence we send the message to the + # worker to keep on working in the meantime before the yield statement + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + self.pipes[idx].send((idx, msg)) + if out is not None and self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + out = out.exclude(*excluded_keys) + yield out + + # We don't want to shutdown yet, the user may want to call state_dict before + # self._shutdown_main() + self.running = False + + def _shutdown_main(self, *args, **kwargs) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + return super()._shutdown_main(*args, **kwargs) + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + super().reset(reset_idx) + if self.queue_out.full(): + time.sleep(_TIMEOUT) # wait until queue is empty + if self.queue_out.full(): + raise Exception("self.queue_out is full") + if self.running: + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.pipes[idx].send((idx, "continue_random")) + else: + self.pipes[idx].send((idx, "continue")) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py new file mode 100644 index 00000000000..f9d7ea7a8bd --- /dev/null +++ b/torchrl/collectors/_multi_base.py @@ -0,0 +1,1478 @@ +from __future__ import annotations + +import _pickle + +import contextlib +import warnings +from collections import OrderedDict +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule +from tensordict.utils import _zip_strict +from torch import multiprocessing as mp, nn +from torchrl import logger as torchrl_logger +from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, RL_WARNINGS +from torchrl.collectors._constants import ( + _InterruptorManager, + _is_osx, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, + INSTANTIATE_TIMEOUT, +) +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase +from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + WeightSyncScheme, +) + + +class _MultiDataCollector(DataCollectorBase): + """Runs a given number of DataCollectors on separate processes. + + Args: + create_env_fn (List[Callabled]): list of Callables, each returning an + instance of :class:`~torchrl.envs.EnvBase`. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided (default), the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: + ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable + (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + .. warning:: `policy_factory` is currently not compatible with multiprocessed data + collectors. + + num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. + Defaults to `None` (workers determined by the `create_env_fn` length). + frames_per_batch (int, Sequence[int]): A keyword-only argument representing the + total number of elements in a batch. If a sequence is provided, represents the number of elements in a + batch per worker. Total number of elements in a batch is then the sum over the sequence. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be + :class:`~torchrl.collectors.SyncDataCollector`, + :class:`~torchrl.collectors.MultiSyncDataCollector`, + :class:`~torchrl.collectors.MultiaSyncDataCollector` + or a derived class of these. + Defaults to :class:`~torchrl.collectors.SyncDataCollector`. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). + If ``"stack"``, the data collected from the workers will be stacked along the + first dimension. This is the preferred behavior as it is the most compatible + with the rest of the library. + If ``0``, results will be concatenated along the first dimension + of the outputs, which can be the batched dimension if the environments are + batched or the time dimension if not. + A ``cat_results`` value of ``-1`` will always concatenate results along the + time dimension. This should be preferred over the default. Intermediate values + are also accepted. + Defaults to ``"stack"``. + + .. note:: From v0.5, this argument will default to ``"stack"`` for a better + interoperability with the rest of the library. + + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. Defaults to ``None``. + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True` for multiprocessed data collectors. + local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize + the replay buffer in the main process (legacy behavior). If ``True``, the storage-level + coordination will handle initialization with real data from worker processes. + Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. + This parameter is deprecated and will be removed in v0.12. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, + which handles weight synchronization across multiple processes. + Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. + If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + """ + + def __init__( + self, + create_env_fn: Sequence[Callable[[], EnvBase]], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + num_workers: int | None = None, + policy_factory: Callable[[], Callable] + | list[Callable[[], Callable]] + | None = None, + frames_per_batch: int | Sequence[int], + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict] | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + cat_results: str | int | None = None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + replay_buffer_chunk: bool | None = None, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + ): + self.closed = True + + # Set up workers and environment functions + create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( + create_env_fn, num_workers, frames_per_batch + ) + + # Set up basic configuration + self.set_truncated = set_truncated + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads + self.create_env_fn = create_env_fn + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Set up environment kwargs + self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) + + # Set up devices + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices + self.collector_class = collector_class + del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync + + # Set up replay buffer + self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self._setup_multi_replay_buffer( + local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer + ) + + # Set up policy and weights + if trust_policy is None: + trust_policy = policy is not None and isinstance(policy, CudaGraphModule) + self.trust_policy = trust_policy + + policy_factory = self._setup_policy_factory(policy_factory) + + # Set up weight synchronization + if ( + not any(policy_factory) + and not weight_sync_schemes + and weight_updater is None + ): + weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} + + self._setup_multi_policy_and_weights( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + + # Set up policy version tracking + self._setup_multi_policy_version_tracking(track_policy_version) + + # Store policy and policy_factory + self.policy = policy + self.policy_factory = policy_factory + + # Set up fallback policy for weight extraction + self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + + # Set up total frames and other parameters + self._setup_multi_total_frames( + total_frames, total_frames_per_batch, frames_per_batch + ) + self.reset_at_each_iter = reset_at_each_iter + self.postprocs = postproc + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + + # Set up split trajectories + self.requested_frames_per_batch = total_frames_per_batch + self.reset_when_done = reset_when_done + self._setup_split_trajs(split_trajs, reset_when_done) + + # Set up other parameters + self.init_random_frames = ( + int(init_random_frames) if init_random_frames is not None else 0 + ) + self.update_at_each_batch = update_at_each_batch + self.exploration_type = exploration_type + self.frames_per_worker = np.inf + + # Set up preemptive threshold + self._setup_preemptive_threshold(preemptive_threshold) + + # Run worker processes + try: + self._run_processes() + except Exception as e: + self.shutdown(raise_on_error=False) + raise e + + # Set up frame tracking and other options + self._exclude_private_keys = True + self._frames = 0 + self._iter = -1 + + # Validate cat_results + self._validate_cat_results(cat_results) + + def _setup_workers_and_env_fns( + self, + create_env_fn: Sequence[Callable] | Callable, + num_workers: int | None, + frames_per_batch: int | Sequence[int], + ) -> tuple[list[Callable], int]: + """Set up workers and environment functions.""" + if isinstance(create_env_fn, Sequence): + self.num_workers = len(create_env_fn) + else: + self.num_workers = num_workers + create_env_fn = [create_env_fn] * self.num_workers + + if ( + isinstance(frames_per_batch, Sequence) + and len(frames_per_batch) != self.num_workers + ): + raise ValueError( + "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." + f"Got {len(frames_per_batch)} values for {self.num_workers} workers." + ) + + self._frames_per_batch = frames_per_batch + total_frames_per_batch = ( + sum(frames_per_batch) + if isinstance(frames_per_batch, Sequence) + else frames_per_batch + ) + + return create_env_fn, total_frames_per_batch + + def _setup_env_kwargs( + self, create_env_kwargs: Sequence[dict] | dict | None + ) -> list[dict]: + """Set up environment kwargs for each worker.""" + if isinstance(create_env_kwargs, Mapping): + create_env_kwargs = [create_env_kwargs] * self.num_workers + elif create_env_kwargs is None: + create_env_kwargs = [{}] * self.num_workers + elif isinstance(create_env_kwargs, (tuple, list)): + create_env_kwargs = list(create_env_kwargs) + if len(create_env_kwargs) != self.num_workers: + raise ValueError( + f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" + ) + return create_env_kwargs + + def _setup_multi_replay_buffer( + self, + local_init_rb: bool | None, + replay_buffer: ReplayBuffer | None, + replay_buffer_chunk: bool | None, + extend_buffer: bool, + ) -> None: + """Set up replay buffer for multi-process collector.""" + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + self._check_replay_buffer_init() + + if replay_buffer_chunk is not None: + if extend_buffer is None: + replay_buffer_chunk = extend_buffer + warnings.warn( + "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", + DeprecationWarning, + ) + elif extend_buffer != replay_buffer_chunk: + raise ValueError( + "conflicting values for replay_buffer_chunk and extend_buffer." + ) + self.extend_buffer = extend_buffer + + if ( + replay_buffer is not None + and hasattr(replay_buffer, "shared") + and not replay_buffer.shared + ): + torchrl_logger.warning("Replay buffer is not shared. Sharing it.") + replay_buffer.share() + + def _setup_policy_factory( + self, policy_factory: Callable | list[Callable] | None + ) -> list[Callable | None]: + """Set up policy factory for each worker.""" + if not isinstance(policy_factory, Sequence): + policy_factory = [policy_factory] * self.num_workers + return policy_factory + + def _setup_multi_policy_and_weights( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy for multi-process collector. + + With weight sync schemes: validates and stores policy without weight extraction. + With weight updater: extracts weights and creates stateful policies. + """ + if any(policy_factory) and policy is not None: + raise TypeError("policy_factory and policy are mutually exclusive") + + if weight_sync_schemes is not None: + # Weight sync schemes handle all weight distribution + # Extract weights so schemes can access them, but don't do in-place replacement + self._policy_weights_dict = {} + self._fallback_policy = None + + if not any(policy_factory) and policy is not None: + # Extract weights for the first device so schemes can access them + # Use first device as representative + first_device = self.policy_device[0] if self.policy_device else None + + # Validate device types for SharedMemWeightSyncScheme + for scheme in weight_sync_schemes.values(): + if isinstance(scheme, SharedMemWeightSyncScheme): + for policy_device in self.policy_device: + if policy_device and policy_device.type not in ( + "cpu", + "cuda", + ): + raise NotImplementedError( + f"Device type '{policy_device.type}' not supported for SharedMemWeightSyncScheme. " + f"Only 'cpu' and 'cuda' are supported." + ) + + # Extract weights from policy + # Use .data to avoid gradient tracking (can't serialize tensors with requires_grad) + weights = ( + TensorDict.from_module(policy, as_module=True).data + if isinstance(policy, nn.Module) + else TensorDict() + ) + + # For SharedMemWeightSyncScheme, share the weights + if any( + isinstance(scheme, SharedMemWeightSyncScheme) + for scheme in weight_sync_schemes.values() + ): + if first_device and first_device.type == "cpu": + weights = weights.share_memory_() + elif first_device and first_device.type == "cuda": + # CUDA tensors maintain shared references through mp.Queue + weights = weights.to(first_device).share_memory_() + + self._policy_weights_dict[first_device] = weights + self._fallback_policy = policy + + self._get_weights_fn = None + else: + # Using legacy weight updater - extract weights and create stateful policies + self._setup_multi_policy_and_weights_legacy( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + def _setup_multi_policy_and_weights_legacy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy and extract weights for each device. + + Creates stateful policies with weights extracted and placed in shared memory. + Used with weight updater for in-place weight replacement. + """ + self._policy_weights_dict = {} + self._fallback_policy = None # Policy to use for weight extraction fallback + + if not any(policy_factory): + for policy_device, env_maker, env_maker_kwargs in _zip_strict( + self.policy_device, self.create_env_fn, self.create_env_kwargs + ): + policy_new_device, get_weights_fn = self._get_policy_and_device( + policy=policy, + policy_device=policy_device, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, + ) + if type(policy_new_device) is not type(policy): + policy = policy_new_device + weights = ( + TensorDict.from_module(policy_new_device) + if isinstance(policy_new_device, nn.Module) + else TensorDict() + ) + # For multi-process collectors, ensure weights are in shared memory + if policy_device and policy_device.type == "cpu": + weights = weights.share_memory_() + self._policy_weights_dict[policy_device] = weights + # Store the first policy instance for fallback weight extraction + if self._fallback_policy is None: + self._fallback_policy = policy_new_device + self._get_weights_fn = get_weights_fn + if weight_updater is None: + # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default + if weight_sync_schemes is None: + weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + elif weight_updater is None: + warnings.warn( + "weight_updater is None, but policy_factory is provided. This means that the server will " + "not know how to send the weights to the workers. If the workers can handle their weight synchronization " + "on their own (via some specialized worker type / constructor) this may well work, but make sure " + "your weight synchronization strategy is properly set. To suppress this warning, you can use " + "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " + "This will work whenever your inference and training policies are nn.Module instances with similar structures." + ) + + def _setup_multi_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization for multi-process collector.""" + if weight_sync_schemes is not None: + # Use weight sync schemes for weight distribution + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + # Senders will be created in _run_processes + self.weight_updater = None + else: + # Use weight updater for weight distribution + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + + def _setup_multi_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking for multi-process collector.""" + self.policy_version_tracker = track_policy_version + if PolicyVersion is not None: + if isinstance(track_policy_version, bool) and track_policy_version: + self.policy_version_tracker = PolicyVersion() + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + else: + self.policy_version_tracker = None + else: + if track_policy_version: + raise ImportError( + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + ) + self.policy_version_tracker = None + + def _setup_fallback_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up fallback policy for weight extraction when using policy_factory.""" + # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided + # If policy_factory was used, create a policy instance to use as fallback + if policy is None and any(policy_factory) and weight_sync_schemes is not None: + if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: + first_factory = ( + policy_factory[0] + if isinstance(policy_factory, list) + else policy_factory + ) + if first_factory is not None: + # Create a policy instance for weight extraction + # This will be a reference to a policy with the same structure + # For shared memory, modifications to any policy will be visible here + self._fallback_policy = first_factory() + + def _setup_multi_total_frames( + self, + total_frames: int, + total_frames_per_batch: int, + frames_per_batch: int | Sequence[int], + ) -> None: + """Validate and set total frames for multi-process collector.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % total_frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " + f"This means {total_frames_per_batch - remainder} additional frames will be collected. " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_split_trajs( + self, split_trajs: bool | None, reset_when_done: bool + ) -> None: + """Set up split trajectories option.""" + if split_trajs is None: + split_trajs = False + elif not reset_when_done and split_trajs: + raise RuntimeError( + "Cannot split trajectories when reset_when_done is False." + ) + self.split_trajs = split_trajs + + def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: + """Set up preemptive threshold for early stopping.""" + if preemptive_threshold is not None: + if _is_osx: + raise NotImplementedError( + "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." + ) + self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) + manager = _InterruptorManager() + manager.start() + self.interruptor = manager._Interruptor() + else: + self.preemptive_threshold = 1.0 + self.interruptor = None + + def _validate_cat_results(self, cat_results: str | int | None) -> None: + """Validate cat_results parameter.""" + if cat_results is not None and ( + not isinstance(cat_results, (int, str)) + or (isinstance(cat_results, str) and cat_results != "stack") + ): + raise ValueError( + "cat_results must be a string ('stack') " + f"or an integer representing the cat dimension. Got {cat_results}." + ) + # Lazy import to avoid circular dependency + from torchrl.collectors._multi_sync import MultiSyncDataCollector + + if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( + "stack", + None, + ): + raise ValueError( + "cat_results can only be used with ``MultiSyncDataCollector``." + ) + self.cat_results = cat_results + + def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return + is_init = hasattr(self.replay_buffer, "_storage") and getattr( + self.replay_buffer._storage, "initialized", True + ) + if not is_init: + if self.local_init_rb: + # New behavior: storage handles all coordination itself + # Nothing to do here - the storage will coordinate during first write + self.replay_buffer.share() + return + + # Legacy behavior: fake tensordict initialization + if isinstance(self.create_env_fn[0], EnvCreator): + fake_td = self.create_env_fn[0].meta_data.tensordict + elif isinstance(self.create_env_fn[0], EnvBase): + fake_td = self.create_env_fn[0].fake_tensordict() + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) + # Use extend to avoid time-related transforms to fail + self.replay_buffer.extend(fake_td.unsqueeze(-1)) + self.replay_buffer.empty() + + @classmethod + def _total_workers_from_env(cls, env_creators): + if isinstance(env_creators, (tuple, list)): + return sum( + cls._total_workers_from_env(env_creator) for env_creator in env_creators + ) + from torchrl.envs import ParallelEnv + + if isinstance(env_creators, ParallelEnv): + return env_creators.num_workers + return 1 + + def _get_devices( + self, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + # convert all devices to lists + if not isinstance(storing_device, (list, tuple)): + storing_device = [ + storing_device, + ] * self.num_workers + if not isinstance(policy_device, (list, tuple)): + policy_device = [ + policy_device, + ] * self.num_workers + if not isinstance(env_device, (list, tuple)): + env_device = [ + env_device, + ] * self.num_workers + if not isinstance(device, (list, tuple)): + device = [ + device, + ] * self.num_workers + if not ( + len(device) + == len(storing_device) + == len(policy_device) + == len(env_device) + == self.num_workers + ): + raise RuntimeError( + f"THe length of the devices does not match the number of workers: {self.num_workers}." + ) + storing_device, policy_device, env_device = zip( + *[ + SyncDataCollector._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + for (storing_device, policy_device, env_device, device) in zip( + storing_device, policy_device, env_device, device + ) + ] + ) + return storing_device, policy_device, env_device + + def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: + raise NotImplementedError + + @property + def _queue_len(self) -> int: + raise NotImplementedError + + def _run_processes(self) -> None: + if self.num_threads is None: + total_workers = self._total_workers_from_env(self.create_env_fn) + self.num_threads = max( + 1, torch.get_num_threads() - total_workers + ) # 1 more thread for this proc + + # Set up for worker processes + torch.set_num_threads(self.num_threads) + queue_out = mp.Queue(self._queue_len) # sends data from proc to main + self.procs = [] + self.pipes = [] + self._traj_pool = _TrajectoryPool(lock=True) + + # Initialize weight sync schemes early for SharedMemWeightSyncScheme + # (queue created in __init__ will be pickled with scheme to workers) + # For MultiProcessWeightSyncScheme, we'll initialize after pipes are available + if self._weight_sync_schemes: + for model_id, scheme in self._weight_sync_schemes.items(): + # Only initialize SharedMemWeightSyncScheme now (needs queue before workers) + # MultiProcessWeightSyncScheme will be initialized after workers are created + if isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( + scheme, "init_on_sender" + ): + scheme.init_on_sender(model_id=model_id, context=self) + self._weight_senders[model_id] = scheme.get_sender() + + # Create a policy on the right device + policy_factory = self.policy_factory + if any(policy_factory): + policy_factory = [ + CloudpickleWrapper(_policy_factory) + for _policy_factory in policy_factory + ] + + for i, (env_fun, env_fun_kwargs) in enumerate( + zip(self.create_env_fn, self.create_env_kwargs) + ): + pipe_parent, pipe_child = mp.Pipe() # send messages to procs + if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( + env_fun, EnvBase + ): # to avoid circular imports + env_fun = CloudpickleWrapper(env_fun) + + policy_device = self.policy_device[i] + storing_device = self.storing_device[i] + env_device = self.env_device[i] + + # Prepare policy for worker based on weight synchronization method + policy = self.policy + + if self._weight_sync_schemes: + # With weight sync schemes, send stateless policies + # Schemes handle weight distribution on worker side + if any(policy_factory): + policy_to_send = None # Factory will create policy in worker + elif policy is not None: + # Send meta-device policy (empty structure) - schemes apply weights + policy_to_send = _make_meta_policy(policy) + else: + policy_to_send = None + cm = contextlib.nullcontext() + else: + # With weight updater, use in-place weight replacement + # Take the weights and locally dispatch them to the policy before sending. + # This ensures a given set of shared weights for a device are shared + # for all policies that rely on that device. + policy_weights = self._policy_weights_dict.get(policy_device) + policy_to_send = policy + if policy is not None and policy_weights is not None: + cm = policy_weights.to_module(policy) + else: + cm = contextlib.nullcontext() + + with cm: + kwargs = { + "policy_factory": policy_factory[i], + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": policy_to_send, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), + "reset_at_each_iter": self.reset_at_each_iter, + "policy_device": policy_device, + "storing_device": storing_device, + "env_device": env_device, + "exploration_type": self.exploration_type, + "reset_when_done": self.reset_when_done, + "idx": i, + "interruptor": self.interruptor, + "set_truncated": self.set_truncated, + "use_buffers": self._use_buffers, + "replay_buffer": self.replay_buffer, + "extend_buffer": self.extend_buffer, + "traj_pool": self._traj_pool, + "trust_policy": self.trust_policy, + "compile_policy": self.compiled_policy_kwargs + if self.compiled_policy + else False, + "cudagraph_policy": self.cudagraphed_policy_kwargs + if self.cudagraphed_policy + else False, + "no_cuda_sync": self.no_cuda_sync, + "collector_class": self.collector_class, + "postproc": self.postprocs + if self.replay_buffer is not None + else None, + "weight_sync_schemes": self._weight_sync_schemes, + "worker_idx": i, # Worker index for queue-based weight distribution + } + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + try: + proc.start() + except TypeError as err: + if "cannot pickle" in str(err): + raise RuntimeError( + "A non-serializable object was passed to the collector workers." + ) from err + except RuntimeError as err: + if "Cowardly refusing to serialize non-leaf tensor" in str(err): + raise RuntimeError( + "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " + "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" + "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" + "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." + ) from err + else: + raise err + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. +This error is likely due to an attempt to use a ParallelEnv in a +multiprocessed data collector. To do this, consider wrapping your +lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: +`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. +This will not only ensure that your lambda function is cloud-pickled once, but +also that the state dict is synchronised across processes if needed.""" + ) from err + pipe_child.close() + self.procs.append(proc) + self.pipes.append(pipe_parent) + + # Wait for workers to be ready + for i, pipe_parent in enumerate(self.pipes): + pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) + try: + msg = pipe_parent.recv() + except EOFError as e: + raise RuntimeError( + f"Worker {i} failed to initialize and closed the connection before sending status. " + f"This typically indicates that the worker process crashed during initialization. " + f"Check the worker process logs for the actual error." + ) from e + if msg != "instantiated": + # Check if it's an error dict from worker + if isinstance(msg, dict) and msg.get("error"): + # Reconstruct the exception from the worker + exc_type_name = msg["exception_type"] + exc_msg = msg["exception_msg"] + traceback_str = msg["traceback"] + + # Try to get the actual exception class + exc_class = None + exc_module = msg["exception_module"] + + if exc_module == "builtins": + # Get from builtins + import builtins + + exc_class = getattr(builtins, exc_type_name, None) + else: + # Try to import from the module + try: + import importlib + + mod = importlib.import_module(exc_module) + exc_class = getattr(mod, exc_type_name, None) + except Exception: + pass + + # Re-raise with original exception type if possible + if exc_class is not None: + raise exc_class( + f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Fall back to RuntimeError if we can't get the original type + raise RuntimeError( + f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Legacy string error message + raise RuntimeError(msg) + + # Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available + # (SharedMemWeightSyncScheme was already initialized before workers) + if self._weight_sync_schemes: + for model_id, scheme in self._weight_sync_schemes.items(): + # Only initialize non-SharedMem schemes here (need pipes) + if not isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( + scheme, "init_on_sender" + ): + scheme.init_on_sender(model_id=model_id, context=self) + # Get the initialized sender + self._weight_senders[model_id] = scheme.get_sender() + + self.queue_out = queue_out + self.closed = False + + _running_free = False + + def start(self): + """Starts the collector(s) for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method initiates the background collection of + data across multiple processes, allowing for decoupling of data collection and training. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env_fn = partial(GymEnv, "ALE/Pong-v5") + ... policy = RandomPolicy(env_fn().action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) + ... + ... # Create a multi-async data collector with 16 environments + ... num_envs = 16 + ... collector = MultiaSyncDataCollector( + ... [env_fn] * num_envs, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=num_envs * 16, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if self.init_random_frames is not None and self.init_random_frames > 0: + raise RuntimeError( + "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." + ) + self._running_free = True + for pipe in self.pipes: + pipe.send((None, "run_free")) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + if self._running_free: + for pipe in self.pipes: + pipe.send((None, "pause")) + # Make sure all workers are paused + for _ in self.pipes: + idx, msg = self.queue_out.get() + if msg != "paused": + raise ValueError(f"Expected paused, but got {msg=}.") + torchrl_logger.info(f"Worker {idx} is paused.") + self._running_free = False + yield None + for pipe in self.pipes: + pipe.send((None, "restart")) + self._running_free = True + else: + raise RuntimeError("Collector cannot be paused.") + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all processes. This operation is irreversible. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + try: + self._shutdown_main(timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def _shutdown_main(self, timeout: float | None = None) -> None: + if timeout is None: + timeout = 10 + try: + if self.closed: + return + _check_for_faulty_process(self.procs) + all_closed = [False] * self.num_workers + rep = 0 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + continue + self.pipes[idx].send((None, "close")) + + while not all(all_closed) and rep < 1000: + rep += 1 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + all_closed[idx] = True + continue + try: + if self.pipes[idx].poll(timeout / 1000 / self.num_workers): + msg = self.pipes[idx].recv() + if msg != "closed": + raise RuntimeError(f"got {msg} but expected 'close'") + all_closed[idx] = True + else: + continue + except BrokenPipeError: + all_closed[idx] = True + continue + self.closed = True + + self.queue_out.close() + for pipe in self.pipes: + pipe.close() + for proc in self.procs: + proc.join(1.0) + finally: + import torchrl + + num_threads = min( + torchrl._THREAD_POOL_INIT, + torch.get_num_threads() + + self._total_workers_from_env(self.create_env_fn), + ) + torch.set_num_threads(num_threads) + + for proc in self.procs: + if proc.is_alive(): + proc.terminate() + + def async_shutdown(self, timeout: float | None = None): + return self.shutdown(timeout=timeout) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed: integer representing the seed to be used for the environment. + static_seed (bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is + contained in the DataCollector, as the seed will be incremented for + each of these. The resulting seed is the seed of the last + environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + _check_for_faulty_process(self.procs) + for idx in range(self.num_workers): + self.pipes[idx].send(((seed, static_seed), "seed")) + new_seed, msg = self.pipes[idx].recv() + if msg != "seeded": + raise RuntimeError(f"Expected msg='seeded', got {msg}") + seed = new_seed + self.reset() + return seed + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + """Resets the environments to a new initial state. + + Args: + reset_idx: Optional. Sequence indicating which environments have + to be reset. If None, all environments are reset. + + """ + _check_for_faulty_process(self.procs) + + if reset_idx is None: + reset_idx = [True for _ in range(self.num_workers)] + for idx in range(self.num_workers): + if reset_idx[idx]: + self.pipes[idx].send((None, "reset")) + for idx in range(self.num_workers): + if reset_idx[idx]: + j, msg = self.pipes[idx].recv() + if msg != "reset": + raise RuntimeError(f"Expected msg='reset', got {msg}") + + def state_dict(self) -> OrderedDict: + """Returns the state_dict of the data collector. + + Each field represents a worker containing its own state_dict. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((None, "state_dict")) + state_dict = OrderedDict() + for idx in range(self.num_workers): + _state_dict, msg = self.pipes[idx].recv() + if msg != "state_dict": + raise RuntimeError(f"Expected msg='state_dict', got {msg}") + state_dict[f"worker{idx}"] = _state_dict + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict) -> None: + """Loads the state_dict on the workers. + + Args: + state_dict (OrderedDict): state_dict of the form + ``{"worker0": state_dict0, "worker1": state_dict1}``. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) + for idx in range(self.num_workers): + _, msg = self.pipes[idx].recv() + if msg != "loaded": + raise RuntimeError(f"Expected msg='loaded', got {msg}") + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy of the first worker. + + Args: + attr (str): The attribute name to retrieve from the policy. + + Returns: + The attribute value from the policy of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the policy. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_policy")) + result, msg = self.pipes[0].recv() + if msg != "getattr_policy": + raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_env(self, attr): + """Get an attribute from the environment of the first worker. + + Args: + attr (str): The attribute name to retrieve from the environment. + + Returns: + The attribute value from the environment of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the environment. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_env")) + result, msg = self.pipes[0].recv() + if msg != "getattr_env": + raise RuntimeError(f"Expected msg='getattr_env', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the fallback policy instance + if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: + return self._fallback_policy + elif hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + if hasattr(self, model_id): + return getattr(self, model_id) + else: + raise ValueError(f"Unknown model_id: {model_id}") + + def get_cached_weights(self, model_id: str): + """Get cached shared memory weights if available (for weight sync schemes). + + Args: + model_id: Model identifier + + Returns: + Cached TensorDict weights or None if not available + """ + if model_id == "policy" and hasattr(self, "_policy_weights_dict"): + # Get the policy device (first device if list) + policy_device = self.policy_device + if isinstance(policy_device, (list, tuple)): + policy_device = policy_device[0] if len(policy_device) > 0 else None + + # Return cached weights for this device + return self._policy_weights_dict.get(policy_device) + return None diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py new file mode 100644 index 00000000000..3f475673a30 --- /dev/null +++ b/torchrl/collectors/_multi_sync.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import collections +import time +import warnings +from collections import OrderedDict +from collections.abc import Iterator, Sequence +from queue import Empty + +import torch + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl import logger as torchrl_logger +from torchrl._utils import ( + _check_for_faulty_process, + accept_remote_rref_udf_invocation, + RL_WARNINGS, +) +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes synchronously. + + .. aafig:: + + +----------------------------------------------------------------------+ + | "MultiSyncDataCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | Main | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor" | "step" | "step" | "actor" | | + | | | | | | + | | "actor" | | | + | | | | | + | "yield batch of traj 1"------->"collect, train"| + | | | + | "step" | "step" | "step" | "step" | "step" | "step" | | + | | | | | | | | + | "actor" | "actor" | | | | + | | "step" | "step" | "actor" | | + | | | | | | + | "step" | "step" | "actor" | "step" | "step" | | + | | | | | | | + | "actor" | | "actor" | | + | "yield batch of traj 2"------->"collect, train"| + | | | + +----------------------------------------------------------------------+ + + Envs can be identical or different. + + The collection starts when the next item of the collector is queried, + and no environment step is computed in between the reception of a batch of + trajectory and the start of the next collection. + This class can be safely used with online RL sota-implementations. + + .. note:: + Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + ... collector = MultiSyncDataCollector(...) + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + if hasattr(self, "out_buffer"): + del self.out_buffer + if hasattr(self, "buffers"): + del self.buffers + try: + return super().shutdown(timeout=timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, worker_idx: int | None) -> int: + if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): + return self._frames_per_batch[worker_idx] + if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," + f" this results in more frames_per_batch per iteration that requested." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + frames_per_batch_worker = -( + -self.requested_frames_per_batch // self.num_workers + ) + return frames_per_batch_worker + + @property + def _queue_len(self) -> int: + return self.num_workers + + def iterator(self) -> Iterator[TensorDictBase]: + cat_results = self.cat_results + if cat_results is None: + cat_results = "stack" + + self.buffers = {} + dones = [False for _ in range(self.num_workers)] + workers_frames = [0 for _ in range(self.num_workers)] + same_device = None + self.out_buffer = None + preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 + + while not all(dones) and self._frames < self.total_frames: + _check_for_faulty_process(self.procs) + if self.update_at_each_batch: + self.update_policy_weights_() + + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + # Debug: sending 'continue' + self.pipes[idx].send((None, msg)) + + self._iter += 1 + + if preempt: + self.interruptor.start_collection() + while self.queue_out.qsize() < int( + self.num_workers * self.preemptive_threshold + ): + continue + self.interruptor.stop_collection() + # Now wait for stragglers to return + while self.queue_out.qsize() < int(self.num_workers): + continue + + recv = collections.deque() + t0 = time.time() + while len(recv) < self.num_workers and ( + (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) + ): + for _ in range(self.num_workers): + try: + new_data, j = self.queue_out.get(timeout=_TIMEOUT) + recv.append((new_data, j)) + except (TimeoutError, Empty): + _check_for_faulty_process(self.procs) + if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): + try: + self.shutdown() + finally: + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + + for _ in range(self.num_workers): + new_data, j = recv.popleft() + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + workers_frames[idx] = workers_frames[ + idx + ] + self.frames_per_batch_worker(worker_idx=idx) + continue + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.buffers[idx] = data + if use_buffers is None and j > 0: + self._use_buffers = False + except TypeError: + if use_buffers is None: + self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + + if preempt: + # mask buffers if cat, and create a mask if stack + if cat_results != "stack": + buffers = {} + for worker_idx, buffer in self.buffers.items(): + valid = buffer.get(("collector", "traj_ids")) != -1 + if valid.ndim > 2: + valid = valid.flatten(0, -2) + if valid.ndim == 2: + valid = valid.any(0) + buffers[worker_idx] = buffer[..., valid] + else: + for buffer in self.buffers.values(): + with buffer.unlock_(): + buffer.set( + ("collector", "mask"), + buffer.get(("collector", "traj_ids")) != -1, + ) + buffers = self.buffers + else: + buffers = self.buffers + + # Skip frame counting if this worker didn't send data this iteration + # (happens when reusing buffers or on first iteration with some workers) + if idx not in buffers: + continue + + workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() + + if workers_frames[idx] >= self.total_frames: + dones[idx] = True + + if self.replay_buffer is not None: + yield + self._frames += sum( + [ + self.frames_per_batch_worker(worker_idx) + for worker_idx in range(self.num_workers) + ] + ) + continue + + # we have to correct the traj_ids to make sure that they don't overlap + # We can count the number of frames collected for free in this loop + n_collected = 0 + for idx in buffers.keys(): + buffer = buffers[idx] + traj_ids = buffer.get(("collector", "traj_ids")) + if preempt: + if cat_results == "stack": + mask_frames = buffer.get(("collector", "traj_ids")) != -1 + n_collected += mask_frames.sum().cpu() + else: + n_collected += traj_ids.numel() + else: + n_collected += traj_ids.numel() + + if same_device is None: + prev_device = None + same_device = True + for item in self.buffers.values(): + if prev_device is None: + prev_device = item.device + else: + same_device = same_device and (item.device == prev_device) + + if cat_results == "stack": + stack = ( + torch.stack if self._use_buffers else TensorDict.maybe_dense_stack + ) + if same_device: + self.out_buffer = stack(list(buffers.values()), 0) + else: + self.out_buffer = stack( + [item.cpu() for item in buffers.values()], 0 + ) + else: + if self._use_buffers is None: + torchrl_logger.warning( + "use_buffer not specified and not yet inferred from data, assuming `True`." + ) + elif not self._use_buffers: + raise RuntimeError( + "Cannot concatenate results with use_buffers=False" + ) + try: + if same_device: + self.out_buffer = torch.cat(list(buffers.values()), cat_results) + else: + self.out_buffer = torch.cat( + [item.cpu() for item in buffers.values()], cat_results + ) + except RuntimeError as err: + if ( + preempt + and cat_results != -1 + and "Sizes of tensors must match" in str(err) + ): + raise RuntimeError( + "The value provided to cat_results isn't compatible with the collectors outputs. " + "Consider using `cat_results=-1`." + ) + raise + + # TODO: why do we need to do cat inplace and clone? + if self.split_trajs: + out = split_trajectories(self.out_buffer, prefix="collector") + else: + out = self.out_buffer + if cat_results in (-1, "stack"): + out.refine_names(*[None] * (out.ndim - 1) + ["time"]) + + self._frames += n_collected + + if self.postprocs: + self.postprocs = ( + self.postprocs.to(out.device) + if hasattr(self.postprocs, "to") + else self.postprocs + ) + out = self.postprocs(out) + if self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + if excluded_keys: + out = out.exclude(*excluded_keys) + yield out + del out + + del self.buffers + self.out_buffer = None + # We shall not call shutdown just yet as user may want to retrieve state_dict + # self._shutdown_main() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py new file mode 100644 index 00000000000..54e5c823888 --- /dev/null +++ b/torchrl/collectors/_runner.py @@ -0,0 +1,504 @@ +from __future__ import annotations + +import queue +from collections.abc import Callable +from functools import partial +from multiprocessing import connection, queues +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDictBase +from torch import nn as nn + +from torchrl import logger as torchrl_logger +from torchrl._utils import VERBOSE +from torchrl.collectors._constants import ( + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + DEFAULT_EXPLORATION_TYPE, +) +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase +from torchrl.collectors.utils import _map_to_cpu_if_needed, _TrajectoryPool +from torchrl.data import ReplayBuffer +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.utils import ExplorationType +from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.weight_sync_schemes import _resolve_model + + +def _make_policy_factory( + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx +): + if policy is not None and policy_factory is not None: + raise ValueError("policy cannot be used with policy_factory") + elif policy_factory is not None: + policy = policy_factory() + + if weight_sync_scheme is not None: + weight_sync_scheme.init_on_worker( + model=policy, model_id="policy", worker_idx=worker_idx + ) + return policy + + +def _main_async_collector( + pipe_parent: connection.Connection, + pipe_child: connection.Connection, + queue_out: queues.Queue, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 + create_env_kwargs: dict[str, Any], + policy: Callable[[TensorDictBase], TensorDictBase], + max_frames_per_traj: int, + frames_per_batch: int, + reset_at_each_iter: bool, + storing_device: torch.device | str | int | None, + env_device: torch.device | str | int | None, + policy_device: torch.device | str | int | None, + idx: int = 0, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + verbose: bool = VERBOSE, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + traj_pool: _TrajectoryPool = None, + trust_policy: bool = False, + compile_policy: bool = False, + cudagraph_policy: bool = False, + no_cuda_sync: bool = False, + policy_factory: Callable | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + worker_idx: int | None = None, +) -> None: + if collector_class is None: + collector_class = SyncDataCollector + pipe_parent.close() + # init variables that will be cleared when closing + collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None + + # Make a policy-factory out of the policy + policy_factory = partial( + _make_policy_factory, + policy=policy, + policy_factory=policy_factory, + weight_sync_scheme=weight_sync_schemes.get("policy"), + worker_idx=worker_idx, + ) + policy = None + try: + collector_class._ignore_rb = extend_buffer + inner_collector = collector_class( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + policy_factory=policy_factory, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=False, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + return_same_td=replay_buffer is None, + interruptor=interruptor, + set_truncated=set_truncated, + use_buffers=use_buffers, + replay_buffer=replay_buffer, + extend_buffer=False, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, + weight_sync_schemes=weight_sync_schemes, + ) + + # Set up weight receivers for worker process + if weight_sync_schemes: + inner_collector._weight_receivers = {} + inner_collector.pipe = pipe_child # Add pipe attribute for context + inner_collector.worker_idx = ( + worker_idx # Add worker index for queue-based schemes + ) + + for model_id, scheme in weight_sync_schemes.items(): + # Check if scheme has new API or legacy API + if hasattr(scheme, "init_on_worker"): + # For SharedMemWeightSyncScheme, init_on_worker reads from queue + # and applies weights to model - all handled by the receiver + scheme.init_on_worker(model_id=model_id, context=inner_collector) + receiver = scheme.get_receiver() + else: + # Legacy API + receiver = scheme.create_receiver() + receiver.set_context(inner_collector) + receiver.register_worker_transport(pipe_child) + + model = _resolve_model(inner_collector, model_id) + receiver.register_model(model) + + inner_collector._weight_receivers[model_id] = receiver + else: + inner_collector._weight_receivers = {} + + use_buffers = inner_collector._use_buffers + if verbose: + torchrl_logger.info("Sync data collector created") + dc_iter = iter(inner_collector) + j = 0 + pipe_child.send("instantiated") + except Exception as e: + # Send error information to main process + # We send a dict with the exception info so we can recreate it in the main process + import traceback + + error_info = { + "error": True, + "exception_type": type(e).__name__, + "exception_module": type(e).__module__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + } + try: + pipe_child.send(error_info) + except Exception: + # If pipe is broken, nothing we can do + pass + return + + has_timed_out = False + counter = 0 + run_free = False + while True: + _timeout = _TIMEOUT if not has_timed_out else 1e-3 + if not run_free and pipe_child.poll(_timeout): + counter = 0 + data_in, msg = pipe_child.recv() + if verbose: + torchrl_logger.info(f"worker {idx} received {msg}") + elif not run_free: + if verbose: + torchrl_logger.info(f"poll failed, j={j}, worker={idx}") + # default is "continue" (after first iteration) + # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe + # in that case, the main process probably expects the worker to continue collect data + if has_timed_out: + counter = 0 + # has_timed_out is True if the process failed to send data, which will + # typically occur if main has taken another batch (i.e. the queue is Full). + # In this case, msg is the previous msg sent by main, which will typically be "continue" + # If it's not the case, it is not expected that has_timed_out is True. + if msg not in ("continue", "continue_random"): + raise RuntimeError(f"Unexpected message after time out: msg={msg}") + else: + # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. + # this means that our process has been waiting for a command from main in vain, while main was not + # receiving data. + # This will occur if main is busy doing something else (e.g. computing loss etc). + + counter += _timeout + if verbose: + torchrl_logger.info(f"worker {idx} has counter {counter}") + if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): + raise RuntimeError( + f"This process waited for {counter} seconds " + f"without receiving a command from main. Consider increasing the maximum idle count " + f"if this is expected via the environment variable MAX_IDLE_COUNT " + f"(current value is {_MAX_IDLE_COUNT})." + f"\nIf this occurs at the end of a function or program, it means that your collector has not been " + f"collected, consider calling `collector.shutdown()` before ending the program." + ) + continue + else: + # placeholder, will be checked after + if msg != "continue": + torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") + msg = "continue" + if msg == "run_free": + run_free = True + msg = "continue" + if run_free: + # Capture shutdown / update / seed signal, but continue should not be expected + if pipe_child.poll(1e-4): + data_in, msg = pipe_child.recv() + torchrl_logger.info(f"worker {idx} received {msg} while running free") + if msg == "continue": + # Switch back to run_free = False + run_free = False + if msg == "pause": + queue_out.put((idx, "paused"), timeout=_TIMEOUT) + while not pipe_child.poll(1e-2): + continue + data_in, msg = pipe_child.recv() + if msg != "restart": + raise RuntimeError(f"Expected msg='restart', got {msg=}") + msg = "continue" + else: + data_in = None + # TODO: this does not work with random frames + msg = "continue" + # Note: The "continue" message handling has been moved below after update_weights handling + # to allow falling through from update_weights to continue + + if msg == "update": + torchrl_logger.info(f"worker {idx} updating the params...") + inner_collector.update_policy_weights_(policy_weights=data_in) + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + if msg == "register_shared_weights": + # Shared memory lazy registration: main process sends buffer reference + if verbose: + torchrl_logger.info( + f"worker {idx} received shared memory buffer registration" + ) + model_id, shared_buffer = data_in + + # Store the shared buffer reference for this model + # The receiver will use this buffer for all future weight accesses + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + # Update receiver's buffer reference + receiver = inner_collector._weight_receivers[model_id] + # Store the shared buffer - the model's parameters should point to this + if hasattr(receiver, "_shared_weights"): + receiver._shared_weights[model_id] = shared_buffer + + # Apply the buffer to the model immediately + # Only apply if the model is an nn.Module (has learnable parameters) + try: + model = receiver._resolve_model_ref() + except (ValueError, AttributeError) as e: + # Model not registered or reference is invalid + if verbose: + torchrl_logger.warning( + f"worker {idx} could not resolve model '{model_id}': {e}" + ) + continue + + if isinstance(model, nn.Module): + receiver.apply_weights(shared_buffer) + else: + if verbose: + torchrl_logger.info( + f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" + ) + + if verbose: + torchrl_logger.info( + f"worker {idx} registered shared buffer for model '{model_id}'" + ) + else: + torchrl_logger.warning( + f"worker {idx} received shared buffer for unknown model '{model_id}'" + ) + + # Send acknowledgment back to main process + pipe_child.send((None, "registered")) + has_timed_out = False + continue + + if msg == "update_weights": + # New weight update protocol for simplified weight sync system + if verbose: + torchrl_logger.info( + f"worker {idx} received weight update via new protocol" + ) + model_id, weights = data_in + + # Apply weights using the appropriate receiver for this model + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + inner_collector._weight_receivers[model_id].apply_weights(weights) + else: + torchrl_logger.warning( + f"worker {idx} received weights for unknown model '{model_id}'" + ) + + # After applying weights, we continue collecting immediately as if we received + # a "continue" message. This ensures the worker keeps collecting data without + # waiting for an explicit continue from the main process. + has_timed_out = False + msg = "continue" + # Now check if we should continue collecting + + if msg in ("continue", "continue_random"): + # This block handles both explicit continue messages and implicit ones after weight updates + if msg == "continue_random": + inner_collector.init_random_frames = float("inf") + else: + inner_collector.init_random_frames = -1 + + # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the + # main message loop above (msg == "update_weights" case). The receiver.receive() + # pattern is only used for schemes with separate communication channels like + # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). + # Calling receiver.receive() here would interfere with the pipe-based message protocol. + + next_data = next(dc_iter) + if pipe_child.poll(_MIN_TIMEOUT): + # in this case, main send a message to the worker while it was busy collecting trajectories. + # In that case, we skip the collected trajectory and get the message from main. This is faster than + # sending the trajectory in the queue until timeout when it's never going to be received. + continue + + if replay_buffer is not None: + if extend_buffer: + next_data.names = None + replay_buffer.extend(next_data) + + if run_free: + continue + + try: + queue_out.put((idx, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.info(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.info(f"worker {idx} has timed out") + has_timed_out = True + continue + + if j == 0 or not use_buffers: + collected_tensordict = next_data + if ( + storing_device is not None + and collected_tensordict.device != storing_device + ): + raise RuntimeError( + f"expected device to be {storing_device} but got {collected_tensordict.device}" + ) + if use_buffers: + # If policy and env are on cpu, we put in shared mem, + # if policy is on cuda and env on cuda, we are fine with this + # If policy is on cuda and env on cpu (or opposite) we put tensors that + # are on cpu in shared mem. + MPS_ERROR = ( + "tensors on mps device cannot be put in shared memory. Make sure " + "the shared device (aka storing_device) is set to CPU." + ) + if collected_tensordict.device is not None: + # placeholder in case we need different behaviors + if collected_tensordict.device.type in ("cpu",): + collected_tensordict.share_memory_() + elif collected_tensordict.device.type in ("mps",): + raise RuntimeError(MPS_ERROR) + elif collected_tensordict.device.type == "cuda": + collected_tensordict.share_memory_() + else: + raise NotImplementedError( + f"Device {collected_tensordict.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + def cast_tensor(x, MPS_ERROR=MPS_ERROR): + if x.device.type in ("cpu",): + x.share_memory_() + if x.device.type in ("mps",): + RuntimeError(MPS_ERROR) + + collected_tensordict.apply(cast_tensor, filter_empty=True) + data = (collected_tensordict, idx) + else: + if next_data is not collected_tensordict: + raise RuntimeError( + "SyncDataCollector should return the same tensordict modified in-place." + ) + data = idx # flag the worker that has sent its data + try: + queue_out.put((data, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.info(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.info(f"worker {idx} has timed out") + has_timed_out = True + continue + + if msg == "seed": + data_in, static_seed = data_in + new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) + torch.manual_seed(data_in) + np.random.seed(data_in) + pipe_child.send((new_seed, "seeded")) + has_timed_out = False + continue + + elif msg == "reset": + inner_collector.reset() + pipe_child.send((j, "reset")) + continue + + elif msg == "state_dict": + from torch.utils._pytree import tree_map + + state_dict = inner_collector.state_dict() + # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility + # CPU and CUDA tensors are already shareable and don't need conversion + state_dict = tree_map(_map_to_cpu_if_needed, state_dict) + pipe_child.send((state_dict, "state_dict")) + has_timed_out = False + continue + + elif msg == "load_state_dict": + state_dict = data_in + inner_collector.load_state_dict(state_dict) + del state_dict + pipe_child.send((j, "loaded")) + has_timed_out = False + continue + + elif msg == "getattr_policy": + attr_name = data_in + try: + result = getattr(inner_collector.policy, attr_name) + pipe_child.send((result, "getattr_policy")) + except AttributeError as e: + pipe_child.send((e, "getattr_policy")) + has_timed_out = False + continue + + elif msg == "getattr_env": + attr_name = data_in + try: + result = getattr(inner_collector.env, attr_name) + pipe_child.send((result, "getattr_env")) + except AttributeError as e: + pipe_child.send((e, "getattr_env")) + has_timed_out = False + continue + + elif msg == "close": + del collected_tensordict, data, next_data, data_in + inner_collector.shutdown() + del inner_collector, dc_iter + pipe_child.send("closed") + if verbose: + torchrl_logger.info(f"collector {idx} closed") + break + + else: + raise Exception(f"Unrecognized message {msg}") diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py new file mode 100644 index 00000000000..aee35c4042a --- /dev/null +++ b/torchrl/collectors/_single.py @@ -0,0 +1,1779 @@ +from __future__ import annotations + +import contextlib +import threading +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator, Sequence +from textwrap import indent +from typing import Any + +import torch + +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase +from torch import nn +from torchrl import compile_with_warmup, logger as torchrl_logger +from torchrl._utils import ( + _ends_with, + _make_ordinal_device, + _replace_last, + accept_remote_rref_udf_invocation, + prod, + RL_WARNINGS, +) +from torchrl.collectors._constants import ( + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, +) +from torchrl.collectors.base import DataCollectorBase +from torchrl.collectors.utils import _TrajectoryPool, split_trajectories +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator, RandomPolicy, StepCounter, TransformedEnv +from torchrl.envs.common import _do_nothing +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.envs.utils import ( + _aggregate_end_of_traj, + _make_compatible_policy, + set_exploration_type, +) +from torchrl.weight_update import WeightSyncScheme + + +@accept_remote_rref_udf_invocation +class SyncDataCollector(DataCollectorBase): + """Generic data collector for RL problems. Requires an environment constructor and a policy. + + Args: + create_env_fn (Callable or EnvBase): a callable that returns an instance of + :class:`~torchrl.envs.EnvBase` class, or the env itself. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the total + number of elements in a batch. + total_frames (int): A keyword-only argument representing the total + number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + create_env_kwargs (dict, optional): Dictionary of kwargs for + ``create_env_fn``. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e., no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + + .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer + as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. + + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + return_same_td (bool, optional): if ``True``, the same TensorDict + will be returned at each iteration, with its values + updated. This feature should be used cautiously: if the same + tensordict is added to a replay buffer for instance, + the whole content of the buffer will be identical. + Default is ``False``. + interruptor (_Interruptor, optional): + An _Interruptor object that can be used from outside the class to control rollout collection. + The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement + strategies such as preeptively stopping rollout collection. + Default is ``False``. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. + Defaults to ``None``. + + .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. + If the buffer needs to be populated with individual frames as they are collected, + set ``extend_buffer=False`` (deprecated). + + .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires + `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. + + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True`. + + .. note:: Setting this to `False` is deprecated and will be removed in a future version. + Extending the buffer with entire rollouts is the recommended approach for better + compatibility with postprocessing and trajectory splitting. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. + Consider using a constructor if the updater needs to be serialized. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector( + ... create_env_fn=env_maker, + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... ) + >>> for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + >>> del collector + + The collector delivers batches of data that are marked with a ``"time"`` + dimension. + + Examples: + >>> assert data.names[-1] == "time" + + """ + + _ignore_rb: bool = False + + def __init__( + self, + create_env_fn: ( + EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 + ), # noqa: F821 + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int = -1, + device: DEVICE_TYPING | None = None, + storing_device: DEVICE_TYPING | None = None, + policy_device: DEVICE_TYPING | None = None, + env_device: DEVICE_TYPING | None = None, + create_env_kwargs: dict[str, Any] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + return_same_td: bool = False, + reset_when_done: bool = True, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + **kwargs, + ): + self.closed = True + + # Initialize environment + env = self._init_env(create_env_fn, create_env_kwargs) + + # Initialize policy + policy = self._init_policy(policy, policy_factory, env, trust_policy) + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Handle trajectory pool and validate kwargs + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) + + # Set up devices and synchronization + self._setup_devices( + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + no_cuda_sync=no_cuda_sync, + ) + + self.env: EnvBase = env + del env + + # Set up policy version tracking + self._setup_policy_version_tracking(track_policy_version) + + # Set up replay buffer + self._setup_replay_buffer( + replay_buffer=replay_buffer, + extend_buffer=extend_buffer, + local_init_rb=local_init_rb, + postproc=postproc, + split_trajs=split_trajs, + return_same_td=return_same_td, + use_buffers=use_buffers, + ) + + self.closed = False + + # Validate reset_when_done + if not reset_when_done: + raise ValueError("reset_when_done is deprecated.") + self.reset_when_done = reset_when_done + self.n_env = self.env.batch_size.numel() + + # Register collector with policy and env + if hasattr(policy, "register_collector"): + policy.register_collector(self) + if hasattr(self.env, "register_collector"): + self.env.register_collector(self) + + # Set up policy and weights + self._setup_policy_and_weights(policy) + + # Apply environment device + self._apply_env_device() + + # Set up max frames per trajectory + self._setup_max_frames_per_traj(max_frames_per_traj) + + # Validate and set total frames + self.reset_at_each_iter = reset_at_each_iter + self._setup_total_frames(total_frames, frames_per_batch) + + # Set up init random frames + self._setup_init_random_frames(init_random_frames, frames_per_batch) + + # Set up postproc + self._setup_postproc(postproc) + + # Calculate frames per batch + self._setup_frames_per_batch(frames_per_batch) + + # Set exploration and other options + self.exploration_type = ( + exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE + ) + self.return_same_td = return_same_td + self.set_truncated = set_truncated + + # Create shuttle and rollout buffers + self._make_shuttle() + self._maybe_make_final_rollout(make_rollout=self._use_buffers) + self._set_truncated_keys() + + # Set split trajectories option + if split_trajs is None: + split_trajs = False + self.split_trajs = split_trajs + self._exclude_private_keys = True + + # Set up interruptor and frame tracking + self.interruptor = interruptor + self._frames = 0 + self._iter = -1 + + # Set up weight synchronization + self._setup_weight_sync(weight_updater, weight_sync_schemes) + + def _init_env( + self, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], + create_env_kwargs: dict[str, Any] | None, + ) -> EnvBase: + """Initialize and configure the environment.""" + from torchrl.envs.batched_envs import BatchedEnvBase + + if create_env_kwargs is None: + create_env_kwargs = {} + + if not isinstance(create_env_fn, EnvBase): + env = create_env_fn(**create_env_kwargs) + else: + env = create_env_fn + if create_env_kwargs: + if not isinstance(env, BatchedEnvBase): + raise RuntimeError( + "kwargs were passed to SyncDataCollector but they can't be set " + f"on environment of type {type(create_env_fn)}." + ) + env.update_kwargs(create_env_kwargs) + return env + + def _init_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: Callable[[], Callable] | None, + env: EnvBase, + trust_policy: bool | None, + ) -> TensorDictModule | Callable: + """Initialize and configure the policy.""" + if policy is None: + if policy_factory is not None: + policy = policy_factory() + else: + policy = RandomPolicy(env.full_action_spec) + elif policy_factory is not None: + raise TypeError("policy_factory cannot be used with policy argument.") + + # If the underlying policy has a state_dict, keep a reference to it + if hasattr(policy, "state_dict"): + self._policy_w_state_dict = policy + + if trust_policy is None: + trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) + self.trust_policy = trust_policy + + return policy + + def _setup_devices( + self, + device: DEVICE_TYPING | None, + storing_device: DEVICE_TYPING | None, + policy_device: DEVICE_TYPING | None, + env_device: DEVICE_TYPING | None, + no_cuda_sync: bool, + ) -> None: + """Set up devices and synchronization functions.""" + storing_device, policy_device, env_device = self._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + + self.storing_device = storing_device + self._sync_storage = self._get_sync_fn(storing_device) + + self.env_device = env_device + self._sync_env = self._get_sync_fn(env_device) + + self.policy_device = policy_device + self._sync_policy = self._get_sync_fn(policy_device) + + self.device = device + self.no_cuda_sync = no_cuda_sync + self._cast_to_policy_device = self.policy_device != self.env_device + + def _get_sync_fn(self, device: torch.device | None) -> Callable: + """Get the appropriate synchronization function for a device.""" + if device is not None and device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + return torch.cuda.synchronize + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): + return torch.mps.synchronize + elif hasattr(torch, "npu") and torch.npu.is_available(): + return torch.npu.synchronize + elif device.type == "cpu": + return _do_nothing + else: + raise RuntimeError("Non supported device") + else: + return _do_nothing + + def _setup_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking if requested.""" + self.policy_version_tracker = track_policy_version + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." + ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: + self.policy_version_tracker = None + + def _setup_replay_buffer( + self, + replay_buffer: ReplayBuffer | None, + extend_buffer: bool, + local_init_rb: bool | None, + postproc: Callable | None, + split_trajs: bool | None, + return_same_td: bool, + use_buffers: bool | None, + ) -> None: + """Set up replay buffer configuration and validate compatibility.""" + self.replay_buffer = replay_buffer + self.extend_buffer = extend_buffer + + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + # Validate replay buffer compatibility + if self.replay_buffer is not None and not self._ignore_rb: + if postproc is not None and not self.extend_buffer: + raise TypeError( + "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." + ) + if split_trajs not in (None, False) and not self.extend_buffer: + raise TypeError( + "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if return_same_td: + raise TypeError( + "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") + + if use_buffers is None: + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None + self._use_buffers = use_buffers + + def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: + """Set up policy, wrapped policy, and extract weights.""" + self._original_policy = policy + + # Check if policy has meta-device parameters (sent from weight sync schemes) + # In that case, skip device placement - weights will come from the receiver + has_meta_params = False + if isinstance(policy, nn.Module): + for p in policy.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + if has_meta_params: + # Skip device placement for meta policies - schemes handle weight application + # Policy stays as-is, weights will be applied by the receiver + self.get_weights_fn = lambda: TensorDict.from_module(policy).data + else: + # Normal path: move policy to correct device + policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) + + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Extract policy weights from the uncompiled policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation + if isinstance(self._wrapped_policy_uncompiled, nn.Module): + self.policy_weights = TensorDict.from_module( + self._wrapped_policy_uncompiled, as_module=True + ).data + else: + self.policy_weights = TensorDict() + + # If policy doesn't have meta params, compile immediately + # Otherwise, defer until first use (after weights are loaded) + if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy): + self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy( + self._wrapped_policy_uncompiled + ) + + def _compile_wrapped_policy(self, policy): + """Apply compilation and/or cudagraph to a policy.""" + if self.compiled_policy: + policy = compile_with_warmup(policy, **self.compiled_policy_kwargs) + if self.cudagraphed_policy: + policy = CudaGraphModule( + policy, + in_keys=[], + out_keys=[], + device=self.policy_device, + **self.cudagraphed_policy_kwargs, + ) + return policy + + @property + def _wrapped_policy(self): + """Returns the compiled policy, compiling it lazily if needed.""" + if (policy := self._wrapped_policy_maybe_compiled) is None: + if self.compiled_policy or self.cudagraphed_policy: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled) + else: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._wrapped_policy_uncompiled + return policy + + @_wrapped_policy.setter + def _wrapped_policy(self, value): + """Allow setting the wrapped policy during initialization.""" + self._wrapped_policy_uncompiled = value + self._wrapped_policy_maybe_compiled = None + + def _apply_env_device(self) -> None: + """Apply device to environment if specified.""" + if self.env_device: + self.env: EnvBase = self.env.to(self.env_device) + elif self.env.device is not None: + # Use the device of the env if none was provided + self.env_device = self.env.device + + # Check if we need to cast to env device + self._cast_to_env_device = self._cast_to_policy_device or ( + self.env.device != self.storing_device + ) + + def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: + """Set up maximum frames per trajectory and add StepCounter if needed.""" + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: + # Check that there is no StepCounter yet + for key in self.env.output_spec.keys(True, True): + if isinstance(key, str): + key = (key,) + if "step_count" in key: + raise ValueError( + "A 'step_count' key is already present in the environment " + "and the 'max_frames_per_traj' argument may conflict with " + "a 'StepCounter' that has already been set. " + "Possible solutions: Set max_frames_per_traj to 0 or " + "remove the StepCounter limit from the environment transforms." + ) + self.env = TransformedEnv( + self.env, StepCounter(max_steps=self.max_frames_per_traj) + ) + + def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: + """Validate and set total frames.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " + f"This means {frames_per_batch - remainder} additional frames will be collected." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_init_random_frames( + self, init_random_frames: int | None, frames_per_batch: int + ) -> None: + """Set up initial random frames.""" + self.init_random_frames = ( + int(init_random_frames) if init_random_frames not in (None, -1) else 0 + ) + if ( + init_random_frames not in (-1, None, 0) + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + + def _setup_postproc(self, postproc: Callable | None) -> None: + """Set up post-processing transform.""" + self.postproc = postproc + if ( + self.postproc is not None + and hasattr(self.postproc, "to") + and self.storing_device + ): + postproc = self.postproc.to(self.storing_device) + if postproc is not self.postproc and postproc is not None: + self.postproc = postproc + + def _setup_frames_per_batch(self, frames_per_batch: int) -> None: + """Calculate and validate frames per batch.""" + if frames_per_batch % self.n_env != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.requested_frames_per_batch = self.frames_per_batch * self.n_env + + def _setup_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization system.""" + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + # For single-process collectors, we don't need senders/receivers + # The policy is local and changes are immediately visible + # Senders will be set up in multiprocess collectors during _run_processes + self.weight_updater = None # Don't use legacy system + elif weight_updater is not None: + # Use legacy weight updater system if explicitly provided + if not isinstance(weight_updater, WeightUpdaterBase): + if callable(weight_updater): + weight_updater = weight_updater() + else: + raise TypeError( + f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." + ) + warnings.warn( + "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + else: + # No weight sync needed for single-process collectors + self.weight_updater = None + self._weight_sync_schemes = None + self._weight_senders = {} + + @property + def _traj_pool(self): + pool = getattr(self, "_traj_pool_val", None) + if pool is None: + pool = self._traj_pool_val = _TrajectoryPool() + return pool + + def _make_shuttle(self): + # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env + with torch.no_grad(): + self._shuttle = self.env.reset() + if self.policy_device != self.env_device or self.env_device is None: + self._shuttle_has_no_device = True + self._shuttle.clear_device_() + else: + self._shuttle_has_no_device = False + + traj_ids = self._traj_pool.get_traj_and_increment( + self.n_env, device=self.storing_device + ).view(self.env.batch_size) + self._shuttle.set( + ("collector", "traj_ids"), + traj_ids, + ) + + def _maybe_make_final_rollout(self, make_rollout: bool): + if make_rollout: + with torch.no_grad(): + self._final_rollout = self.env.fake_tensordict() + + # If storing device is not None, we use this to cast the storage. + # If it is None and the env and policy are on the same device, + # the storing device is already the same as those, so we don't need + # to consider this use case. + # In all other cases, we can't really put a device on the storage, + # since at least one data source has a device that is not clear. + if self.storing_device: + self._final_rollout = self._final_rollout.to( + self.storing_device, non_blocking=True + ) + else: + # erase all devices + self._final_rollout.clear_device_() + + # Check if policy has meta-device parameters (not yet initialized) + has_meta_params = False + if hasattr(self, "_wrapped_policy_uncompiled") and isinstance( + self._wrapped_policy_uncompiled, nn.Module + ): + for p in self._wrapped_policy_uncompiled.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + # If the policy has a valid spec, we use it + self._policy_output_keys = set() + if ( + make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "spec", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + is not None + and all( + v is not None + for v in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.values(True, True) + ) + ): + if any( + key not in self._final_rollout.keys(isinstance(key, tuple)) + for key in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.keys(True, True) + ): + # if policy spec is non-empty, all the values are not None and the keys + # match the out_keys we assume the user has given all relevant information + # the policy could have more keys than the env: + policy_spec = ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + if policy_spec.ndim < self._final_rollout.ndim: + policy_spec = policy_spec.expand(self._final_rollout.shape) + for key, spec in policy_spec.items(True, True): + self._policy_output_keys.add(key) + if key in self._final_rollout.keys(True): + continue + self._final_rollout.set(key, spec.zero()) + elif ( + not make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "out_keys", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ): + self._policy_output_keys = list( + ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ) + elif has_meta_params: + # Policy has meta params and no spec/out_keys - defer initialization + # Mark that we need to initialize later when weights are loaded + self._policy_output_keys = set() + if make_rollout: + # We'll populate keys on first actual rollout after weights are loaded + self._final_rollout_needs_init = True + else: + if make_rollout: + # otherwise, we perform a small number of steps with the policy to + # determine the relevant keys with which to pre-populate _final_rollout. + # This is the safest thing to do if the spec has None fields or if there is + # no spec at all. + # See #505 for additional context. + self._final_rollout.update(self._shuttle.copy()) + with torch.no_grad(): + policy_input = self._shuttle.copy() + if self.policy_device: + policy_input = policy_input.to(self.policy_device) + # we cast to policy device, we'll deal with the device later + policy_input_copy = policy_input.copy() + policy_input_clone = ( + policy_input.clone() + ) # to test if values have changed in-place + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + + # check that we don't have exclusive keys, because they don't appear in keys + def check_exclusive(val): + if ( + isinstance(val, LazyStackedTensorDict) + and val._has_exclusive_keys + ): + raise RuntimeError( + "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " + "Consider using a placeholder for missing keys." + ) + + policy_output._fast_apply( + check_exclusive, call_on_nested=True, filter_empty=True + ) + + # Use apply, because it works well with lazy stacks + # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit + # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has + # changed them here). + # This will cause a failure to update entries when policy and env device mismatch and + # casting is necessary. + def filter_policy(name, value_output, value_input, value_input_clone): + if (value_input is None) or ( + (value_output is not value_input) + and ( + value_output.device != value_input_clone.device + or ~torch.isclose(value_output, value_input_clone).any() + ) + ): + return value_output + + filtered_policy_output = policy_output.apply( + filter_policy, + policy_input_copy, + policy_input_clone, + default=None, + filter_empty=True, + named=True, + ) + self._policy_output_keys = list( + self._policy_output_keys.union( + set(filtered_policy_output.keys(True, True)) + ) + ) + if make_rollout: + self._final_rollout.update( + policy_output.select(*self._policy_output_keys) + ) + del filtered_policy_output, policy_output, policy_input + + _env_output_keys = [] + for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: + _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) + self._env_output_keys = _env_output_keys + if make_rollout: + self._final_rollout = ( + self._final_rollout.unsqueeze(-1) + .expand(*self.env.batch_size, self.frames_per_batch) + .clone() + .zero_() + ) + + # in addition to outputs of the policy, we add traj_ids to + # _final_rollout which will be collected during rollout + self._final_rollout.set( + ("collector", "traj_ids"), + torch.zeros( + *self._final_rollout.batch_size, + dtype=torch.int64, + device=self.storing_device, + ), + ) + self._final_rollout.refine_names(..., "time") + + def _set_truncated_keys(self): + self._truncated_keys = [] + if self.set_truncated: + if not any(_ends_with(key, "truncated") for key in self.env.done_keys): + raise RuntimeError( + "set_truncated was set to True but no truncated key could be found " + "in the environment. Make sure the truncated keys are properly set using " + "`env.add_truncated_keys()` before passing the env to the collector." + ) + self._truncated_keys = [ + key for key in self.env.done_keys if _ends_with(key, "truncated") + ] + + @classmethod + def _get_devices( + cls, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + device = _make_ordinal_device(torch.device(device) if device else device) + storing_device = _make_ordinal_device( + torch.device(storing_device) if storing_device else device + ) + policy_device = _make_ordinal_device( + torch.device(policy_device) if policy_device else device + ) + env_device = _make_ordinal_device( + torch.device(env_device) if env_device else device + ) + if storing_device is None and (env_device == policy_device): + storing_device = env_device + return storing_device, policy_device, env_device + + # for RPC + def next(self): + return super().next() + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed (int): integer representing the seed to be used for the environment. + static_seed(bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is contained in the DataCollector, as the + seed will be incremented for each of these. The resulting seed is the seed of the last environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed + + def iterator(self) -> Iterator[TensorDictBase]: + """Iterates through the DataCollector. + + Yields: TensorDictBase objects containing (chunks of) trajectories + + """ + if ( + not self.no_cuda_sync + and self.storing_device + and self.storing_device.type == "cuda" + ): + stream = torch.cuda.Stream(self.storing_device, priority=-1) + event = stream.record_event() + streams = [stream] + events = [event] + elif not self.no_cuda_sync and self.storing_device is None: + streams = [] + events = [] + # this way of checking cuda is robust to lazy stacks with mismatching shapes + cuda_devices = set() + + def cuda_check(tensor: torch.Tensor): + if tensor.is_cuda: + cuda_devices.add(tensor.device) + + if not self._use_buffers: + # This may be a bit dangerous as `torch.device("cuda")` may not have a precise + # device associated, whereas `tensor.device` always has + for spec in self.env.specs.values(True, True): + if spec.device is not None and spec.device.type == "cuda": + if ":" not in str(spec.device): + raise RuntimeError( + "A cuda spec did not have a device associated. Make sure to " + "pass `'cuda:device_num'` to each spec device." + ) + cuda_devices.add(spec.device) + else: + self._final_rollout.apply(cuda_check, filter_empty=True) + for device in cuda_devices: + streams.append(torch.cuda.Stream(device, priority=-1)) + events.append(streams[-1].record_event()) + else: + streams = [] + events = [] + with contextlib.ExitStack() as stack: + for stream in streams: + stack.enter_context(torch.cuda.stream(stream)) + + while self._frames < self.total_frames: + self._iter += 1 + if self.verbose: + torchrl_logger.info("Collector: rollout.") + tensordict_out = self.rollout() + if tensordict_out is None: + # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out + # frames are updated within the rollout function + if self.verbose: + torchrl_logger.info("Collector: No tensordict_out. Yielding.") + yield + continue + self._increment_frames(tensordict_out.numel()) + tensordict_out = self._postproc(tensordict_out) + if self.verbose: + torchrl_logger.info("Collector: postproc done.") + if self.return_same_td: + # This is used with multiprocessed collectors to use the buffers + # stored in the tensordict. + if events: + for event in events: + event.record() + event.synchronize() + yield tensordict_out + elif self.replay_buffer is not None and not self._ignore_rb: + self.replay_buffer.extend(tensordict_out) + if self.verbose: + torchrl_logger.info( + f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " + "Buffer write count: {self.replay_buffer.write_count}. Yielding." + ) + yield + else: + # we must clone the values, as the tensordict is updated in-place. + # otherwise the following code may break: + # >>> for i, data in enumerate(collector): + # >>> if i == 0: + # >>> data0 = data + # >>> elif i == 1: + # >>> data1 = data + # >>> else: + # >>> break + # >>> assert data0["done"] is not data1["done"] + yield tensordict_out.clone() + + def start(self): + """Starts the collector in a separate thread for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data + collection from training, allowing your training loop to run independently of the data collection process. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env = GymEnv("ALE/Pong-v5") + ... policy = RandomPolicy(env.action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) + ... + ... # Create a synchronous data collector + ... collector = SyncDataCollector( + ... env, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=256, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if not self.is_running(): + self._stop = False + self._thread = threading.Thread(target=self._run_iterator) + self._thread.daemon = ( + True # So that the thread dies when the main program exits + ) + self._thread.start() + + def _run_iterator(self): + for _ in self: + if self._stop: + return + + def is_running(self): + return hasattr(self, "_thread") and self._thread.is_alive() + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Finishes processes started by ray.init() during async execution.""" + self._stop = True + if hasattr(self, "_thread") and self._thread.is_alive(): + self._thread.join(timeout=timeout) + self.shutdown(close_env=close_env) + + def _postproc(self, tensordict_out): + if self.split_trajs: + tensordict_out = split_trajectories(tensordict_out, prefix="collector") + if self.postproc is not None: + tensordict_out = self.postproc(tensordict_out) + if self._exclude_private_keys: + + def is_private(key): + if isinstance(key, str) and key.startswith("_"): + return True + if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): + return True + return False + + excluded_keys = [ + key for key in tensordict_out.keys(True) if is_private(key) + ] + tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) + return tensordict_out + + def _update_traj_ids(self, env_output) -> None: + # we can't use the reset keys because they're gone + traj_sop = _aggregate_end_of_traj( + env_output.get("next"), done_keys=self.env.done_keys + ) + if traj_sop.any(): + device = self.storing_device + + traj_ids = self._shuttle.get(("collector", "traj_ids")) + if device is not None: + traj_ids = traj_ids.to(device) + traj_sop = traj_sop.to(device) + elif traj_sop.device != traj_ids.device: + traj_sop = traj_sop.to(traj_ids.device) + + pool = self._traj_pool + new_traj = pool.get_traj_and_increment( + traj_sop.sum(), device=traj_sop.device + ) + traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) + self._shuttle.set(("collector", "traj_ids"), traj_ids) + + @torch.no_grad() + def rollout(self) -> TensorDictBase: + """Computes a rollout in the environment using the provided policy. + + Returns: + TensorDictBase containing the computed rollout. + + """ + if self.reset_at_each_iter: + self._shuttle.update(self.env.reset()) + + # self._shuttle.fill_(("collector", "step_count"), 0) + if self._use_buffers: + self._final_rollout.fill_(("collector", "traj_ids"), -1) + else: + pass + tensordicts = [] + with set_exploration_type(self.exploration_type): + for t in range(self.frames_per_batch): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.env.rand_action(self._shuttle) + if ( + self.policy_device is not None + and self.policy_device != self.env_device + ): + # TODO: This may break with exclusive / ragged lazy stacks + self._shuttle.apply( + lambda name, val: val.to( + device=self.policy_device, non_blocking=True + ) + if name in self._policy_output_keys + else val, + out=self._shuttle, + named=True, + nested_keys=True, + ) + else: + if self._cast_to_policy_device: + if self.policy_device is not None: + # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking + non_blocking = ( + not self.no_cuda_sync + or self.policy_device.type == "cuda" + ) + policy_input = self._shuttle.to( + self.policy_device, + non_blocking=non_blocking, + ) + if not self.no_cuda_sync: + self._sync_policy() + elif self.policy_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # policy_input = self._shuttle.clear_device_() + policy_input = self._shuttle + else: + policy_input = self._shuttle + # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() + if self._shuttle is not policy_output: + # ad-hoc update shuttle + self._shuttle.update( + policy_output, keys_to_update=self._policy_output_keys + ) + + if self._cast_to_env_device: + if self.env_device is not None: + non_blocking = ( + not self.no_cuda_sync or self.env_device.type == "cuda" + ) + env_input = self._shuttle.to( + self.env_device, non_blocking=non_blocking + ) + if not self.no_cuda_sync: + self._sync_env() + elif self.env_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # env_input = self._shuttle.clear_device_() + env_input = self._shuttle + else: + env_input = self._shuttle + env_output, env_next_output = self.env.step_and_maybe_reset(env_input) + + if self._shuttle is not env_output: + # ad-hoc update shuttle + next_data = env_output.get("next") + if self._shuttle_has_no_device: + # Make sure + next_data.clear_device_() + self._shuttle.set("next", next_data) + + if self.verbose: + torchrl_logger.info( + f"Collector: Rollout step completed {self._iter=}." + ) + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + if self.verbose: + torchrl_logger.info( + f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." + ) + self.replay_buffer.add(self._shuttle) + if self._increment_frames(self._shuttle.numel()): + return + else: + if self.storing_device is not None: + if self.verbose: + torchrl_logger.info( + f"Collector: Moving to {self.storing_device} and adding to queue." + ) + non_blocking = ( + not self.no_cuda_sync or self.storing_device.type == "cuda" + ) + tensordicts.append( + self._shuttle.to( + self.storing_device, non_blocking=non_blocking + ) + ) + if not self.no_cuda_sync: + self._sync_storage() + else: + if self.verbose: + torchrl_logger.info( + "Collector: Adding to queue (no device)." + ) + tensordicts.append(self._shuttle) + + # carry over collector data without messing up devices + collector_data = self._shuttle.get("collector").copy() + self._shuttle = env_next_output + if self._shuttle_has_no_device: + self._shuttle.clear_device_() + self._shuttle.set("collector", collector_data) + self._update_traj_ids(env_output) + + if ( + self.interruptor is not None + and self.interruptor.collection_stopped() + ): + if self.verbose: + torchrl_logger.info("Collector: Interruptor stopped.") + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + result = self._final_rollout + if self._use_buffers: + try: + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + except RuntimeError: + with self._final_rollout.unlock_(): + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + else: + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + break + else: + if self._use_buffers: + torchrl_logger.info("Returning final rollout within buffer.") + result = self._final_rollout + try: + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + + except RuntimeError: + with self._final_rollout.unlock_(): + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + else: + torchrl_logger.info( + "Returning final rollout with NO buffer (maybe_dense_stack)." + ) + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + result.refine_names(..., "time") + + return self._maybe_set_truncated(result) + + def _maybe_set_truncated(self, final_rollout): + last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) + for truncated_key in self._truncated_keys: + truncated = final_rollout["next", truncated_key] + truncated[last_step] = True + final_rollout["next", truncated_key] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) + return final_rollout + + @torch.no_grad() + def reset(self, index=None, **kwargs) -> None: + """Resets the environments to a new initial state.""" + # metadata + collector_metadata = self._shuttle.get("collector").clone() + if index is not None: + # check that the env supports partial reset + if prod(self.env.batch_size) == 0: + raise RuntimeError("resetting unique env with index is not permitted.") + for reset_key, done_keys in zip( + self.env.reset_keys, self.env.done_keys_groups + ): + _reset = torch.zeros( + self.env.full_done_spec[done_keys[0]].shape, + dtype=torch.bool, + device=self.env.device, + ) + _reset[index] = 1 + self._shuttle.set(reset_key, _reset) + else: + _reset = None + self._shuttle.zero_() + + self._shuttle.update(self.env.reset(**kwargs), inplace=True) + collector_metadata["traj_ids"] = ( + collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() + ) + self._shuttle["collector"] = collector_metadata + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all workers and/or closes the local environment. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + No effect for this class. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + try: + if not self.closed: + self.closed = True + del self._shuttle + if self._use_buffers: + del self._final_rollout + if close_env and not self.env.is_closed: + self.env.close(raise_if_closed=raise_on_error) + del self.env + return + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def state_dict(self) -> OrderedDict: + """Returns the local state_dict of the data collector (environment and policy). + + Returns: + an ordered dictionary with fields :obj:`"policy_state_dict"` and + `"env_state_dict"`. + + """ + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, TransformedEnv): + env_state_dict = self.env.transform.state_dict() + elif isinstance(self.env, BatchedEnvBase): + env_state_dict = self.env.state_dict() + else: + env_state_dict = OrderedDict() + + if hasattr(self, "_policy_w_state_dict"): + policy_state_dict = self._policy_w_state_dict.state_dict() + state_dict = OrderedDict( + policy_state_dict=policy_state_dict, + env_state_dict=env_state_dict, + ) + else: + state_dict = OrderedDict(env_state_dict=env_state_dict) + + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + """Loads a state_dict on the environment and policy. + + Args: + state_dict (OrderedDict): ordered dictionary containing the fields + `"policy_state_dict"` and :obj:`"env_state_dict"`. + + """ + strict = kwargs.get("strict", True) + if strict or "env_state_dict" in state_dict: + self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) + if strict or "policy_state_dict" in state_dict: + if not hasattr(self, "_policy_w_state_dict"): + raise ValueError( + "Underlying policy does not have state_dict to load policy_state_dict into." + ) + self._policy_w_state_dict.load_state_dict( + state_dict["policy_state_dict"], **kwargs + ) + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def __repr__(self) -> str: + try: + env_str = indent(f"env={self.env}", 4 * " ") + policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") + td_out_str = repr(getattr(self, "_final_rollout", None)) + if len(td_out_str) > 50: + td_out_str = td_out_str[:50] + "..." + td_out_str = indent(f"td_out={td_out_str}", 4 * " ") + string = ( + f"{self.__class__.__name__}(" + f"\n{env_str}," + f"\n{policy_str}," + f"\n{td_out_str}," + f"\nexploration={self.exploration_type})" + ) + return string + except Exception: + return f"{type(self).__name__}(not_init)" + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy.""" + # send command to policy to return the attr + return getattr(self._wrapped_policy, attr) + + def getattr_env(self, attr): + """Get an attribute from the environment.""" + # send command to env to return the attr + return getattr(self.env, attr) + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + # send command to rb to return the attr + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the unwrapped policy instance for weight synchronization + # The unwrapped policy has the same parameter structure as what's + # extracted in the main process, avoiding key mismatches when + # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) + if hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + if hasattr(self, model_id): + return getattr(self, model_id) + else: + raise ValueError(f"Unknown model_id: {model_id}") diff --git a/torchrl/collectors/_single_async.py b/torchrl/collectors/_single_async.py new file mode 100644 index 00000000000..131c913b184 --- /dev/null +++ b/torchrl/collectors/_single_async.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import Any + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl._utils import accept_remote_rref_udf_invocation +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase + + +@accept_remote_rref_udf_invocation +class aSyncDataCollector(MultiaSyncDataCollector): + """Runs a single DataCollector on a separate process. + + This is mostly useful for offline RL paradigms where the policy being + trained can differ from the policy used to collect data. In online + settings, a regular DataCollector should be preferred. This class is + merely a wrapper around a MultiaSyncDataCollector where a single process + is being created. + + Args: + create_env_fn (Callabled): Callable returning an instance of EnvBase + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the + total number of elements in a batch. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + """ + + def __init__( + self, + create_env_fn: Callable[[], EnvBase], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict[str, Any]] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + set_truncated: bool = False, + track_policy_version: bool = False, + **kwargs, + ): + super().__init__( + create_env_fn=[create_env_fn], + policy=policy, + policy_factory=policy_factory, + total_frames=total_frames, + create_env_kwargs=[create_env_kwargs] + if create_env_kwargs + else create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + init_random_frames=init_random_frames, + postproc=postproc, + split_trajs=split_trajs, + device=device, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + update_at_each_batch=update_at_each_batch, + preemptive_threshold=preemptive_threshold, + num_threads=num_threads, + num_sub_threads=num_sub_threads, + set_truncated=set_truncated, + track_policy_version=track_policy_version, + **kwargs, + ) + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + return super().shutdown( + timeout=timeout, close_env=close_env, raise_on_error=raise_on_error + ) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) diff --git a/torchrl/collectors/base.py b/torchrl/collectors/base.py new file mode 100644 index 00000000000..1ad97d4056f --- /dev/null +++ b/torchrl/collectors/base.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +import abc +import contextlib +import functools +import typing +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator +from copy import deepcopy +from typing import Any + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.base import NO_DEFAULT +from tensordict.nn import TensorDictModule, TensorDictModuleBase +from torch import nn as nn +from torch.utils.data import IterableDataset +from torchrl.collectors.utils import _map_weight + +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.weight_update import WeightReceiver, WeightSender, WeightSyncScheme + + +class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): + """Base class for data collectors.""" + + _task = None + _iterator = None + total_frames: int + requested_frames_per_batch: int + frames_per_batch: int + trust_policy: bool + compiled_policy: bool + cudagraphed_policy: bool + _weight_updater: WeightUpdaterBase | None = None + _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None + _weight_senders: dict[str, WeightSender] | None = None + _weight_receivers: dict[str, WeightReceiver] | None = None + verbose: bool = False + + @property + def weight_updater(self) -> WeightUpdaterBase: + return self._weight_updater + + @weight_updater.setter + def weight_updater(self, value: WeightUpdaterBase | None): + if value is not None: + if not isinstance(value, WeightUpdaterBase) and callable( + value + ): # Fall back to default constructor + value = value() + value.register_collector(self) + if value.collector is not self: + raise RuntimeError("Failed to register collector.") + self._weight_updater = value + + def _get_policy_and_device( + self, + policy: Callable[[Any], Any] | None = None, + policy_device: Any = NO_DEFAULT, + env_maker: Any | None = None, + env_maker_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorDictModule, None | Callable[[], dict]]: + """Util method to get a policy and its device given the collector __init__ inputs. + + We want to copy the policy and then move the data there, not call policy.to(device). + + Args: + policy (TensorDictModule, optional): a policy to be used + policy_device (torch.device, optional): the device where the policy should be placed. + Defaults to self.policy_device + env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. + env_maker_kwargs (a dict, optional): the env_maker function kwargs. + + """ + if policy_device is NO_DEFAULT: + policy_device = self.policy_device + + if not policy_device: + return policy, None + + if isinstance(policy, nn.Module): + param_and_buf = TensorDict.from_module(policy, as_module=True) + else: + # Because we want to reach the warning + param_and_buf = TensorDict() + + i = -1 + for p in param_and_buf.values(True, True): + i += 1 + if p.device != policy_device: + # Then we need casting + break + else: + if i == -1 and not self.trust_policy: + # We trust that the policy policy device is adequate + warnings.warn( + "A policy device was provided but no parameter/buffer could be found in " + "the policy. Casting to policy_device is therefore impossible. " + "The collector will trust that the devices match. To suppress this " + "warning, set `trust_policy=True` when building the collector." + ) + return policy, None + + # Create a stateless policy, then populate this copy with params on device + def get_original_weights(policy=policy): + td = TensorDict.from_module(policy) + return td.data + + # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function + with param_and_buf.data.to("meta").to_module(policy): + policy_new_device = deepcopy(policy) + + param_and_buf_new_device = param_and_buf.apply( + functools.partial(_map_weight, policy_device=policy_device), + filter_empty=False, + ) + param_and_buf_new_device.to_module(policy_new_device) + # Sanity check + if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( + get_original_weights().keys(True, True) + ): + raise RuntimeError("Failed to map weights. The weight sets mismatch.") + return policy_new_device, get_original_weights + + def start(self): + """Starts the collector for asynchronous data collection. + + This method initiates the background collection of data, allowing for decoupling of data collection and training. + + The collected data is typically stored in a replay buffer passed during the collector's initialization. + + .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` + when you're done with it to free up resources. + + .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. + Ensure you understand the implications for your specific algorithm before using this mode. + + Raises: + NotImplementedError: If not implemented by a subclass. + """ + raise NotImplementedError( + f"Collector start() is not implemented for {type(self).__name__}." + ) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + raise NotImplementedError( + f"Collector pause() is not implemented for {type(self).__name__}." + ) + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Shuts down the collector when started asynchronously with the `start` method. + + Args: + timeout (float, optional): The maximum time to wait for the collector to shutdown. + close_env (bool, optional): If True, the collector will close the contained environment. + Defaults to `True`. + + .. seealso:: :meth:`~.start` + + """ + return self.shutdown(timeout=timeout, close_env=close_env) + + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For the new weight sync scheme system, weight preparation is handled + by the scheme's prepare_weights() method. This method now only handles + legacy weight updater cases. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier for resolving string paths. + + Returns: + Extracted weights in the appropriate format. + """ + # New weight sync schemes handle preparation themselves + if self._weight_sync_schemes: + # Just pass through - WeightSender will call scheme.prepare_weights() + return weights + + # Legacy weight updater path + return self._legacy_extract_weights(weights, model_id) + + def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: + """Legacy weight extraction for old weight updater system. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier. + + Returns: + Extracted weights. + """ + if weights is None: + if model_id == "policy" and hasattr(self, "policy_weights"): + return self.policy_weights + elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): + policy_device = ( + self.policy_device + if not isinstance(self.policy_device, (list, tuple)) + else self.policy_device[0] + ) + return self._policy_weights_dict.get(policy_device) + return None + + return weights + + @property + def _legacy_weight_updater(self) -> bool: + return self._weight_updater is not None + + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. + + This method ensures that the policy weights used by the data collector are synchronized with the latest + trained weights. It supports both local and remote weight updates, depending on the configuration of the + data collector. The local (download) update is performed before the remote (upload) update, such that weights + can be transferred to the children workers from a server. + + Args: + policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: + - TensorDictModuleBase: A policy module whose weights will be extracted + - TensorDictBase: A TensorDict containing weights + - dict: A regular dict containing weights + - None: Will try to get weights from server using _get_server_weights() + worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the + workers that need to be updated. This is relevant when the collector has more than one worker associated + with it. + model_id (str | None, optional): The model identifier to update. If provided, only updates this specific + model. Cannot be used together with weights_dict. + weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. + Cannot be used together with model_id or policy_or_weights. + + Raises: + TypeError: If `worker_ids` is provided but no `weight_updater` is configured. + ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). + + .. note:: Users should extend the `WeightUpdaterBase` classes to customize + the weight update logic for specific use cases. This method should not be overwritten. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and + :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. + + """ + if self._legacy_weight_updater: + return self._legacy_weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + else: + return self._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + def _legacy_weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if weights_dict is not None: + raise ValueError("weights_dict is not supported with legacy weight updater") + if model_id is not None: + raise ValueError("model_id is not supported with legacy weight updater") + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + if weights_dict is not None and model_id is not None: + raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") + + if weights_dict is not None and policy_or_weights is not None: + raise ValueError( + "Cannot specify both 'weights_dict' and 'policy_or_weights'" + ) + + if policy_or_weights is not None: + weights_dict = {"policy": policy_or_weights} + + # Priority: new weight sync schemes > old weight updater system + if self._weight_senders: + if model_id is not None: + # Compose weight_dict + weights_dict = {model_id: policy_or_weights} + if weights_dict is None: + if "policy" in self._weight_senders: + weights_dict = {"policy": policy_or_weights} + elif len(self._weight_senders) == 1: + single_model_id = next(iter(self._weight_senders.keys())) + weights_dict = {single_model_id: policy_or_weights} + else: + raise ValueError( + "Cannot determine the model to update. Please provide a weights_dict." + ) + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_senders: + raise KeyError( + f"Model '{target_model_id}' not found in registered weight senders. " + f"Available models: {list(self._weight_senders.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + weights, target_model_id + ) + # Use new send() API with worker_ids support + self._weight_senders[target_model_id].send( + weights=processed_weights, worker_ids=worker_ids + ) + elif self._weight_updater is not None: + # unreachable + raise RuntimeError + else: + return self.receive_weights(policy_or_weights) + + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + # No weight updater configured + # For single-process collectors, apply weights locally if explicitly provided + if policy_or_weights is not None: + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + # Use WeightStrategy to apply weights properly + strategy = WeightStrategy(extract_as="tensordict") + + # Extract weights if needed + if isinstance(policy_or_weights, nn.Module): + weights = strategy.extract_weights(policy_or_weights) + else: + weights = policy_or_weights + + # Apply to local policy + if hasattr(self, "policy") and isinstance(self.policy, nn.Module): + strategy.apply_weights(self.policy, weights) + elif ( + hasattr(self, "_original_policy") + and isinstance(self._original_policy, nn.Module) + and hasattr(self, "policy") + and isinstance(self.policy, nn.Module) + ): + # If no weights were provided, mirror weights from the original (trainer) policy + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as="tensordict") + weights = strategy.extract_weights(self._original_policy) + # Cast weights to the policy device before applying + if self.policy_device is not None: + weights = weights.to(self.policy_device) + strategy.apply_weights(self.policy, weights) + # Otherwise, no action needed - policy is local and changes are immediately visible + + def __iter__(self) -> Iterator[TensorDictBase]: + try: + yield from self.iterator() + except Exception: + self.shutdown() + raise + + def next(self): + try: + if self._iterator is None: + self._iterator = iter(self) + out = next(self._iterator) + # if any, we don't want the device ref to be passed in distributed settings + if out is not None and (out.device != "cpu"): + out = out.copy().clear_device_() + return out + except StopIteration: + return None + + @abc.abstractmethod + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def iterator(self) -> Iterator[TensorDictBase]: + raise NotImplementedError + + @abc.abstractmethod + def set_seed(self, seed: int, static_seed: bool = False) -> int: + raise NotImplementedError + + @abc.abstractmethod + def state_dict(self) -> OrderedDict: + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict(self, state_dict: OrderedDict) -> None: + raise NotImplementedError + + def _read_compile_kwargs(self, compile_policy, cudagraph_policy): + self.compiled_policy = compile_policy not in (False, None) + self.cudagraphed_policy = cudagraph_policy not in (False, None) + self.compiled_policy_kwargs = ( + {} if not isinstance(compile_policy, typing.Mapping) else compile_policy + ) + self.cudagraphed_policy_kwargs = ( + {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy + ) + + def __repr__(self) -> str: + string = f"{self.__class__.__name__}()" + return string + + def __class_getitem__(self, index): + raise NotImplementedError + + def __len__(self) -> int: + if self.total_frames > 0: + return -(self.total_frames // -self.requested_frames_per_batch) + raise RuntimeError("Non-terminating collectors do not have a length") + + def init_updater(self, *args, **kwargs): + """Initialize the weight updater with custom arguments. + + This method passes the arguments to the weight updater's init method. + If no weight updater is set, this is a no-op. + + Args: + *args: Positional arguments for weight updater initialization + **kwargs: Keyword arguments for weight updater initialization + """ + if self.weight_updater is not None: + self.weight_updater.init(*args, **kwargs) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3686368ae71..d0f1c1f765a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2,4973 +2,46 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""Re-exports of collector classes for backward compatibility.""" from __future__ import annotations -import _pickle -import abc -import collections -import contextlib -import functools -import os -import queue -import sys -import threading -import time -import typing -import warnings -from collections import defaultdict, OrderedDict -from collections.abc import Callable, Iterator, Mapping, Sequence -from copy import deepcopy -from multiprocessing import connection, queues -from multiprocessing.managers import SyncManager -from queue import Empty -from textwrap import indent -from typing import Any, TypeVar - -import numpy as np -import torch -import torch.nn as nn - -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict.base import NO_DEFAULT -from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase -from tensordict.utils import _zip_strict, Buffer -from torch import multiprocessing as mp -from torch.nn import Parameter -from torch.utils.data import IterableDataset - -from torchrl._utils import ( - _check_for_faulty_process, - _ends_with, - _make_ordinal_device, - _ProcessNoWarn, - _replace_last, - accept_remote_rref_udf_invocation, - compile_with_warmup, - logger as torchrl_logger, - prod, - RL_WARNINGS, - VERBOSE, -) -from torchrl.collectors.utils import split_trajectories -from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.data import ReplayBuffer -from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import _do_nothing, EnvBase -from torchrl.envs.env_creator import EnvCreator - -from torchrl.envs.llm.transforms.policy_version import PolicyVersion -from torchrl.envs.transforms import StepCounter, TransformedEnv -from torchrl.envs.utils import ( - _aggregate_end_of_traj, - _make_compatible_policy, - ExplorationType, - RandomPolicy, - set_exploration_type, -) -from torchrl.weight_update import SharedMemWeightSyncScheme -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, - MultiProcessWeightSyncScheme, - WeightReceiver, - WeightSender, - WeightSyncScheme, +# Re-export constants for backward compatibility +from torchrl.collectors._constants import ( + _Interruptor, + _InterruptorManager, + _is_osx, + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + INSTANTIATE_TIMEOUT, ) -try: - from torch.compiler import cudagraph_mark_step_begin -except ImportError: - - def cudagraph_mark_step_begin(): - """Placeholder for missing cudagraph_mark_step_begin method.""" - raise NotImplementedError("cudagraph_mark_step_begin not implemented.") - - -_TIMEOUT = 1.0 -INSTANTIATE_TIMEOUT = 20 -_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory -# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. -_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) - -DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM - -_is_osx = sys.platform.startswith("darwin") - -T = TypeVar("T") - - -class _Interruptor: - """A class for managing the collection state of a process. - - This class provides methods to start and stop collection, and to check - whether collection has been stopped. The collection state is protected - by a lock to ensure thread-safety. - """ - - # interrupter vs interruptor: google trends seems to indicate that "or" is more - # widely used than "er" even if my IDE complains about that... - def __init__(self): - self._collect = True - self._lock = mp.Lock() - - def start_collection(self): - with self._lock: - self._collect = True - - def stop_collection(self): - with self._lock: - self._collect = False - - def collection_stopped(self): - with self._lock: - return self._collect is False - - -class _InterruptorManager(SyncManager): - """A custom SyncManager for managing the collection state of a process. - - This class extends the SyncManager class and allows to share an Interruptor object - between processes. - """ - - -_InterruptorManager.register("_Interruptor", _Interruptor) - - -def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: - """Maps the tensors to CPU through a nested dictionary.""" - return OrderedDict( - **{ - k: recursive_map_to_cpu(item) - if isinstance(item, OrderedDict) - else item.cpu() - if isinstance(item, torch.Tensor) - else item - for k, item in dictionary.items() - } - ) - - -class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): - """Base class for data collectors.""" - - _task = None - _iterator = None - total_frames: int - requested_frames_per_batch: int - frames_per_batch: int - trust_policy: bool - compiled_policy: bool - cudagraphed_policy: bool - _weight_updater: WeightUpdaterBase | None = None - _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None - _weight_senders: dict[str, WeightSender] | None = None - _weight_receivers: dict[str, WeightReceiver] | None = None - verbose: bool = False - - @property - def weight_updater(self) -> WeightUpdaterBase: - return self._weight_updater - - @weight_updater.setter - def weight_updater(self, value: WeightUpdaterBase | None): - if value is not None: - if not isinstance(value, WeightUpdaterBase) and callable( - value - ): # Fall back to default constructor - value = value() - value.register_collector(self) - if value.collector is not self: - raise RuntimeError("Failed to register collector.") - self._weight_updater = value - - def _get_policy_and_device( - self, - policy: Callable[[Any], Any] | None = None, - policy_device: Any = NO_DEFAULT, - env_maker: Any | None = None, - env_maker_kwargs: dict[str, Any] | None = None, - ) -> tuple[TensorDictModule, None | Callable[[], dict]]: - """Util method to get a policy and its device given the collector __init__ inputs. - - We want to copy the policy and then move the data there, not call policy.to(device). - - Args: - policy (TensorDictModule, optional): a policy to be used - policy_device (torch.device, optional): the device where the policy should be placed. - Defaults to self.policy_device - env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. - env_maker_kwargs (a dict, optional): the env_maker function kwargs. - - """ - if policy_device is NO_DEFAULT: - policy_device = self.policy_device - - if not policy_device: - return policy, None - - if isinstance(policy, nn.Module): - param_and_buf = TensorDict.from_module(policy, as_module=True) - else: - # Because we want to reach the warning - param_and_buf = TensorDict() - - i = -1 - for p in param_and_buf.values(True, True): - i += 1 - if p.device != policy_device: - # Then we need casting - break - else: - if i == -1 and not self.trust_policy: - # We trust that the policy policy device is adequate - warnings.warn( - "A policy device was provided but no parameter/buffer could be found in " - "the policy. Casting to policy_device is therefore impossible. " - "The collector will trust that the devices match. To suppress this " - "warning, set `trust_policy=True` when building the collector." - ) - return policy, None - - # Create a stateless policy, then populate this copy with params on device - def get_original_weights(policy=policy): - td = TensorDict.from_module(policy) - return td.data - - # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function - with param_and_buf.data.to("meta").to_module(policy): - policy_new_device = deepcopy(policy) - - param_and_buf_new_device = param_and_buf.apply( - functools.partial(_map_weight, policy_device=policy_device), - filter_empty=False, - ) - param_and_buf_new_device.to_module(policy_new_device) - # Sanity check - if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( - get_original_weights().keys(True, True) - ): - raise RuntimeError("Failed to map weights. The weight sets mismatch.") - return policy_new_device, get_original_weights - - def start(self): - """Starts the collector for asynchronous data collection. - - This method initiates the background collection of data, allowing for decoupling of data collection and training. - - The collected data is typically stored in a replay buffer passed during the collector's initialization. - - .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` - when you're done with it to free up resources. - - .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. - Ensure you understand the implications for your specific algorithm before using this mode. - - Raises: - NotImplementedError: If not implemented by a subclass. - """ - raise NotImplementedError( - f"Collector start() is not implemented for {type(self).__name__}." - ) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - raise NotImplementedError( - f"Collector pause() is not implemented for {type(self).__name__}." - ) - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Shuts down the collector when started asynchronously with the `start` method. - - Args: - timeout (float, optional): The maximum time to wait for the collector to shutdown. - close_env (bool, optional): If True, the collector will close the contained environment. - Defaults to `True`. - - .. seealso:: :meth:`~.start` - - """ - return self.shutdown(timeout=timeout, close_env=close_env) - - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For the new weight sync scheme system, weight preparation is handled - by the scheme's prepare_weights() method. This method now only handles - legacy weight updater cases. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier for resolving string paths. - - Returns: - Extracted weights in the appropriate format. - """ - # New weight sync schemes handle preparation themselves - if self._weight_sync_schemes: - # Just pass through - WeightSender will call scheme.prepare_weights() - return weights - - # Legacy weight updater path - return self._legacy_extract_weights(weights, model_id) - - def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: - """Legacy weight extraction for old weight updater system. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier. - - Returns: - Extracted weights. - """ - if weights is None: - if model_id == "policy" and hasattr(self, "policy_weights"): - return self.policy_weights - elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): - policy_device = ( - self.policy_device - if not isinstance(self.policy_device, (list, tuple)) - else self.policy_device[0] - ) - return self._policy_weights_dict.get(policy_device) - return None - - return weights - - @property - def _legacy_weight_updater(self) -> bool: - return self._weight_updater is not None - - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. - - This method ensures that the policy weights used by the data collector are synchronized with the latest - trained weights. It supports both local and remote weight updates, depending on the configuration of the - data collector. The local (download) update is performed before the remote (upload) update, such that weights - can be transferred to the children workers from a server. - - Args: - policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: - - TensorDictModuleBase: A policy module whose weights will be extracted - - TensorDictBase: A TensorDict containing weights - - dict: A regular dict containing weights - - None: Will try to get weights from server using _get_server_weights() - worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the - workers that need to be updated. This is relevant when the collector has more than one worker associated - with it. - model_id (str | None, optional): The model identifier to update. If provided, only updates this specific - model. Cannot be used together with weights_dict. - weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating - multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. - Cannot be used together with model_id or policy_or_weights. - - Raises: - TypeError: If `worker_ids` is provided but no `weight_updater` is configured. - ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). - - .. note:: Users should extend the `WeightUpdaterBase` classes to customize - the weight update logic for specific use cases. This method should not be overwritten. - - .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and - :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. - - """ - if self._legacy_weight_updater: - return self._legacy_weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - else: - return self._weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - - def _legacy_weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if weights_dict is not None: - raise ValueError("weights_dict is not supported with legacy weight updater") - if model_id is not None: - raise ValueError("model_id is not supported with legacy weight updater") - # Fall back to old weight updater system - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def _weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - if weights_dict is not None and model_id is not None: - raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") - - if weights_dict is not None and policy_or_weights is not None: - raise ValueError( - "Cannot specify both 'weights_dict' and 'policy_or_weights'" - ) - - if policy_or_weights is not None: - weights_dict = {"policy": policy_or_weights} - - # Priority: new weight sync schemes > old weight updater system - if self._weight_senders: - if model_id is not None: - # Compose weight_dict - weights_dict = {model_id: policy_or_weights} - if weights_dict is None: - if "policy" in self._weight_senders: - weights_dict = {"policy": policy_or_weights} - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - weights_dict = {single_model_id: policy_or_weights} - else: - raise ValueError( - "Cannot determine the model to update. Please provide a weights_dict." - ) - for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: - raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" - ) - processed_weights = self._extract_weights_if_needed( - weights, target_model_id - ) - # Use new send() API with worker_ids support - self._weight_senders[target_model_id].send( - weights=processed_weights, worker_ids=worker_ids - ) - elif self._weight_updater is not None: - # unreachable - raise RuntimeError - else: - return self.receive_weights(policy_or_weights) - - def receive_weights(self, policy_or_weights: TensorDictBase | None = None): - # No weight updater configured - # For single-process collectors, apply weights locally if explicitly provided - if policy_or_weights is not None: - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - # Use WeightStrategy to apply weights properly - strategy = WeightStrategy(extract_as="tensordict") - - # Extract weights if needed - if isinstance(policy_or_weights, nn.Module): - weights = strategy.extract_weights(policy_or_weights) - else: - weights = policy_or_weights - - # Apply to local policy - if hasattr(self, "policy") and isinstance(self.policy, nn.Module): - strategy.apply_weights(self.policy, weights) - elif ( - hasattr(self, "_original_policy") - and isinstance(self._original_policy, nn.Module) - and hasattr(self, "policy") - and isinstance(self.policy, nn.Module) - ): - # If no weights were provided, mirror weights from the original (trainer) policy - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as="tensordict") - weights = strategy.extract_weights(self._original_policy) - # Cast weights to the policy device before applying - if self.policy_device is not None: - weights = weights.to(self.policy_device) - strategy.apply_weights(self.policy, weights) - # Otherwise, no action needed - policy is local and changes are immediately visible - - def __iter__(self) -> Iterator[TensorDictBase]: - try: - yield from self.iterator() - except Exception: - self.shutdown() - raise - - def next(self): - try: - if self._iterator is None: - self._iterator = iter(self) - out = next(self._iterator) - # if any, we don't want the device ref to be passed in distributed settings - if out is not None and (out.device != "cpu"): - out = out.copy().clear_device_() - return out - except StopIteration: - return None - - @abc.abstractmethod - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - raise NotImplementedError - - @abc.abstractmethod - def iterator(self) -> Iterator[TensorDictBase]: - raise NotImplementedError - - @abc.abstractmethod - def set_seed(self, seed: int, static_seed: bool = False) -> int: - raise NotImplementedError - - @abc.abstractmethod - def state_dict(self) -> OrderedDict: - raise NotImplementedError - - @abc.abstractmethod - def load_state_dict(self, state_dict: OrderedDict) -> None: - raise NotImplementedError - - def _read_compile_kwargs(self, compile_policy, cudagraph_policy): - self.compiled_policy = compile_policy not in (False, None) - self.cudagraphed_policy = cudagraph_policy not in (False, None) - self.compiled_policy_kwargs = ( - {} if not isinstance(compile_policy, typing.Mapping) else compile_policy - ) - self.cudagraphed_policy_kwargs = ( - {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy - ) - - def __repr__(self) -> str: - string = f"{self.__class__.__name__}()" - return string - - def __class_getitem__(self, index): - raise NotImplementedError - - def __len__(self) -> int: - if self.total_frames > 0: - return -(self.total_frames // -self.requested_frames_per_batch) - raise RuntimeError("Non-terminating collectors do not have a length") - - def init_updater(self, *args, **kwargs): - """Initialize the weight updater with custom arguments. - - This method passes the arguments to the weight updater's init method. - If no weight updater is set, this is a no-op. - - Args: - *args: Positional arguments for weight updater initialization - **kwargs: Keyword arguments for weight updater initialization - """ - if self.weight_updater is not None: - self.weight_updater.init(*args, **kwargs) - - -@accept_remote_rref_udf_invocation -class SyncDataCollector(DataCollectorBase): - """Generic data collector for RL problems. Requires an environment constructor and a policy. - - Args: - create_env_fn (Callable or EnvBase): a callable that returns an instance of - :class:`~torchrl.envs.EnvBase` class, or the env itself. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the total - number of elements in a batch. - total_frames (int): A keyword-only argument representing the total - number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (endless collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - create_env_kwargs (dict, optional): Dictionary of kwargs for - ``create_env_fn``. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e., no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - - .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer - as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. - - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - return_same_td (bool, optional): if ``True``, the same TensorDict - will be returned at each iteration, with its values - updated. This feature should be used cautiously: if the same - tensordict is added to a replay buffer for instance, - the whole content of the buffer will be identical. - Default is ``False``. - interruptor (_Interruptor, optional): - An _Interruptor object that can be used from outside the class to control rollout collection. - The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement - strategies such as preeptively stopping rollout collection. - Default is ``False``. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. - Defaults to ``None``. - - .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. - If the buffer needs to be populated with individual frames as they are collected, - set ``extend_buffer=False`` (deprecated). - - .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires - `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. - - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True`. - - .. note:: Setting this to `False` is deprecated and will be removed in a future version. - Extending the buffer with entire rollouts is the recommended approach for better - compatibility with postprocessing and trajectory splitting. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. - Consider using a constructor if the updater needs to be serialized. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector( - ... create_env_fn=env_maker, - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - >>> del collector - - The collector delivers batches of data that are marked with a ``"time"`` - dimension. - - Examples: - >>> assert data.names[-1] == "time" - - """ - - _ignore_rb: bool = False - - def __init__( - self, - create_env_fn: ( - EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 - ), # noqa: F821 - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int = -1, - device: DEVICE_TYPING | None = None, - storing_device: DEVICE_TYPING | None = None, - policy_device: DEVICE_TYPING | None = None, - env_device: DEVICE_TYPING | None = None, - create_env_kwargs: dict[str, Any] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - return_same_td: bool = False, - reset_when_done: bool = True, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - **kwargs, - ): - self.closed = True - - # Initialize environment - env = self._init_env(create_env_fn, create_env_kwargs) - - # Initialize policy - policy = self._init_policy(policy, policy_factory, env, trust_policy) - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Handle trajectory pool and validate kwargs - self._traj_pool_val = kwargs.pop("traj_pool", None) - if kwargs: - raise TypeError( - f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." - ) - - # Set up devices and synchronization - self._setup_devices( - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - no_cuda_sync=no_cuda_sync, - ) - - self.env: EnvBase = env - del env - - # Set up policy version tracking - self._setup_policy_version_tracking(track_policy_version) - - # Set up replay buffer - self._setup_replay_buffer( - replay_buffer=replay_buffer, - extend_buffer=extend_buffer, - local_init_rb=local_init_rb, - postproc=postproc, - split_trajs=split_trajs, - return_same_td=return_same_td, - use_buffers=use_buffers, - ) - - self.closed = False - - # Validate reset_when_done - if not reset_when_done: - raise ValueError("reset_when_done is deprecated.") - self.reset_when_done = reset_when_done - self.n_env = self.env.batch_size.numel() - - # Register collector with policy and env - if hasattr(policy, "register_collector"): - policy.register_collector(self) - if hasattr(self.env, "register_collector"): - self.env.register_collector(self) - - # Set up policy and weights - self._setup_policy_and_weights(policy) - - # Apply environment device - self._apply_env_device() - - # Set up max frames per trajectory - self._setup_max_frames_per_traj(max_frames_per_traj) - - # Validate and set total frames - self.reset_at_each_iter = reset_at_each_iter - self._setup_total_frames(total_frames, frames_per_batch) - - # Set up init random frames - self._setup_init_random_frames(init_random_frames, frames_per_batch) - - # Set up postproc - self._setup_postproc(postproc) - - # Calculate frames per batch - self._setup_frames_per_batch(frames_per_batch) - - # Set exploration and other options - self.exploration_type = ( - exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE - ) - self.return_same_td = return_same_td - self.set_truncated = set_truncated - - # Create shuttle and rollout buffers - self._make_shuttle() - self._maybe_make_final_rollout(make_rollout=self._use_buffers) - self._set_truncated_keys() - - # Set split trajectories option - if split_trajs is None: - split_trajs = False - self.split_trajs = split_trajs - self._exclude_private_keys = True - - # Set up interruptor and frame tracking - self.interruptor = interruptor - self._frames = 0 - self._iter = -1 - - # Set up weight synchronization - self._setup_weight_sync(weight_updater, weight_sync_schemes) - - def _init_env( - self, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], - create_env_kwargs: dict[str, Any] | None, - ) -> EnvBase: - """Initialize and configure the environment.""" - from torchrl.envs.batched_envs import BatchedEnvBase - - if create_env_kwargs is None: - create_env_kwargs = {} - - if not isinstance(create_env_fn, EnvBase): - env = create_env_fn(**create_env_kwargs) - else: - env = create_env_fn - if create_env_kwargs: - if not isinstance(env, BatchedEnvBase): - raise RuntimeError( - "kwargs were passed to SyncDataCollector but they can't be set " - f"on environment of type {type(create_env_fn)}." - ) - env.update_kwargs(create_env_kwargs) - return env - - def _init_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: Callable[[], Callable] | None, - env: EnvBase, - trust_policy: bool | None, - ) -> TensorDictModule | Callable: - """Initialize and configure the policy.""" - if policy is None: - if policy_factory is not None: - policy = policy_factory() - else: - policy = RandomPolicy(env.full_action_spec) - elif policy_factory is not None: - raise TypeError("policy_factory cannot be used with policy argument.") - - # If the underlying policy has a state_dict, keep a reference to it - if hasattr(policy, "state_dict"): - self._policy_w_state_dict = policy - - if trust_policy is None: - trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) - self.trust_policy = trust_policy - - return policy - - def _setup_devices( - self, - device: DEVICE_TYPING | None, - storing_device: DEVICE_TYPING | None, - policy_device: DEVICE_TYPING | None, - env_device: DEVICE_TYPING | None, - no_cuda_sync: bool, - ) -> None: - """Set up devices and synchronization functions.""" - storing_device, policy_device, env_device = self._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - - self.storing_device = storing_device - self._sync_storage = self._get_sync_fn(storing_device) - - self.env_device = env_device - self._sync_env = self._get_sync_fn(env_device) - - self.policy_device = policy_device - self._sync_policy = self._get_sync_fn(policy_device) - - self.device = device - self.no_cuda_sync = no_cuda_sync - self._cast_to_policy_device = self.policy_device != self.env_device - - def _get_sync_fn(self, device: torch.device | None) -> Callable: - """Get the appropriate synchronization function for a device.""" - if device is not None and device.type != "cuda": - # Cuda handles sync - if torch.cuda.is_available(): - return torch.cuda.synchronize - elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - return torch.mps.synchronize - elif hasattr(torch, "npu") and torch.npu.is_available(): - return torch.npu.synchronize - elif device.type == "cpu": - return _do_nothing - else: - raise RuntimeError("Non supported device") - else: - return _do_nothing - - def _setup_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking if requested.""" - self.policy_version_tracker = track_policy_version - if isinstance(track_policy_version, bool) and track_policy_version: - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, BatchedEnvBase): - raise RuntimeError( - "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " - "and pass that transform to the collector." - ) - self.policy_version_tracker = PolicyVersion() - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - else: - self.policy_version_tracker = None - - def _setup_replay_buffer( - self, - replay_buffer: ReplayBuffer | None, - extend_buffer: bool, - local_init_rb: bool | None, - postproc: Callable | None, - split_trajs: bool | None, - return_same_td: bool, - use_buffers: bool | None, - ) -> None: - """Set up replay buffer configuration and validate compatibility.""" - self.replay_buffer = replay_buffer - self.extend_buffer = extend_buffer - - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - # Validate replay buffer compatibility - if self.replay_buffer is not None and not self._ignore_rb: - if postproc is not None and not self.extend_buffer: - raise TypeError( - "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." - ) - if split_trajs not in (None, False) and not self.extend_buffer: - raise TypeError( - "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if return_same_td: - raise TypeError( - "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if use_buffers: - raise TypeError("replay_buffer is exclusive with use_buffers.") - - if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None - self._use_buffers = use_buffers - - def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: - """Set up policy, wrapped policy, and extract weights.""" - self._original_policy = policy - policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) - - if not self.trust_policy: - self.policy = policy - env = getattr(self, "env", None) - try: - wrapped_policy = _make_compatible_policy( - policy=policy, - observation_spec=getattr(env, "observation_spec", None), - env=self.env, - ) - except (TypeError, AttributeError, ValueError) as err: - raise TypeError( - "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." - ) from err - self._wrapped_policy = wrapped_policy - else: - self.policy = self._wrapped_policy = policy - - # Extract policy weights - if isinstance(self._wrapped_policy, nn.Module): - self.policy_weights = TensorDict.from_module( - self._wrapped_policy, as_module=True - ).data - else: - self.policy_weights = TensorDict() - - # Apply compilation/cudagraph - if self.compiled_policy: - self._wrapped_policy = compile_with_warmup( - self._wrapped_policy, **self.compiled_policy_kwargs - ) - if self.cudagraphed_policy: - self._wrapped_policy = CudaGraphModule( - self._wrapped_policy, - in_keys=[], - out_keys=[], - device=self.policy_device, - **self.cudagraphed_policy_kwargs, - ) - - def _apply_env_device(self) -> None: - """Apply device to environment if specified.""" - if self.env_device: - self.env: EnvBase = self.env.to(self.env_device) - elif self.env.device is not None: - # Use the device of the env if none was provided - self.env_device = self.env.device - - # Check if we need to cast to env device - self._cast_to_env_device = self._cast_to_policy_device or ( - self.env.device != self.storing_device - ) - - def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: - """Set up maximum frames per trajectory and add StepCounter if needed.""" - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: - # Check that there is no StepCounter yet - for key in self.env.output_spec.keys(True, True): - if isinstance(key, str): - key = (key,) - if "step_count" in key: - raise ValueError( - "A 'step_count' key is already present in the environment " - "and the 'max_frames_per_traj' argument may conflict with " - "a 'StepCounter' that has already been set. " - "Possible solutions: Set max_frames_per_traj to 0 or " - "remove the StepCounter limit from the environment transforms." - ) - self.env = TransformedEnv( - self.env, StepCounter(max_steps=self.max_frames_per_traj) - ) - - def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: - """Validate and set total frames.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % frames_per_batch - if remainder != 0 and RL_WARNINGS: - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " - f"This means {frames_per_batch - remainder} additional frames will be collected." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_init_random_frames( - self, init_random_frames: int | None, frames_per_batch: int - ) -> None: - """Set up initial random frames.""" - self.init_random_frames = ( - int(init_random_frames) if init_random_frames not in (None, -1) else 0 - ) - if ( - init_random_frames not in (-1, None, 0) - and init_random_frames % frames_per_batch != 0 - and RL_WARNINGS - ): - warnings.warn( - f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " - f" this results in more init_random_frames than requested" - f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - - def _setup_postproc(self, postproc: Callable | None) -> None: - """Set up post-processing transform.""" - self.postproc = postproc - if ( - self.postproc is not None - and hasattr(self.postproc, "to") - and self.storing_device - ): - postproc = self.postproc.to(self.storing_device) - if postproc is not self.postproc and postproc is not None: - self.postproc = postproc - - def _setup_frames_per_batch(self, frames_per_batch: int) -> None: - """Calculate and validate frames per batch.""" - if frames_per_batch % self.n_env != 0 and RL_WARNINGS: - warnings.warn( - f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " - f" this results in more frames_per_batch per iteration that requested" - f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.frames_per_batch = -(-frames_per_batch // self.n_env) - self.requested_frames_per_batch = self.frames_per_batch * self.n_env - - def _setup_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization system.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # For single-process collectors, we don't need senders/receivers - # The policy is local and changes are immediately visible - # Senders will be set up in multiprocess collectors during _run_processes - self.weight_updater = None # Don't use legacy system - elif weight_updater is not None: - # Use legacy weight updater system if explicitly provided - if not isinstance(weight_updater, WeightUpdaterBase): - if callable(weight_updater): - weight_updater = weight_updater() - else: - raise TypeError( - f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." - ) - warnings.warn( - "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " - "This will be removed in a future version.", - DeprecationWarning, - stacklevel=2, - ) - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - else: - # No weight sync needed for single-process collectors - self.weight_updater = None - self._weight_sync_schemes = None - self._weight_senders = {} - - @property - def _traj_pool(self): - pool = getattr(self, "_traj_pool_val", None) - if pool is None: - pool = self._traj_pool_val = _TrajectoryPool() - return pool - - def _make_shuttle(self): - # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env - with torch.no_grad(): - self._shuttle = self.env.reset() - if self.policy_device != self.env_device or self.env_device is None: - self._shuttle_has_no_device = True - self._shuttle.clear_device_() - else: - self._shuttle_has_no_device = False - - traj_ids = self._traj_pool.get_traj_and_increment( - self.n_env, device=self.storing_device - ).view(self.env.batch_size) - self._shuttle.set( - ("collector", "traj_ids"), - traj_ids, - ) - - def _maybe_make_final_rollout(self, make_rollout: bool): - if make_rollout: - with torch.no_grad(): - self._final_rollout = self.env.fake_tensordict() - - # If storing device is not None, we use this to cast the storage. - # If it is None and the env and policy are on the same device, - # the storing device is already the same as those, so we don't need - # to consider this use case. - # In all other cases, we can't really put a device on the storage, - # since at least one data source has a device that is not clear. - if self.storing_device: - self._final_rollout = self._final_rollout.to( - self.storing_device, non_blocking=True - ) - else: - # erase all devices - self._final_rollout.clear_device_() - - # If the policy has a valid spec, we use it - self._policy_output_keys = set() - if ( - make_rollout - and hasattr(self._wrapped_policy, "spec") - and self._wrapped_policy.spec is not None - and all(v is not None for v in self._wrapped_policy.spec.values(True, True)) - ): - if any( - key not in self._final_rollout.keys(isinstance(key, tuple)) - for key in self._wrapped_policy.spec.keys(True, True) - ): - # if policy spec is non-empty, all the values are not None and the keys - # match the out_keys we assume the user has given all relevant information - # the policy could have more keys than the env: - policy_spec = self._wrapped_policy.spec - if policy_spec.ndim < self._final_rollout.ndim: - policy_spec = policy_spec.expand(self._final_rollout.shape) - for key, spec in policy_spec.items(True, True): - self._policy_output_keys.add(key) - if key in self._final_rollout.keys(True): - continue - self._final_rollout.set(key, spec.zero()) - elif ( - not make_rollout - and hasattr(self._wrapped_policy, "out_keys") - and self._wrapped_policy.out_keys - ): - self._policy_output_keys = list(self._wrapped_policy.out_keys) - else: - if make_rollout: - # otherwise, we perform a small number of steps with the policy to - # determine the relevant keys with which to pre-populate _final_rollout. - # This is the safest thing to do if the spec has None fields or if there is - # no spec at all. - # See #505 for additional context. - self._final_rollout.update(self._shuttle.copy()) - with torch.no_grad(): - policy_input = self._shuttle.copy() - if self.policy_device: - policy_input = policy_input.to(self.policy_device) - # we cast to policy device, we'll deal with the device later - policy_input_copy = policy_input.copy() - policy_input_clone = ( - policy_input.clone() - ) # to test if values have changed in-place - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - - # check that we don't have exclusive keys, because they don't appear in keys - def check_exclusive(val): - if ( - isinstance(val, LazyStackedTensorDict) - and val._has_exclusive_keys - ): - raise RuntimeError( - "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " - "Consider using a placeholder for missing keys." - ) - - policy_output._fast_apply( - check_exclusive, call_on_nested=True, filter_empty=True - ) - - # Use apply, because it works well with lazy stacks - # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit - # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has - # changed them here). - # This will cause a failure to update entries when policy and env device mismatch and - # casting is necessary. - def filter_policy(name, value_output, value_input, value_input_clone): - if (value_input is None) or ( - (value_output is not value_input) - and ( - value_output.device != value_input_clone.device - or ~torch.isclose(value_output, value_input_clone).any() - ) - ): - return value_output - - filtered_policy_output = policy_output.apply( - filter_policy, - policy_input_copy, - policy_input_clone, - default=None, - filter_empty=True, - named=True, - ) - self._policy_output_keys = list( - self._policy_output_keys.union( - set(filtered_policy_output.keys(True, True)) - ) - ) - if make_rollout: - self._final_rollout.update( - policy_output.select(*self._policy_output_keys) - ) - del filtered_policy_output, policy_output, policy_input - - _env_output_keys = [] - for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: - _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) - self._env_output_keys = _env_output_keys - if make_rollout: - self._final_rollout = ( - self._final_rollout.unsqueeze(-1) - .expand(*self.env.batch_size, self.frames_per_batch) - .clone() - .zero_() - ) - - # in addition to outputs of the policy, we add traj_ids to - # _final_rollout which will be collected during rollout - self._final_rollout.set( - ("collector", "traj_ids"), - torch.zeros( - *self._final_rollout.batch_size, - dtype=torch.int64, - device=self.storing_device, - ), - ) - self._final_rollout.refine_names(..., "time") - - def _set_truncated_keys(self): - self._truncated_keys = [] - if self.set_truncated: - if not any(_ends_with(key, "truncated") for key in self.env.done_keys): - raise RuntimeError( - "set_truncated was set to True but no truncated key could be found " - "in the environment. Make sure the truncated keys are properly set using " - "`env.add_truncated_keys()` before passing the env to the collector." - ) - self._truncated_keys = [ - key for key in self.env.done_keys if _ends_with(key, "truncated") - ] - - @classmethod - def _get_devices( - cls, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - device = _make_ordinal_device(torch.device(device) if device else device) - storing_device = _make_ordinal_device( - torch.device(storing_device) if storing_device else device - ) - policy_device = _make_ordinal_device( - torch.device(policy_device) if policy_device else device - ) - env_device = _make_ordinal_device( - torch.device(env_device) if env_device else device - ) - if storing_device is None and (env_device == policy_device): - storing_device = env_device - return storing_device, policy_device, env_device - - # for RPC - def next(self): - return super().next() - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed (int): integer representing the seed to be used for the environment. - static_seed(bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is contained in the DataCollector, as the - seed will be incremented for each of these. The resulting seed is the seed of the last environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - out = self.env.set_seed(seed, static_seed=static_seed) - return out - - def _increment_frames(self, numel): - self._frames += numel - completed = self._frames >= self.total_frames - if completed: - self.env.close() - return completed - - def iterator(self) -> Iterator[TensorDictBase]: - """Iterates through the DataCollector. - - Yields: TensorDictBase objects containing (chunks of) trajectories - - """ - if ( - not self.no_cuda_sync - and self.storing_device - and self.storing_device.type == "cuda" - ): - stream = torch.cuda.Stream(self.storing_device, priority=-1) - event = stream.record_event() - streams = [stream] - events = [event] - elif not self.no_cuda_sync and self.storing_device is None: - streams = [] - events = [] - # this way of checking cuda is robust to lazy stacks with mismatching shapes - cuda_devices = set() - - def cuda_check(tensor: torch.Tensor): - if tensor.is_cuda: - cuda_devices.add(tensor.device) - - if not self._use_buffers: - # This may be a bit dangerous as `torch.device("cuda")` may not have a precise - # device associated, whereas `tensor.device` always has - for spec in self.env.specs.values(True, True): - if spec.device is not None and spec.device.type == "cuda": - if ":" not in str(spec.device): - raise RuntimeError( - "A cuda spec did not have a device associated. Make sure to " - "pass `'cuda:device_num'` to each spec device." - ) - cuda_devices.add(spec.device) - else: - self._final_rollout.apply(cuda_check, filter_empty=True) - for device in cuda_devices: - streams.append(torch.cuda.Stream(device, priority=-1)) - events.append(streams[-1].record_event()) - else: - streams = [] - events = [] - with contextlib.ExitStack() as stack: - for stream in streams: - stack.enter_context(torch.cuda.stream(stream)) - - while self._frames < self.total_frames: - self._iter += 1 - if self.verbose: - torchrl_logger.info("Collector: rollout.") - tensordict_out = self.rollout() - if tensordict_out is None: - # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out - # frames are updated within the rollout function - if self.verbose: - torchrl_logger.info("Collector: No tensordict_out. Yielding.") - yield - continue - self._increment_frames(tensordict_out.numel()) - tensordict_out = self._postproc(tensordict_out) - if self.verbose: - torchrl_logger.info("Collector: postproc done.") - if self.return_same_td: - # This is used with multiprocessed collectors to use the buffers - # stored in the tensordict. - if events: - for event in events: - event.record() - event.synchronize() - yield tensordict_out - elif self.replay_buffer is not None and not self._ignore_rb: - self.replay_buffer.extend(tensordict_out) - if self.verbose: - torchrl_logger.info( - f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " - "Buffer write count: {self.replay_buffer.write_count}. Yielding." - ) - yield - else: - # we must clone the values, as the tensordict is updated in-place. - # otherwise the following code may break: - # >>> for i, data in enumerate(collector): - # >>> if i == 0: - # >>> data0 = data - # >>> elif i == 1: - # >>> data1 = data - # >>> else: - # >>> break - # >>> assert data0["done"] is not data1["done"] - yield tensordict_out.clone() - - def start(self): - """Starts the collector in a separate thread for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data - collection from training, allowing your training loop to run independently of the data collection process. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env = GymEnv("ALE/Pong-v5") - ... policy = RandomPolicy(env.action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) - ... - ... # Create a synchronous data collector - ... collector = SyncDataCollector( - ... env, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=256, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if not self.is_running(): - self._stop = False - self._thread = threading.Thread(target=self._run_iterator) - self._thread.daemon = ( - True # So that the thread dies when the main program exits - ) - self._thread.start() - - def _run_iterator(self): - for _ in self: - if self._stop: - return - - def is_running(self): - return hasattr(self, "_thread") and self._thread.is_alive() - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Finishes processes started by ray.init() during async execution.""" - self._stop = True - if hasattr(self, "_thread") and self._thread.is_alive(): - self._thread.join(timeout=timeout) - self.shutdown(close_env=close_env) - - def _postproc(self, tensordict_out): - if self.split_trajs: - tensordict_out = split_trajectories(tensordict_out, prefix="collector") - if self.postproc is not None: - tensordict_out = self.postproc(tensordict_out) - if self._exclude_private_keys: - - def is_private(key): - if isinstance(key, str) and key.startswith("_"): - return True - if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): - return True - return False - - excluded_keys = [ - key for key in tensordict_out.keys(True) if is_private(key) - ] - tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) - return tensordict_out - - def _update_traj_ids(self, env_output) -> None: - # we can't use the reset keys because they're gone - traj_sop = _aggregate_end_of_traj( - env_output.get("next"), done_keys=self.env.done_keys - ) - if traj_sop.any(): - device = self.storing_device - - traj_ids = self._shuttle.get(("collector", "traj_ids")) - if device is not None: - traj_ids = traj_ids.to(device) - traj_sop = traj_sop.to(device) - elif traj_sop.device != traj_ids.device: - traj_sop = traj_sop.to(traj_ids.device) - - pool = self._traj_pool - new_traj = pool.get_traj_and_increment( - traj_sop.sum(), device=traj_sop.device - ) - traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) - self._shuttle.set(("collector", "traj_ids"), traj_ids) - - @torch.no_grad() - def rollout(self) -> TensorDictBase: - """Computes a rollout in the environment using the provided policy. - - Returns: - TensorDictBase containing the computed rollout. - - """ - if self.reset_at_each_iter: - self._shuttle.update(self.env.reset()) - - # self._shuttle.fill_(("collector", "step_count"), 0) - if self._use_buffers: - self._final_rollout.fill_(("collector", "traj_ids"), -1) - else: - pass - tensordicts = [] - with set_exploration_type(self.exploration_type): - for t in range(self.frames_per_batch): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.env.rand_action(self._shuttle) - if ( - self.policy_device is not None - and self.policy_device != self.env_device - ): - # TODO: This may break with exclusive / ragged lazy stacks - self._shuttle.apply( - lambda name, val: val.to( - device=self.policy_device, non_blocking=True - ) - if name in self._policy_output_keys - else val, - out=self._shuttle, - named=True, - nested_keys=True, - ) - else: - if self._cast_to_policy_device: - if self.policy_device is not None: - # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking - non_blocking = ( - not self.no_cuda_sync - or self.policy_device.type == "cuda" - ) - policy_input = self._shuttle.to( - self.policy_device, - non_blocking=non_blocking, - ) - if not self.no_cuda_sync: - self._sync_policy() - elif self.policy_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # policy_input = self._shuttle.clear_device_() - policy_input = self._shuttle - else: - policy_input = self._shuttle - # we still do the assignment for security - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - if self.compiled_policy: - policy_output = policy_output.clone() - if self._shuttle is not policy_output: - # ad-hoc update shuttle - self._shuttle.update( - policy_output, keys_to_update=self._policy_output_keys - ) - - if self._cast_to_env_device: - if self.env_device is not None: - non_blocking = ( - not self.no_cuda_sync or self.env_device.type == "cuda" - ) - env_input = self._shuttle.to( - self.env_device, non_blocking=non_blocking - ) - if not self.no_cuda_sync: - self._sync_env() - elif self.env_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # env_input = self._shuttle.clear_device_() - env_input = self._shuttle - else: - env_input = self._shuttle - env_output, env_next_output = self.env.step_and_maybe_reset(env_input) - - if self._shuttle is not env_output: - # ad-hoc update shuttle - next_data = env_output.get("next") - if self._shuttle_has_no_device: - # Make sure - next_data.clear_device_() - self._shuttle.set("next", next_data) - - if self.verbose: - torchrl_logger.info( - f"Collector: Rollout step completed {self._iter=}." - ) - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - if self.verbose: - torchrl_logger.info( - f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." - ) - self.replay_buffer.add(self._shuttle) - if self._increment_frames(self._shuttle.numel()): - return - else: - if self.storing_device is not None: - if self.verbose: - torchrl_logger.info( - f"Collector: Moving to {self.storing_device} and adding to queue." - ) - non_blocking = ( - not self.no_cuda_sync or self.storing_device.type == "cuda" - ) - tensordicts.append( - self._shuttle.to( - self.storing_device, non_blocking=non_blocking - ) - ) - if not self.no_cuda_sync: - self._sync_storage() - else: - if self.verbose: - torchrl_logger.info( - "Collector: Adding to queue (no device)." - ) - tensordicts.append(self._shuttle) - - # carry over collector data without messing up devices - collector_data = self._shuttle.get("collector").copy() - self._shuttle = env_next_output - if self._shuttle_has_no_device: - self._shuttle.clear_device_() - self._shuttle.set("collector", collector_data) - self._update_traj_ids(env_output) - - if ( - self.interruptor is not None - and self.interruptor.collection_stopped() - ): - if self.verbose: - torchrl_logger.info("Collector: Interruptor stopped.") - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - result = self._final_rollout - if self._use_buffers: - try: - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - except RuntimeError: - with self._final_rollout.unlock_(): - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - else: - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - break - else: - if self._use_buffers: - torchrl_logger.info("Returning final rollout within buffer.") - result = self._final_rollout - try: - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - - except RuntimeError: - with self._final_rollout.unlock_(): - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - elif ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - else: - torchrl_logger.info( - "Returning final rollout with NO buffer (maybe_dense_stack)." - ) - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - result.refine_names(..., "time") - - return self._maybe_set_truncated(result) - - def _maybe_set_truncated(self, final_rollout): - last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) - for truncated_key in self._truncated_keys: - truncated = final_rollout["next", truncated_key] - truncated[last_step] = True - final_rollout["next", truncated_key] = truncated - done = final_rollout["next", _replace_last(truncated_key, "done")] - final_rollout["next", _replace_last(truncated_key, "done")] = ( - done | truncated - ) - return final_rollout - - @torch.no_grad() - def reset(self, index=None, **kwargs) -> None: - """Resets the environments to a new initial state.""" - # metadata - collector_metadata = self._shuttle.get("collector").clone() - if index is not None: - # check that the env supports partial reset - if prod(self.env.batch_size) == 0: - raise RuntimeError("resetting unique env with index is not permitted.") - for reset_key, done_keys in zip( - self.env.reset_keys, self.env.done_keys_groups - ): - _reset = torch.zeros( - self.env.full_done_spec[done_keys[0]].shape, - dtype=torch.bool, - device=self.env.device, - ) - _reset[index] = 1 - self._shuttle.set(reset_key, _reset) - else: - _reset = None - self._shuttle.zero_() - - self._shuttle.update(self.env.reset(**kwargs), inplace=True) - collector_metadata["traj_ids"] = ( - collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() - ) - self._shuttle["collector"] = collector_metadata - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all workers and/or closes the local environment. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - No effect for this class. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - try: - if not self.closed: - self.closed = True - del self._shuttle - if self._use_buffers: - del self._final_rollout - if close_env and not self.env.is_closed: - self.env.close(raise_if_closed=raise_on_error) - del self.env - return - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def state_dict(self) -> OrderedDict: - """Returns the local state_dict of the data collector (environment and policy). - - Returns: - an ordered dictionary with fields :obj:`"policy_state_dict"` and - `"env_state_dict"`. - - """ - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, TransformedEnv): - env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, BatchedEnvBase): - env_state_dict = self.env.state_dict() - else: - env_state_dict = OrderedDict() - - if hasattr(self, "_policy_w_state_dict"): - policy_state_dict = self._policy_w_state_dict.state_dict() - state_dict = OrderedDict( - policy_state_dict=policy_state_dict, - env_state_dict=env_state_dict, - ) - else: - state_dict = OrderedDict(env_state_dict=env_state_dict) - - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: - """Loads a state_dict on the environment and policy. - - Args: - state_dict (OrderedDict): ordered dictionary containing the fields - `"policy_state_dict"` and :obj:`"env_state_dict"`. - - """ - strict = kwargs.get("strict", True) - if strict or "env_state_dict" in state_dict: - self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) - if strict or "policy_state_dict" in state_dict: - if not hasattr(self, "_policy_w_state_dict"): - raise ValueError( - "Underlying policy does not have state_dict to load policy_state_dict into." - ) - self._policy_w_state_dict.load_state_dict( - state_dict["policy_state_dict"], **kwargs - ) - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def __repr__(self) -> str: - try: - env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") - td_out_str = repr(getattr(self, "_final_rollout", None)) - if len(td_out_str) > 50: - td_out_str = td_out_str[:50] + "..." - td_out_str = indent(f"td_out={td_out_str}", 4 * " ") - string = ( - f"{self.__class__.__name__}(" - f"\n{env_str}," - f"\n{policy_str}," - f"\n{td_out_str}," - f"\nexploration={self.exploration_type})" - ) - return string - except Exception: - return f"{type(self).__name__}(not_init)" - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy.""" - # send command to policy to return the attr - return getattr(self._wrapped_policy, attr) - - def getattr_env(self, attr): - """Get an attribute from the environment.""" - # send command to env to return the attr - return getattr(self.env, attr) - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - # send command to rb to return the attr - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the unwrapped policy instance for weight synchronization - # The unwrapped policy has the same parameter structure as what's - # extracted in the main process, avoiding key mismatches when - # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) - if hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - -class _MultiDataCollector(DataCollectorBase): - """Runs a given number of DataCollectors on separate processes. - - Args: - create_env_fn (List[Callabled]): list of Callables, each returning an - instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided (default), the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: - ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable - (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - .. warning:: `policy_factory` is currently not compatible with multiprocessed data - collectors. - - num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. - Defaults to `None` (workers determined by the `create_env_fn` length). - frames_per_batch (int, Sequence[int]): A keyword-only argument representing the - total number of elements in a batch. If a sequence is provided, represents the number of elements in a - batch per worker. Total number of elements in a batch is then the sum over the sequence. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be - :class:`~torchrl.collectors.SyncDataCollector`, - :class:`~torchrl.collectors.MultiSyncDataCollector`, - :class:`~torchrl.collectors.MultiaSyncDataCollector` - or a derived class of these. - Defaults to :class:`~torchrl.collectors.SyncDataCollector`. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). - If ``"stack"``, the data collected from the workers will be stacked along the - first dimension. This is the preferred behavior as it is the most compatible - with the rest of the library. - If ``0``, results will be concatenated along the first dimension - of the outputs, which can be the batched dimension if the environments are - batched or the time dimension if not. - A ``cat_results`` value of ``-1`` will always concatenate results along the - time dimension. This should be preferred over the default. Intermediate values - are also accepted. - Defaults to ``"stack"``. - - .. note:: From v0.5, this argument will default to ``"stack"`` for a better - interoperability with the rest of the library. - - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. Defaults to ``None``. - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True` for multiprocessed data collectors. - local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize - the replay buffer in the main process (legacy behavior). If ``True``, the storage-level - coordination will handle initialization with real data from worker processes. - Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. - This parameter is deprecated and will be removed in v0.12. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, - which handles weight synchronization across multiple processes. - Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. - If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Sequence[Callable[[], EnvBase]], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - num_workers: int | None = None, - policy_factory: Callable[[], Callable] - | list[Callable[[], Callable]] - | None = None, - frames_per_batch: int | Sequence[int], - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict] | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - cat_results: str | int | None = None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - replay_buffer_chunk: bool | None = None, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - ): - self.closed = True - - # Set up workers and environment functions - create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( - create_env_fn, num_workers, frames_per_batch - ) - - # Set up basic configuration - self.set_truncated = set_truncated - self.num_sub_threads = num_sub_threads - self.num_threads = num_threads - self.create_env_fn = create_env_fn - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Set up environment kwargs - self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) - - # Set up devices - storing_devices, policy_devices, env_devices = self._get_devices( - storing_device=storing_device, - env_device=env_device, - policy_device=policy_device, - device=device, - ) - self.storing_device = storing_devices - self.policy_device = policy_devices - self.env_device = env_devices - self.collector_class = collector_class - del storing_device, env_device, policy_device, device - self.no_cuda_sync = no_cuda_sync - - # Set up replay buffer - self._use_buffers = use_buffers - self.replay_buffer = replay_buffer - self._setup_multi_replay_buffer( - local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer - ) - - # Set up policy and weights - if trust_policy is None: - trust_policy = policy is not None and isinstance(policy, CudaGraphModule) - self.trust_policy = trust_policy - - policy_factory = self._setup_policy_factory(policy_factory) - - # Set up weight synchronization - if ( - not any(policy_factory) - and not weight_sync_schemes - and weight_updater is None - ): - weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} - - self._setup_multi_policy_and_weights( - policy, policy_factory, weight_updater, weight_sync_schemes - ) - - self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) - - # Set up policy version tracking - self._setup_multi_policy_version_tracking(track_policy_version) - - # Store policy and policy_factory - self.policy = policy - self.policy_factory = policy_factory - - # Set up fallback policy for weight extraction - self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) - - # Set up total frames and other parameters - self._setup_multi_total_frames( - total_frames, total_frames_per_batch, frames_per_batch - ) - self.reset_at_each_iter = reset_at_each_iter - self.postprocs = postproc - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - - # Set up split trajectories - self.requested_frames_per_batch = total_frames_per_batch - self.reset_when_done = reset_when_done - self._setup_split_trajs(split_trajs, reset_when_done) - - # Set up other parameters - self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 - ) - self.update_at_each_batch = update_at_each_batch - self.exploration_type = exploration_type - self.frames_per_worker = np.inf - - # Set up preemptive threshold - self._setup_preemptive_threshold(preemptive_threshold) - - # Run worker processes - try: - self._run_processes() - except Exception as e: - self.shutdown(raise_on_error=False) - raise e - - # Set up frame tracking and other options - self._exclude_private_keys = True - self._frames = 0 - self._iter = -1 - - # Validate cat_results - self._validate_cat_results(cat_results) - - def _setup_workers_and_env_fns( - self, - create_env_fn: Sequence[Callable] | Callable, - num_workers: int | None, - frames_per_batch: int | Sequence[int], - ) -> tuple[list[Callable], int]: - """Set up workers and environment functions.""" - if isinstance(create_env_fn, Sequence): - self.num_workers = len(create_env_fn) - else: - self.num_workers = num_workers - create_env_fn = [create_env_fn] * self.num_workers - - if ( - isinstance(frames_per_batch, Sequence) - and len(frames_per_batch) != self.num_workers - ): - raise ValueError( - "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." - f"Got {len(frames_per_batch)} values for {self.num_workers} workers." - ) - - self._frames_per_batch = frames_per_batch - total_frames_per_batch = ( - sum(frames_per_batch) - if isinstance(frames_per_batch, Sequence) - else frames_per_batch - ) - - return create_env_fn, total_frames_per_batch - - def _setup_env_kwargs( - self, create_env_kwargs: Sequence[dict] | dict | None - ) -> list[dict]: - """Set up environment kwargs for each worker.""" - if isinstance(create_env_kwargs, Mapping): - create_env_kwargs = [create_env_kwargs] * self.num_workers - elif create_env_kwargs is None: - create_env_kwargs = [{}] * self.num_workers - elif isinstance(create_env_kwargs, (tuple, list)): - create_env_kwargs = list(create_env_kwargs) - if len(create_env_kwargs) != self.num_workers: - raise ValueError( - f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" - ) - return create_env_kwargs - - def _setup_multi_replay_buffer( - self, - local_init_rb: bool | None, - replay_buffer: ReplayBuffer | None, - replay_buffer_chunk: bool | None, - extend_buffer: bool, - ) -> None: - """Set up replay buffer for multi-process collector.""" - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - self._check_replay_buffer_init() - - if replay_buffer_chunk is not None: - if extend_buffer is None: - replay_buffer_chunk = extend_buffer - warnings.warn( - "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", - DeprecationWarning, - ) - elif extend_buffer != replay_buffer_chunk: - raise ValueError( - "conflicting values for replay_buffer_chunk and extend_buffer." - ) - self.extend_buffer = extend_buffer - - if ( - replay_buffer is not None - and hasattr(replay_buffer, "shared") - and not replay_buffer.shared - ): - torchrl_logger.warning("Replay buffer is not shared. Sharing it.") - replay_buffer.share() - - def _setup_policy_factory( - self, policy_factory: Callable | list[Callable] | None - ) -> list[Callable | None]: - """Set up policy factory for each worker.""" - if not isinstance(policy_factory, Sequence): - policy_factory = [policy_factory] * self.num_workers - return policy_factory - - def _setup_multi_policy_and_weights( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up policy and extract weights for each device.""" - self._policy_weights_dict = {} - self._fallback_policy = None # Policy to use for weight extraction fallback - - if any(policy_factory) and policy is not None: - raise TypeError("policy_factory and policy are mutually exclusive") - elif not any(policy_factory): - for policy_device, env_maker, env_maker_kwargs in _zip_strict( - self.policy_device, self.create_env_fn, self.create_env_kwargs - ): - policy_new_device, get_weights_fn = self._get_policy_and_device( - policy=policy, - policy_device=policy_device, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) - if type(policy_new_device) is not type(policy): - policy = policy_new_device - weights = ( - TensorDict.from_module(policy_new_device) - if isinstance(policy_new_device, nn.Module) - else TensorDict() - ) - # For multi-process collectors, ensure weights are in shared memory - if policy_device and policy_device.type == "cpu": - weights = weights.share_memory_() - self._policy_weights_dict[policy_device] = weights - # Store the first policy instance for fallback weight extraction - if self._fallback_policy is None: - self._fallback_policy = policy_new_device - self._get_weights_fn = get_weights_fn - if weight_updater is None: - # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default - if weight_sync_schemes is None: - weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} - elif weight_updater is None: - warnings.warn( - "weight_updater is None, but policy_factory is provided. This means that the server will " - "not know how to send the weights to the workers. If the workers can handle their weight synchronization " - "on their own (via some specialized worker type / constructor) this may well work, but make sure " - "your weight synchronization strategy is properly set. To suppress this warning, you can use " - "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " - "This will work whenever your inference and training policies are nn.Module instances with similar structures." - ) - - def _setup_multi_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization for multi-process collector.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Senders will be created in _run_processes when pipes are available - self.weight_updater = None # Don't use legacy system - else: - # Fall back to legacy weight updater system - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - - def _setup_multi_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking for multi-process collector.""" - self.policy_version_tracker = track_policy_version - if PolicyVersion is not None: - if isinstance(track_policy_version, bool) and track_policy_version: - self.policy_version_tracker = PolicyVersion() - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - else: - self.policy_version_tracker = None - else: - if track_policy_version: - raise ImportError( - "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." - ) - self.policy_version_tracker = None - - def _setup_fallback_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up fallback policy for weight extraction when using policy_factory.""" - # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided - # If policy_factory was used, create a policy instance to use as fallback - if policy is None and any(policy_factory) and weight_sync_schemes is not None: - if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: - first_factory = ( - policy_factory[0] - if isinstance(policy_factory, list) - else policy_factory - ) - if first_factory is not None: - # Create a policy instance for weight extraction - # This will be a reference to a policy with the same structure - # For shared memory, modifications to any policy will be visible here - self._fallback_policy = first_factory() - - def _setup_multi_total_frames( - self, - total_frames: int, - total_frames_per_batch: int, - frames_per_batch: int | Sequence[int], - ) -> None: - """Validate and set total frames for multi-process collector.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % total_frames_per_batch - if remainder != 0 and RL_WARNINGS: - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " - f"This means {total_frames_per_batch - remainder} additional frames will be collected. " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_split_trajs( - self, split_trajs: bool | None, reset_when_done: bool - ) -> None: - """Set up split trajectories option.""" - if split_trajs is None: - split_trajs = False - elif not reset_when_done and split_trajs: - raise RuntimeError( - "Cannot split trajectories when reset_when_done is False." - ) - self.split_trajs = split_trajs - - def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: - """Set up preemptive threshold for early stopping.""" - if preemptive_threshold is not None: - if _is_osx: - raise NotImplementedError( - "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." - ) - self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) - manager = _InterruptorManager() - manager.start() - self.interruptor = manager._Interruptor() - else: - self.preemptive_threshold = 1.0 - self.interruptor = None - - def _validate_cat_results(self, cat_results: str | int | None) -> None: - """Validate cat_results parameter.""" - if cat_results is not None and ( - not isinstance(cat_results, (int, str)) - or (isinstance(cat_results, str) and cat_results != "stack") - ): - raise ValueError( - "cat_results must be a string ('stack') " - f"or an integer representing the cat dimension. Got {cat_results}." - ) - if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( - "stack", - None, - ): - raise ValueError( - "cat_results can only be used with ``MultiSyncDataCollector``." - ) - self.cat_results = cat_results - - def _check_replay_buffer_init(self): - if self.replay_buffer is None: - return - is_init = hasattr(self.replay_buffer, "_storage") and getattr( - self.replay_buffer._storage, "initialized", True - ) - if not is_init: - if self.local_init_rb: - # New behavior: storage handles all coordination itself - # Nothing to do here - the storage will coordinate during first write - self.replay_buffer.share() - return - - # Legacy behavior: fake tensordict initialization - if isinstance(self.create_env_fn[0], EnvCreator): - fake_td = self.create_env_fn[0].meta_data.tensordict - elif isinstance(self.create_env_fn[0], EnvBase): - fake_td = self.create_env_fn[0].fake_tensordict() - else: - fake_td = self.create_env_fn[0]( - **self.create_env_kwargs[0] - ).fake_tensordict() - fake_td["collector", "traj_ids"] = torch.zeros( - fake_td.shape, dtype=torch.long - ) - # Use extend to avoid time-related transforms to fail - self.replay_buffer.extend(fake_td.unsqueeze(-1)) - self.replay_buffer.empty() - - @classmethod - def _total_workers_from_env(cls, env_creators): - if isinstance(env_creators, (tuple, list)): - return sum( - cls._total_workers_from_env(env_creator) for env_creator in env_creators - ) - from torchrl.envs import ParallelEnv - - if isinstance(env_creators, ParallelEnv): - return env_creators.num_workers - return 1 - - def _get_devices( - self, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - # convert all devices to lists - if not isinstance(storing_device, (list, tuple)): - storing_device = [ - storing_device, - ] * self.num_workers - if not isinstance(policy_device, (list, tuple)): - policy_device = [ - policy_device, - ] * self.num_workers - if not isinstance(env_device, (list, tuple)): - env_device = [ - env_device, - ] * self.num_workers - if not isinstance(device, (list, tuple)): - device = [ - device, - ] * self.num_workers - if not ( - len(device) - == len(storing_device) - == len(policy_device) - == len(env_device) - == self.num_workers - ): - raise RuntimeError( - f"THe length of the devices does not match the number of workers: {self.num_workers}." - ) - storing_device, policy_device, env_device = zip( - *[ - SyncDataCollector._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - for (storing_device, policy_device, env_device, device) in zip( - storing_device, policy_device, env_device, device - ) - ] - ) - return storing_device, policy_device, env_device - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - raise NotImplementedError - - @property - def _queue_len(self) -> int: - raise NotImplementedError - - def _run_processes(self) -> None: - if self.num_threads is None: - total_workers = self._total_workers_from_env(self.create_env_fn) - self.num_threads = max( - 1, torch.get_num_threads() - total_workers - ) # 1 more thread for this proc - - # Weight senders will be initialized after workers are ready (via init_on_sender) - torch.set_num_threads(self.num_threads) - queue_out = mp.Queue(self._queue_len) # sends data from proc to main - self.procs = [] - self.pipes = [] - self._traj_pool = _TrajectoryPool(lock=True) - # Create a policy on the right device - policy_factory = self.policy_factory - if any(policy_factory): - policy_factory = [ - CloudpickleWrapper(_policy_factory) - for _policy_factory in policy_factory - ] - - for i, (env_fun, env_fun_kwargs) in enumerate( - zip(self.create_env_fn, self.create_env_kwargs) - ): - pipe_parent, pipe_child = mp.Pipe() # send messages to procs - if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( - env_fun, EnvBase - ): # to avoid circular imports - env_fun = CloudpickleWrapper(env_fun) - - policy_device = self.policy_device[i] - storing_device = self.storing_device[i] - env_device = self.env_device[i] - # We take the weights, the policy, and locally dispatch the weights to the policy - # while we send the policy to the remote process. - # This makes sure that a given set of shared weights for a given device are - # shared for all policies that rely on that device. - policy = self.policy - policy_weights = self._policy_weights_dict.get(policy_device) - if policy is not None and policy_weights is not None: - cm = policy_weights.to_module(policy) - else: - cm = contextlib.nullcontext() - with cm: - kwargs = { - "policy_factory": policy_factory[i], - "pipe_parent": pipe_parent, - "pipe_child": pipe_child, - "queue_out": queue_out, - "create_env_fn": env_fun, - "create_env_kwargs": env_fun_kwargs, - "policy": policy, - "max_frames_per_traj": self.max_frames_per_traj, - "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), - "reset_at_each_iter": self.reset_at_each_iter, - "policy_device": policy_device, - "storing_device": storing_device, - "env_device": env_device, - "exploration_type": self.exploration_type, - "reset_when_done": self.reset_when_done, - "idx": i, - "interruptor": self.interruptor, - "set_truncated": self.set_truncated, - "use_buffers": self._use_buffers, - "replay_buffer": self.replay_buffer, - "extend_buffer": self.extend_buffer, - "traj_pool": self._traj_pool, - "trust_policy": self.trust_policy, - "compile_policy": self.compiled_policy_kwargs - if self.compiled_policy - else False, - "cudagraph_policy": self.cudagraphed_policy_kwargs - if self.cudagraphed_policy - else False, - "no_cuda_sync": self.no_cuda_sync, - "collector_class": self.collector_class, - "postproc": self.postprocs - if self.replay_buffer is not None - else None, - "weight_sync_schemes": self._weight_sync_schemes, - } - proc = _ProcessNoWarn( - target=_main_async_collector, - num_threads=self.num_sub_threads, - kwargs=kwargs, - ) - # proc.daemon can't be set as daemonic processes may be launched by the process itself - try: - proc.start() - except TypeError as err: - if "cannot pickle" in str(err): - raise RuntimeError( - "A non-serializable object was passed to the collector workers." - ) from err - except RuntimeError as err: - if "Cowardly refusing to serialize non-leaf tensor" in str(err): - raise RuntimeError( - "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " - "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" - "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" - "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." - ) from err - else: - raise err - except _pickle.PicklingError as err: - if "" in str(err): - raise RuntimeError( - """Can't open a process with doubly cloud-pickled lambda function. -This error is likely due to an attempt to use a ParallelEnv in a -multiprocessed data collector. To do this, consider wrapping your -lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: -`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. -This will not only ensure that your lambda function is cloud-pickled once, but -also that the state dict is synchronised across processes if needed.""" - ) from err - pipe_child.close() - self.procs.append(proc) - self.pipes.append(pipe_parent) - - # Worker registration now handled by init_on_sender() after workers are ready - for i, pipe_parent in enumerate(self.pipes): - pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) - try: - msg = pipe_parent.recv() - except EOFError as e: - raise RuntimeError( - f"Worker {i} failed to initialize and closed the connection before sending status. " - f"This typically indicates that the worker process crashed during initialization. " - f"Check the worker process logs for the actual error." - ) from e - if msg != "instantiated": - # Check if it's an error dict from worker - if isinstance(msg, dict) and msg.get("error"): - # Reconstruct the exception from the worker - exc_type_name = msg["exception_type"] - exc_msg = msg["exception_msg"] - traceback_str = msg["traceback"] - - # Try to get the actual exception class - exc_class = None - exc_module = msg["exception_module"] - - if exc_module == "builtins": - # Get from builtins - import builtins - - exc_class = getattr(builtins, exc_type_name, None) - else: - # Try to import from the module - try: - import importlib - - mod = importlib.import_module(exc_module) - exc_class = getattr(mod, exc_type_name, None) - except Exception: - pass - - # Re-raise with original exception type if possible - if exc_class is not None: - raise exc_class( - f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Fall back to RuntimeError if we can't get the original type - raise RuntimeError( - f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Legacy string error message - raise RuntimeError(msg) - - # Initialize all weight sync schemes now that workers are ready - # This calls init_on_sender() for each scheme which: - # 1. Creates transports for all workers - # 2. Creates and configures the sender - # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock - if self._weight_sync_schemes: - for model_id, scheme in self._weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_sender"): - scheme.init_on_sender(model_id=model_id, context=self) - # Get the initialized sender - self._weight_senders[model_id] = scheme.get_sender() - # else: keep using legacy _weight_senders initialization from before - - self.queue_out = queue_out - self.closed = False - - _running_free = False - - def start(self): - """Starts the collector(s) for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method initiates the background collection of - data across multiple processes, allowing for decoupling of data collection and training. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env_fn = partial(GymEnv, "ALE/Pong-v5") - ... policy = RandomPolicy(env_fn().action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) - ... - ... # Create a multi-async data collector with 16 environments - ... num_envs = 16 - ... collector = MultiaSyncDataCollector( - ... [env_fn] * num_envs, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=num_envs * 16, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if self.init_random_frames is not None and self.init_random_frames > 0: - raise RuntimeError( - "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." - ) - self._running_free = True - for pipe in self.pipes: - pipe.send((None, "run_free")) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - if self._running_free: - for pipe in self.pipes: - pipe.send((None, "pause")) - # Make sure all workers are paused - for _ in self.pipes: - idx, msg = self.queue_out.get() - if msg != "paused": - raise ValueError(f"Expected paused, but got {msg=}.") - torchrl_logger.info(f"Worker {idx} is paused.") - self._running_free = False - yield None - for pipe in self.pipes: - pipe.send((None, "restart")) - self._running_free = True - else: - raise RuntimeError("Collector cannot be paused.") - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all processes. This operation is irreversible. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - try: - self._shutdown_main(timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def _shutdown_main(self, timeout: float | None = None) -> None: - if timeout is None: - timeout = 10 - try: - if self.closed: - return - _check_for_faulty_process(self.procs) - all_closed = [False] * self.num_workers - rep = 0 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - continue - self.pipes[idx].send((None, "close")) - - while not all(all_closed) and rep < 1000: - rep += 1 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - all_closed[idx] = True - continue - try: - if self.pipes[idx].poll(timeout / 1000 / self.num_workers): - msg = self.pipes[idx].recv() - if msg != "closed": - raise RuntimeError(f"got {msg} but expected 'close'") - all_closed[idx] = True - else: - continue - except BrokenPipeError: - all_closed[idx] = True - continue - self.closed = True - - self.queue_out.close() - for pipe in self.pipes: - pipe.close() - for proc in self.procs: - proc.join(1.0) - finally: - import torchrl - - num_threads = min( - torchrl._THREAD_POOL_INIT, - torch.get_num_threads() - + self._total_workers_from_env(self.create_env_fn), - ) - torch.set_num_threads(num_threads) - - for proc in self.procs: - if proc.is_alive(): - proc.terminate() - - def async_shutdown(self, timeout: float | None = None): - return self.shutdown(timeout=timeout) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed: integer representing the seed to be used for the environment. - static_seed (bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is - contained in the DataCollector, as the seed will be incremented for - each of these. The resulting seed is the seed of the last - environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - _check_for_faulty_process(self.procs) - for idx in range(self.num_workers): - self.pipes[idx].send(((seed, static_seed), "seed")) - new_seed, msg = self.pipes[idx].recv() - if msg != "seeded": - raise RuntimeError(f"Expected msg='seeded', got {msg}") - seed = new_seed - self.reset() - return seed - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - """Resets the environments to a new initial state. - - Args: - reset_idx: Optional. Sequence indicating which environments have - to be reset. If None, all environments are reset. - - """ - _check_for_faulty_process(self.procs) - - if reset_idx is None: - reset_idx = [True for _ in range(self.num_workers)] - for idx in range(self.num_workers): - if reset_idx[idx]: - self.pipes[idx].send((None, "reset")) - for idx in range(self.num_workers): - if reset_idx[idx]: - j, msg = self.pipes[idx].recv() - if msg != "reset": - raise RuntimeError(f"Expected msg='reset', got {msg}") - - def state_dict(self) -> OrderedDict: - """Returns the state_dict of the data collector. - - Each field represents a worker containing its own state_dict. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((None, "state_dict")) - state_dict = OrderedDict() - for idx in range(self.num_workers): - _state_dict, msg = self.pipes[idx].recv() - if msg != "state_dict": - raise RuntimeError(f"Expected msg='state_dict', got {msg}") - state_dict[f"worker{idx}"] = _state_dict - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict) -> None: - """Loads the state_dict on the workers. - - Args: - state_dict (OrderedDict): state_dict of the form - ``{"worker0": state_dict0, "worker1": state_dict1}``. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) - for idx in range(self.num_workers): - _, msg = self.pipes[idx].recv() - if msg != "loaded": - raise RuntimeError(f"Expected msg='loaded', got {msg}") - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy of the first worker. - - Args: - attr (str): The attribute name to retrieve from the policy. - - Returns: - The attribute value from the policy of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the policy. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_policy")) - result, msg = self.pipes[0].recv() - if msg != "getattr_policy": - raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_env(self, attr): - """Get an attribute from the environment of the first worker. - - Args: - attr (str): The attribute name to retrieve from the environment. - - Returns: - The attribute value from the environment of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the environment. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_env")) - result, msg = self.pipes[0].recv() - if msg != "getattr_env": - raise RuntimeError(f"Expected msg='getattr_env', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the fallback policy instance - if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: - return self._fallback_policy - elif hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - def get_cached_weights(self, model_id: str): - """Get cached shared memory weights if available (for weight sync schemes). - - Args: - model_id: Model identifier - - Returns: - Cached TensorDict weights or None if not available - """ - if model_id == "policy" and hasattr(self, "_policy_weights_dict"): - # Get the policy device (first device if list) - policy_device = self.policy_device - if isinstance(policy_device, (list, tuple)): - policy_device = policy_device[0] if len(policy_device) > 0 else None - - # Return cached weights for this device - return self._policy_weights_dict.get(policy_device) - return None - - -@accept_remote_rref_udf_invocation -class MultiSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes synchronously. - - .. aafig:: - - +----------------------------------------------------------------------+ - | "MultiSyncDataCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | Main | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor" | "step" | "step" | "actor" | | - | | | | | | - | | "actor" | | | - | | | | | - | "yield batch of traj 1"------->"collect, train"| - | | | - | "step" | "step" | "step" | "step" | "step" | "step" | | - | | | | | | | | - | "actor" | "actor" | | | | - | | "step" | "step" | "actor" | | - | | | | | | - | "step" | "step" | "actor" | "step" | "step" | | - | | | | | | | - | "actor" | | "actor" | | - | "yield batch of traj 2"------->"collect, train"| - | | | - +----------------------------------------------------------------------+ - - Envs can be identical or different. - - The collection starts when the next item of the collector is queried, - and no environment step is computed in between the reception of a batch of - trajectory and the start of the next collection. - This class can be safely used with online RL sota-implementations. - - .. note:: - Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - ... collector = MultiSyncDataCollector(...) - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - if hasattr(self, "out_buffer"): - del self.out_buffer - if hasattr(self, "buffers"): - del self.buffers - try: - return super().shutdown(timeout=timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None) -> int: - if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): - return self._frames_per_batch[worker_idx] - if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: - warnings.warn( - f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," - f" this results in more frames_per_batch per iteration that requested." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - frames_per_batch_worker = -( - -self.requested_frames_per_batch // self.num_workers - ) - return frames_per_batch_worker - - @property - def _queue_len(self) -> int: - return self.num_workers - - def iterator(self) -> Iterator[TensorDictBase]: - cat_results = self.cat_results - if cat_results is None: - cat_results = "stack" - - self.buffers = {} - dones = [False for _ in range(self.num_workers)] - workers_frames = [0 for _ in range(self.num_workers)] - same_device = None - self.out_buffer = None - preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 - - while not all(dones) and self._frames < self.total_frames: - _check_for_faulty_process(self.procs) - if self.update_at_each_batch: - self.update_policy_weights_() - - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - # Debug: sending 'continue' - self.pipes[idx].send((None, msg)) - - self._iter += 1 - - if preempt: - self.interruptor.start_collection() - while self.queue_out.qsize() < int( - self.num_workers * self.preemptive_threshold - ): - continue - self.interruptor.stop_collection() - # Now wait for stragglers to return - while self.queue_out.qsize() < int(self.num_workers): - continue - - recv = collections.deque() - t0 = time.time() - while len(recv) < self.num_workers and ( - (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) - ): - for _ in range(self.num_workers): - try: - new_data, j = self.queue_out.get(timeout=_TIMEOUT) - recv.append((new_data, j)) - except (TimeoutError, Empty): - _check_for_faulty_process(self.procs) - if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): - try: - self.shutdown() - finally: - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - - for _ in range(self.num_workers): - new_data, j = recv.popleft() - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - workers_frames[idx] = workers_frames[ - idx - ] + self.frames_per_batch_worker(worker_idx=idx) - continue - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.buffers[idx] = data - if use_buffers is None and j > 0: - self._use_buffers = False - except TypeError: - if use_buffers is None: - self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - - if preempt: - # mask buffers if cat, and create a mask if stack - if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): - valid = buffer.get(("collector", "traj_ids")) != -1 - if valid.ndim > 2: - valid = valid.flatten(0, -2) - if valid.ndim == 2: - valid = valid.any(0) - buffers[worker_idx] = buffer[..., valid] - else: - for buffer in self.buffers.values(): - with buffer.unlock_(): - buffer.set( - ("collector", "mask"), - buffer.get(("collector", "traj_ids")) != -1, - ) - buffers = self.buffers - else: - buffers = self.buffers - - # Skip frame counting if this worker didn't send data this iteration - # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: - continue - - workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() - - if workers_frames[idx] >= self.total_frames: - dones[idx] = True - - if self.replay_buffer is not None: - yield - self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx) - for worker_idx in range(self.num_workers) - ] - ) - continue - - # we have to correct the traj_ids to make sure that they don't overlap - # We can count the number of frames collected for free in this loop - n_collected = 0 - for idx in buffers.keys(): - buffer = buffers[idx] - traj_ids = buffer.get(("collector", "traj_ids")) - if preempt: - if cat_results == "stack": - mask_frames = buffer.get(("collector", "traj_ids")) != -1 - n_collected += mask_frames.sum().cpu() - else: - n_collected += traj_ids.numel() - else: - n_collected += traj_ids.numel() - - if same_device is None: - prev_device = None - same_device = True - for item in self.buffers.values(): - if prev_device is None: - prev_device = item.device - else: - same_device = same_device and (item.device == prev_device) - - if cat_results == "stack": - stack = ( - torch.stack if self._use_buffers else TensorDict.maybe_dense_stack - ) - if same_device: - self.out_buffer = stack(list(buffers.values()), 0) - else: - self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 - ) - else: - if self._use_buffers is None: - torchrl_logger.warning( - "use_buffer not specified and not yet inferred from data, assuming `True`." - ) - elif not self._use_buffers: - raise RuntimeError( - "Cannot concatenate results with use_buffers=False" - ) - try: - if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) - else: - self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results - ) - except RuntimeError as err: - if ( - preempt - and cat_results != -1 - and "Sizes of tensors must match" in str(err) - ): - raise RuntimeError( - "The value provided to cat_results isn't compatible with the collectors outputs. " - "Consider using `cat_results=-1`." - ) - raise - - # TODO: why do we need to do cat inplace and clone? - if self.split_trajs: - out = split_trajectories(self.out_buffer, prefix="collector") - else: - out = self.out_buffer - if cat_results in (-1, "stack"): - out.refine_names(*[None] * (out.ndim - 1) + ["time"]) - - self._frames += n_collected - - if self.postprocs: - self.postprocs = ( - self.postprocs.to(out.device) - if hasattr(self.postprocs, "to") - else self.postprocs - ) - out = self.postprocs(out) - if self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - if excluded_keys: - out = out.exclude(*excluded_keys) - yield out - del out - - del self.buffers - self.out_buffer = None - # We shall not call shutdown just yet as user may want to retrieve state_dict - # self._shutdown_main() - - -@accept_remote_rref_udf_invocation -class MultiaSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes asynchronously. - - .. aafig:: - - - +----------------------------------------------------------------------+ - | "MultiConcurrentCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor | "step" | "step" | "actor" | | - | | | | | | - | "yield batch 1" | "actor" | |"collect, train"| - | | | | | - | "step" | "step" | | "yield batch 2" |"collect, train"| - | | | | | | - | | | "yield batch 3" | |"collect, train"| - | | | | | | - +----------------------------------------------------------------------+ - - Environment types can be identical or different. - - The collection keeps on occurring on all processes even between the time - the batch of rollouts is collected and the next call to the iterator. - This class can be safely used with offline RL sota-implementations. - - .. note:: Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.out_tensordicts = defaultdict(lambda: None) - self.running = False - - if self.postprocs is not None and self.replay_buffer is None: - postproc = self.postprocs - self.postprocs = {} - for _device in self.storing_device: - if _device not in self.postprocs: - if hasattr(postproc, "to"): - postproc = deepcopy(postproc).to(_device) - self.postprocs[_device] = postproc - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - return self.requested_frames_per_batch - - def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: - new_data, j = self.queue_out.get(timeout=timeout) - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.out_tensordicts[idx] = data - if use_buffers is None and j > 0: - use_buffers = self._use_buffers = False - except TypeError: - if use_buffers is None: - use_buffers = self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - out = self.out_tensordicts[idx] - if not self.replay_buffer and (j == 0 or use_buffers): - # we clone the data to make sure that we'll be working with a fixed copy - out = out.clone() - return idx, j, out - - @property - def _queue_len(self) -> int: - return 1 - - def iterator(self) -> Iterator[TensorDictBase]: - if self.update_at_each_batch: - self.update_policy_weights_() - - for i in range(self.num_workers): - if self.init_random_frames is not None and self.init_random_frames > 0: - self.pipes[i].send((None, "continue_random")) - else: - self.pipes[i].send((None, "continue")) - self.running = True - - workers_frames = [0 for _ in range(self.num_workers)] - while self._frames < self.total_frames: - self._iter += 1 - counter = 0 - while True: - try: - idx, j, out = self._get_from_queue(timeout=_TIMEOUT) - break - except (TimeoutError, Empty): - counter += _TIMEOUT - _check_for_faulty_process(self.procs) - if counter > (_TIMEOUT * _MAX_IDLE_COUNT): - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - if self.replay_buffer is None: - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") - else: - worker_frames = self.frames_per_batch_worker() - self._frames += worker_frames - workers_frames[idx] = workers_frames[idx] + worker_frames - if out is not None and self.postprocs: - out = self.postprocs[out.device](out) - - # the function blocks here until the next item is asked, hence we send the message to the - # worker to keep on working in the meantime before the yield statement - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - self.pipes[idx].send((idx, msg)) - if out is not None and self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - out = out.exclude(*excluded_keys) - yield out - - # We don't want to shutdown yet, the user may want to call state_dict before - # self._shutdown_main() - self.running = False - - def _shutdown_main(self, *args, **kwargs) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - return super()._shutdown_main(*args, **kwargs) - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - super().reset(reset_idx) - if self.queue_out.full(): - time.sleep(_TIMEOUT) # wait until queue is empty - if self.queue_out.full(): - raise Exception("self.queue_out is full") - if self.running: - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.pipes[idx].send((idx, "continue_random")) - else: - self.pipes[idx].send((idx, "continue")) - - -@accept_remote_rref_udf_invocation -class aSyncDataCollector(MultiaSyncDataCollector): - """Runs a single DataCollector on a separate process. - - This is mostly useful for offline RL paradigms where the policy being - trained can differ from the policy used to collect data. In online - settings, a regular DataCollector should be preferred. This class is - merely a wrapper around a MultiaSyncDataCollector where a single process - is being created. - - Args: - create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the - total number of elements in a batch. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Callable[[], EnvBase], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict[str, Any]] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - set_truncated: bool = False, - track_policy_version: bool = False, - **kwargs, - ): - super().__init__( - create_env_fn=[create_env_fn], - policy=policy, - policy_factory=policy_factory, - total_frames=total_frames, - create_env_kwargs=[create_env_kwargs] - if create_env_kwargs - else create_env_kwargs, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - init_random_frames=init_random_frames, - postproc=postproc, - split_trajs=split_trajs, - device=device, - policy_device=policy_device, - env_device=env_device, - storing_device=storing_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - update_at_each_batch=update_at_each_batch, - preemptive_threshold=preemptive_threshold, - num_threads=num_threads, - num_sub_threads=num_sub_threads, - set_truncated=set_truncated, - track_policy_version=track_policy_version, - **kwargs, - ) - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - return super().shutdown( - timeout=timeout, close_env=close_env, raise_on_error=raise_on_error - ) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - -def _main_async_collector( - pipe_parent: connection.Connection, - pipe_child: connection.Connection, - queue_out: queues.Queue, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 - create_env_kwargs: dict[str, Any], - policy: Callable[[TensorDictBase], TensorDictBase], - max_frames_per_traj: int, - frames_per_batch: int, - reset_at_each_iter: bool, - storing_device: torch.device | str | int | None, - env_device: torch.device | str | int | None, - policy_device: torch.device | str | int | None, - idx: int = 0, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - verbose: bool = VERBOSE, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - traj_pool: _TrajectoryPool = None, - trust_policy: bool = False, - compile_policy: bool = False, - cudagraph_policy: bool = False, - no_cuda_sync: bool = False, - policy_factory: Callable | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, -) -> None: - if collector_class is None: - collector_class = SyncDataCollector - pipe_parent.close() - # init variables that will be cleared when closing - collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None - - try: - collector_class._ignore_rb = extend_buffer - inner_collector = collector_class( - create_env_fn, - create_env_kwargs=create_env_kwargs, - policy=policy, - policy_factory=policy_factory, - total_frames=-1, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - postproc=postproc, - split_trajs=False, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - return_same_td=replay_buffer is None, - interruptor=interruptor, - set_truncated=set_truncated, - use_buffers=use_buffers, - replay_buffer=replay_buffer, - extend_buffer=False, - traj_pool=traj_pool, - trust_policy=trust_policy, - compile_policy=compile_policy, - cudagraph_policy=cudagraph_policy, - no_cuda_sync=no_cuda_sync, - weight_sync_schemes=weight_sync_schemes, - ) - - # Set up weight receivers for worker process - if weight_sync_schemes: - inner_collector._weight_receivers = {} - inner_collector.pipe = pipe_child # Add pipe attribute for context - for model_id, scheme in weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_worker"): - scheme.init_on_worker(model_id=model_id, context=inner_collector) - receiver = scheme.get_receiver() - else: - # Legacy API - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) - - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) - - inner_collector._weight_receivers[model_id] = receiver - else: - inner_collector._weight_receivers = {} - - use_buffers = inner_collector._use_buffers - if verbose: - torchrl_logger.info("Sync data collector created") - dc_iter = iter(inner_collector) - j = 0 - pipe_child.send("instantiated") - except Exception as e: - # Send error information to main process - # We send a dict with the exception info so we can recreate it in the main process - import traceback - - error_info = { - "error": True, - "exception_type": type(e).__name__, - "exception_module": type(e).__module__, - "exception_msg": str(e), - "traceback": traceback.format_exc(), - } - try: - pipe_child.send(error_info) - except Exception: - # If pipe is broken, nothing we can do - pass - return - - has_timed_out = False - counter = 0 - run_free = False - while True: - _timeout = _TIMEOUT if not has_timed_out else 1e-3 - if not run_free and pipe_child.poll(_timeout): - counter = 0 - data_in, msg = pipe_child.recv() - if verbose: - torchrl_logger.info(f"worker {idx} received {msg}") - elif not run_free: - if verbose: - torchrl_logger.info(f"poll failed, j={j}, worker={idx}") - # default is "continue" (after first iteration) - # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe - # in that case, the main process probably expects the worker to continue collect data - if has_timed_out: - counter = 0 - # has_timed_out is True if the process failed to send data, which will - # typically occur if main has taken another batch (i.e. the queue is Full). - # In this case, msg is the previous msg sent by main, which will typically be "continue" - # If it's not the case, it is not expected that has_timed_out is True. - if msg not in ("continue", "continue_random"): - raise RuntimeError(f"Unexpected message after time out: msg={msg}") - else: - # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. - # this means that our process has been waiting for a command from main in vain, while main was not - # receiving data. - # This will occur if main is busy doing something else (e.g. computing loss etc). - - counter += _timeout - if verbose: - torchrl_logger.info(f"worker {idx} has counter {counter}") - if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): - raise RuntimeError( - f"This process waited for {counter} seconds " - f"without receiving a command from main. Consider increasing the maximum idle count " - f"if this is expected via the environment variable MAX_IDLE_COUNT " - f"(current value is {_MAX_IDLE_COUNT})." - f"\nIf this occurs at the end of a function or program, it means that your collector has not been " - f"collected, consider calling `collector.shutdown()` before ending the program." - ) - continue - else: - # placeholder, will be checked after - if msg != "continue": - torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") - msg = "continue" - if msg == "run_free": - run_free = True - msg = "continue" - if run_free: - # Capture shutdown / update / seed signal, but continue should not be expected - if pipe_child.poll(1e-4): - data_in, msg = pipe_child.recv() - torchrl_logger.info(f"worker {idx} received {msg} while running free") - if msg == "continue": - # Switch back to run_free = False - run_free = False - if msg == "pause": - queue_out.put((idx, "paused"), timeout=_TIMEOUT) - while not pipe_child.poll(1e-2): - continue - data_in, msg = pipe_child.recv() - if msg != "restart": - raise RuntimeError(f"Expected msg='restart', got {msg=}") - msg = "continue" - else: - data_in = None - # TODO: this does not work with random frames - msg = "continue" - # Note: The "continue" message handling has been moved below after update_weights handling - # to allow falling through from update_weights to continue - - if msg == "update": - torchrl_logger.info(f"worker {idx} updating the params...") - inner_collector.update_policy_weights_(policy_weights=data_in) - pipe_child.send((j, "updated")) - has_timed_out = False - continue - - if msg == "register_shared_weights": - # Shared memory lazy registration: main process sends buffer reference - if verbose: - torchrl_logger.info( - f"worker {idx} received shared memory buffer registration" - ) - model_id, shared_buffer = data_in - - # Store the shared buffer reference for this model - # The receiver will use this buffer for all future weight accesses - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - # Update receiver's buffer reference - receiver = inner_collector._weight_receivers[model_id] - # Store the shared buffer - the model's parameters should point to this - if hasattr(receiver, "_shared_weights"): - receiver._shared_weights[model_id] = shared_buffer - - # Apply the buffer to the model immediately - # Only apply if the model is an nn.Module (has learnable parameters) - try: - model = receiver._resolve_model_ref() - except (ValueError, AttributeError) as e: - # Model not registered or reference is invalid - if verbose: - torchrl_logger.warning( - f"worker {idx} could not resolve model '{model_id}': {e}" - ) - continue - - if isinstance(model, nn.Module): - receiver.apply_weights(shared_buffer) - else: - if verbose: - torchrl_logger.info( - f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" - ) - - if verbose: - torchrl_logger.info( - f"worker {idx} registered shared buffer for model '{model_id}'" - ) - else: - torchrl_logger.warning( - f"worker {idx} received shared buffer for unknown model '{model_id}'" - ) - - # Send acknowledgment back to main process - pipe_child.send((None, "registered")) - has_timed_out = False - continue - - if msg == "update_weights": - # New weight update protocol for simplified weight sync system - if verbose: - torchrl_logger.info( - f"worker {idx} received weight update via new protocol" - ) - model_id, weights = data_in - - # Apply weights using the appropriate receiver for this model - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - inner_collector._weight_receivers[model_id].apply_weights(weights) - else: - torchrl_logger.warning( - f"worker {idx} received weights for unknown model '{model_id}'" - ) - - # After applying weights, we continue collecting immediately as if we received - # a "continue" message. This ensures the worker keeps collecting data without - # waiting for an explicit continue from the main process. - has_timed_out = False - msg = "continue" - # Now check if we should continue collecting - - if msg in ("continue", "continue_random"): - # This block handles both explicit continue messages and implicit ones after weight updates - if msg == "continue_random": - inner_collector.init_random_frames = float("inf") - else: - inner_collector.init_random_frames = -1 - - # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the - # main message loop above (msg == "update_weights" case). The receiver.receive() - # pattern is only used for schemes with separate communication channels like - # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). - # Calling receiver.receive() here would interfere with the pipe-based message protocol. - - next_data = next(dc_iter) - if pipe_child.poll(_MIN_TIMEOUT): - # in this case, main send a message to the worker while it was busy collecting trajectories. - # In that case, we skip the collected trajectory and get the message from main. This is faster than - # sending the trajectory in the queue until timeout when it's never going to be received. - continue - - if replay_buffer is not None: - if extend_buffer: - next_data.names = None - replay_buffer.extend(next_data) - - if run_free: - continue - - try: - queue_out.put((idx, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if j == 0 or not use_buffers: - collected_tensordict = next_data - if ( - storing_device is not None - and collected_tensordict.device != storing_device - ): - raise RuntimeError( - f"expected device to be {storing_device} but got {collected_tensordict.device}" - ) - if use_buffers: - # If policy and env are on cpu, we put in shared mem, - # if policy is on cuda and env on cuda, we are fine with this - # If policy is on cuda and env on cpu (or opposite) we put tensors that - # are on cpu in shared mem. - MPS_ERROR = ( - "tensors on mps device cannot be put in shared memory. Make sure " - "the shared device (aka storing_device) is set to CPU." - ) - if collected_tensordict.device is not None: - # placeholder in case we need different behaviors - if collected_tensordict.device.type in ("cpu",): - collected_tensordict.share_memory_() - elif collected_tensordict.device.type in ("mps",): - raise RuntimeError(MPS_ERROR) - elif collected_tensordict.device.type == "cuda": - collected_tensordict.share_memory_() - else: - raise NotImplementedError( - f"Device {collected_tensordict.device} is not supported in multi-collectors yet." - ) - else: - # make sure each cpu tensor is shared - assuming non-cpu devices are shared - def cast_tensor(x, MPS_ERROR=MPS_ERROR): - if x.device.type in ("cpu",): - x.share_memory_() - if x.device.type in ("mps",): - RuntimeError(MPS_ERROR) - - collected_tensordict.apply(cast_tensor, filter_empty=True) - data = (collected_tensordict, idx) - else: - if next_data is not collected_tensordict: - raise RuntimeError( - "SyncDataCollector should return the same tensordict modified in-place." - ) - data = idx # flag the worker that has sent its data - try: - queue_out.put((data, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if msg == "seed": - data_in, static_seed = data_in - new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) - torch.manual_seed(data_in) - np.random.seed(data_in) - pipe_child.send((new_seed, "seeded")) - has_timed_out = False - continue - - elif msg == "reset": - inner_collector.reset() - pipe_child.send((j, "reset")) - continue - - elif msg == "state_dict": - state_dict = inner_collector.state_dict() - # send state_dict to cpu first - state_dict = recursive_map_to_cpu(state_dict) - pipe_child.send((state_dict, "state_dict")) - has_timed_out = False - continue - - elif msg == "load_state_dict": - state_dict = data_in - inner_collector.load_state_dict(state_dict) - del state_dict - pipe_child.send((j, "loaded")) - has_timed_out = False - continue - - elif msg == "getattr_policy": - attr_name = data_in - try: - result = getattr(inner_collector.policy, attr_name) - pipe_child.send((result, "getattr_policy")) - except AttributeError as e: - pipe_child.send((e, "getattr_policy")) - has_timed_out = False - continue - - elif msg == "getattr_env": - attr_name = data_in - try: - result = getattr(inner_collector.env, attr_name) - pipe_child.send((result, "getattr_env")) - except AttributeError as e: - pipe_child.send((e, "getattr_env")) - has_timed_out = False - continue - - elif msg == "close": - del collected_tensordict, data, next_data, data_in - inner_collector.shutdown() - del inner_collector, dc_iter - pipe_child.send("closed") - if verbose: - torchrl_logger.info(f"collector {idx} closed") - break - - else: - raise Exception(f"Unrecognized message {msg}") - - -def _make_meta_params(param): - is_param = isinstance(param, Parameter) - - pd = param.detach().to("meta") - - if is_param: - pd = Parameter(pd, requires_grad=False) - return pd - - -class _TrajectoryPool: - def __init__(self, ctx=None, lock: bool = False): - self.ctx = ctx - self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() - if ctx is None: - self.lock = contextlib.nullcontext() if not lock else mp.RLock() - else: - self.lock = contextlib.nullcontext() if not lock else ctx.RLock() - - def get_traj_and_increment(self, n=1, device=None): - with self.lock: - v = self._traj_id.item() - out = torch.arange(v, v + n).to(device) - self._traj_id.copy_(1 + out[-1].item()) - return out - - -def _map_weight( - weight, - policy_device, -): - - is_param = isinstance(weight, Parameter) - is_buffer = isinstance(weight, Buffer) - weight = weight.data - if weight.device != policy_device: - weight = weight.to(policy_device) - elif weight.device.type in ("cpu",): - weight = weight.share_memory_() - if is_param: - weight = Parameter(weight, requires_grad=False) - elif is_buffer: - weight = Buffer(weight) - return weight +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors._single_async import aSyncDataCollector +from torchrl.collectors.base import DataCollectorBase + +__all__ = [ + "MultiSyncDataCollector", + "MultiaSyncDataCollector", + "_MultiDataCollector", + "SyncDataCollector", + "_main_async_collector", + "aSyncDataCollector", + "DataCollectorBase", + # Constants + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 4839259e4ca..ff15aa63d67 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -20,13 +20,11 @@ from tensordict.nn import TensorDictModuleBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index b8b28345872..a88e1aa7fcb 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -16,13 +16,11 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase from torchrl.data import ReplayBuffer diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 3d86bbc5422..bdf28942e0f 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -24,13 +24,11 @@ from torch.distributed import rpc from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 980b3a4b489..f81a5efce0a 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -19,13 +19,11 @@ from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, diff --git a/torchrl/collectors/llm/base.py b/torchrl/collectors/llm/base.py index e9ba6e9bcdf..8e4a9578859 100644 --- a/torchrl/collectors/llm/base.py +++ b/torchrl/collectors/llm/base.py @@ -14,7 +14,7 @@ from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors import SyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.llm.utils import _QueueAsRB from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer diff --git a/torchrl/collectors/llm/weight_update/vllm.py b/torchrl/collectors/llm/weight_update/vllm.py index 9b2fe144b0f..15c6e169457 100644 --- a/torchrl/collectors/llm/weight_update/vllm.py +++ b/torchrl/collectors/llm/weight_update/vllm.py @@ -17,7 +17,7 @@ from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends import stateless_init_process_group _has_vllm = importlib.util.find_spec("vllm") is not None diff --git a/torchrl/collectors/llm/weight_update/vllm_v2.py b/torchrl/collectors/llm/weight_update/vllm_v2.py index 0792d7e7de6..f97746ecb25 100644 --- a/torchrl/collectors/llm/weight_update/vllm_v2.py +++ b/torchrl/collectors/llm/weight_update/vllm_v2.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends.vllm import RLvLLMEngine try: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 1f8b2668938..4a9470f708d 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -4,12 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib from collections.abc import Callable +from copy import deepcopy import torch +from pyvers import implement_for -from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase - +from tensordict import NestedKey, pad, set_lazy_legacy, TensorDict, TensorDictBase +from tensordict.utils import Buffer +from torch import multiprocessing as mp, nn as nn +from torch.nn import Parameter _NON_NN_POLICY_WEIGHTS = ( "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and " @@ -257,3 +262,118 @@ def nest(*x): [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 ) return td + + +@implement_for("torch", "2.5.0") +def _make_meta_policy(policy: nn.Module) -> nn.Module: + """Create policy structure with parameters on meta device. + + This is used with weight sync schemes to send policy structure without weights. + The actual weights are distributed by the schemes. + + Args: + policy: Policy module to extract structure from. + + Returns: + A copy of the policy with all parameters on meta device and requires_grad=False. + """ + + def _cast(p, param_maybe_buffer): + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if isinstance(param_maybe_buffer, Buffer): + return Buffer(p) + return p + + param_and_buf = TensorDict.from_module(policy, as_module=True) + with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): + meta_policy = deepcopy(policy) + return meta_policy + + +@implement_for("torch", None, "2.5.0") +def _make_meta_policy(policy: nn.Module) -> nn.Module: # noqa: F811 + """Create policy structure with parameters on meta device. + + This is used with weight sync schemes to send policy structure without weights. + The actual weights are distributed by the schemes. + + Args: + policy: Policy module to extract structure from. + + Returns: + A copy of the policy with all parameters on meta device and requires_grad=False. + """ + + def _cast(p, param_maybe_buffer): + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + return p + + param_and_buf = TensorDict.from_module(policy, as_module=True) + with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): + meta_policy = deepcopy(policy) + return meta_policy + + +def _map_to_cpu_if_needed(x): + """Map tensors on exotic devices (MPS, NPU, etc.) to CPU. + + CPU and CUDA tensors are kept as-is since they can be shared across processes. + Only exotic devices that don't support multiprocessing are mapped to CPU. + """ + if isinstance(x, torch.Tensor): + # CPU and CUDA can be shared across processes + if x.device.type in ("cpu", "cuda"): + return x + # Exotic devices (MPS, NPU, etc.) need to be mapped to CPU + return x.cpu() + return x + + +def _make_meta_params(param): + is_param = isinstance(param, Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = Parameter(pd, requires_grad=False) + return pd + + +class _TrajectoryPool: + def __init__(self, ctx=None, lock: bool = False): + self.ctx = ctx + self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() + if ctx is None: + self.lock = contextlib.nullcontext() if not lock else mp.RLock() + else: + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() + + def get_traj_and_increment(self, n=1, device=None): + with self.lock: + v = self._traj_id.item() + out = torch.arange(v, v + n).to(device) + self._traj_id.copy_(1 + out[-1].item()) + return out + + +def _map_weight( + weight, + policy_device, +): + + is_param = isinstance(weight, Parameter) + is_buffer = isinstance(weight, Buffer) + weight = weight.data + if weight.device != policy_device: + weight = weight.to(policy_device) + elif weight.device.type in ("cpu",): + weight = weight.share_memory_() + if is_param: + weight = Parameter(weight, requires_grad=False) + elif is_buffer: + weight = Buffer(weight) + return weight diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b0993c12242..0ba2c019303 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2701,7 +2701,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: child_pipe.send( @@ -2713,6 +2712,9 @@ def _run_worker_pipe_direct( raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err else: child_pipe.send(cur_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del cur_td @@ -2726,7 +2728,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: next_td = next_td.consolidate( @@ -2735,6 +2736,9 @@ def _run_worker_pipe_direct( except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err child_pipe.send(next_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del next_td diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 6a17125b1d4..94c9bfa2aed 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -906,9 +906,9 @@ def execute(self, prompt: str) -> dict[str, Any]: except queue.Empty: pass - if not start_found: - timeout_val -= 0.1 - time.sleep(0.1) + # Always sleep a bit to avoid busy-waiting and give subprocess time + timeout_val -= 0.01 + time.sleep(0.01) except Exception as e: return { @@ -1007,8 +1007,10 @@ def __init__(self, pool_size: int = 32, timeout: float = 10.0): self.processes = [ PersistentPythonProcess(timeout=timeout) for _ in range(pool_size) ] + # Create a lock for each process to prevent concurrent access + self.process_locks = [threading.Lock() for _ in range(pool_size)] self.next_idx = 0 - self._lock = threading.Lock() + self._selection_lock = threading.Lock() def execute(self, code: str) -> dict: """Execute Python code using next available process (round-robin). @@ -1019,12 +1021,14 @@ def execute(self, code: str) -> dict: Returns: dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'. """ - # Simple round-robin - Ray handles the queuing via max_concurrency - with self._lock: - process = self.processes[self.next_idx] + # Select a process using round-robin + with self._selection_lock: + process_idx = self.next_idx self.next_idx = (self.next_idx + 1) % self.pool_size - return process.execute(code) + # Lock the selected process for the duration of execution + with self.process_locks[process_idx]: + return self.processes[process_idx].execute(code) def cleanup(self): """Cleanup all processes in the pool.""" diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 42d13108a0f..ad84c855757 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -8,11 +8,15 @@ import weakref from collections.abc import Iterator +from queue import Empty from typing import Any, Literal, Protocol +import torch +import torch.distributed + from tensordict import TensorDict, TensorDictBase -from torch import nn +from torch import multiprocessing as mp, nn __all__ = [ "TransportBackend", @@ -136,13 +140,12 @@ class SharedMemTransport: This transport updates shared memory tensors directly without message passing. Workers automatically see weight updates without explicit communication. - The transport supports lazy registration with pipe-based buffer distribution: - - On first weight send for a model, creates shared memory and sends buffer via pipes + The transport supports lazy registration with queue-based buffer distribution: + - On first weight send for a model, creates shared memory and sends buffer via queue - Workers receive the buffer reference and update their local references - Subsequent updates are pure in-place shared memory (zero-copy) - This hybrid approach solves the chicken-and-egg problem: workers can start before - weights are available, and they'll receive the shared buffer references when ready. + Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. Args: policy_weights: Dictionary mapping model_id to shared TensorDict weights. @@ -159,18 +162,21 @@ def __init__( ): self._policy_weights = policy_weights if policy_weights is not None else {} self._auto_register = auto_register - self._pipes = [] # List of pipes to send initial buffer references + self._weight_queues = ( + None # Dict of per-worker queues for distributing shared weights + ) + self._device_to_workers = {} # Maps device -> list of worker indices # Track which model_ids have been sent to workers self._registered_with_workers = set() - def register_pipe(self, pipe: Any) -> None: - """Register a pipe for sending buffer references on first weight send. + def set_worker_info(self, device_to_workers: dict) -> None: + """Set worker device mapping for distributing weights. Args: - pipe: Pipe connection to a worker process. + device_to_workers: Dict mapping device -> list of worker indices on that device. + Example: {torch.device('cuda:1'): [0, 2], torch.device('cuda:2'): [1, 3]} """ - if pipe not in self._pipes: - self._pipes.append(pipe) + self._device_to_workers = device_to_workers def register_weights(self, model_id: str, weights: TensorDictBase) -> None: """Register a shared memory weights TensorDict for a model. @@ -178,10 +184,7 @@ def register_weights(self, model_id: str, weights: TensorDictBase) -> None: This method allows explicit registration of shared weights. It's optional when auto_register=True (the default), but required when auto_register=False. - If pipes are registered and this model hasn't been sent to workers yet, - this will trigger sending the buffer reference to all workers. If pipes - aren't registered yet, weights are stored and will be sent when pipes - become available (during init_on_sender). + Weights are stored and will be sent to workers during init_on_sender. """ if not isinstance(weights, TensorDictBase): raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") @@ -192,40 +195,60 @@ def register_weights(self, model_id: str, weights: TensorDictBase) -> None: else: raise RuntimeError("Re-registering weights is not supported.") - # If this is a new registration and we have pipes, send buffer to workers - # If pipes aren't available yet, defer sending until init_on_sender is called - if self._pipes: - if model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, weights) - else: - raise RuntimeError( - f"Model '{model_id}' has already been registered with workers." - ) + def _infer_device(self, td: TensorDictBase): + """Infer the device from a TensorDict by checking its tensors. - def _send_buffer_to_workers( - self, model_id: str, buffer: TensorDictBase, timeout: float = 10.0 - ) -> None: - """Send shared memory buffer reference to all workers via pipes. + Returns: + torch.device or None if no tensors found or all on different devices. + """ + for value in td.values(True, True): + if isinstance(value, torch.Tensor): + return value.device + return None + + def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. - This is called once per model_id when lazy registration occurs. - Workers receive the buffer and update their local references. + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, eliminating race conditions. Note: We send buffer.data to avoid gradient tracking issues when crossing process boundaries. The .data attribute gives us the underlying tensors without autograd metadata. """ - for pipe in self._pipes: - # Send special registration message with the shared buffer - # Use .data to strip gradient information (can't serialize non-leaf tensors with requires_grad) - pipe.send(((model_id, buffer.data), "register_shared_weights")) + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + # Validate device + device = buffer.device or self._infer_device(buffer) + if device is not None and device.type not in ("cpu", "cuda"): + raise NotImplementedError( + f"Device type '{device.type}' not supported for shared memory. " + f"Only 'cpu' and 'cuda' are supported." + ) - # Wait for acknowledgments from all workers - for pipe in self._pipes: - if not pipe.poll(timeout): - raise TimeoutError("Timeout waiting for acknowledgment from worker") - _, msg = pipe.recv() - if msg != "registered": - raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'") + # Send weights to each worker's dedicated queue + device = buffer.device or self._infer_device(buffer) + if device in self._device_to_workers: + worker_indices = self._device_to_workers[device] + for worker_idx in worker_indices: + # Each worker has its own queue - no race conditions + # Message format: (model_id, weights) + if worker_idx not in self._weight_queues: + raise RuntimeError( + f"Worker {worker_idx} queue not created. " + f"Available queues: {list(self._weight_queues.keys())}" + ) + self._weight_queues[worker_idx].put((model_id, buffer.data)) + else: + # Fallback: send to all workers (for CPU or unknown device) + # Calculate total workers from device_to_workers mapping + all_workers = set() + for workers in self._device_to_workers.values(): + all_workers.update(workers) + for worker_idx in sorted(all_workers): + if worker_idx in self._weight_queues: + self._weight_queues[worker_idx].put((model_id, buffer.data)) self._registered_with_workers.add(model_id) @@ -234,8 +257,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: If the model is not registered and auto_register=True, it will be automatically registered by creating a shared memory copy of the provided weights. The shared - buffer reference is sent to all workers via pipes on first registration, then - subsequent updates are pure in-place shared memory. + buffer reference will be sent to workers via queue during the next init_on_sender call. Args: model_id: Identifier for the model whose weights to update. @@ -272,9 +294,8 @@ def send_weights(self, model_id: str, weights: Any) -> None: self._policy_weights[model_id] = shared_buffer - # Send buffer reference to all workers if we have pipes - if self._pipes and model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, shared_buffer) + # Note: Buffer will be sent to workers during init_on_sender + # when the queue is available shared_weights = self._policy_weights[model_id] @@ -677,8 +698,6 @@ def send_ack(self, message: str = "updated") -> None: def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" - import torch.distributed - return torch.distributed.is_initialized() @@ -1591,6 +1610,11 @@ def __init__( self._shared_transport = SharedMemTransport( self.policy_weights, auto_register=auto_register ) + # Create per-worker queues to avoid race conditions + # Each worker gets its own queue for weight initialization + self._weight_init_queues = {} # worker_idx -> Queue + # General message queue for coordination (if needed in future) + self._message_queue = mp.Queue() def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: """Register shared memory weights for a model. @@ -1614,38 +1638,52 @@ def init_on_sender( ) -> None: """Initialize on the main process (sender side). - For SharedMemWeightSyncScheme, this handles: - 1. Getting cached shared memory weights from context - 2. Pre-registering the weights with the transport - 3. Distributing buffer references to all workers (avoiding later deadlock) + Creates per-worker queues and distributes any pre-registered weights. Args: model_id: Identifier for the model being synchronized - context: Optional context object providing pipes, cached_weights - **kwargs: Alternative to context (pipes, cached_weights, etc.) + context: Optional context object providing device_to_workers mapping, cached_weights + **kwargs: Alternative to context (device_to_workers, cached_weights, etc.) """ - # Extract parameters from context or kwargs + # Extract device_to_workers mapping from context if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) + # Build device_to_workers from policy_device list + if hasattr(context, "policy_device"): + device_to_workers = {} + for idx, device in enumerate(context.policy_device): + if device not in device_to_workers: + device_to_workers[device] = [] + device_to_workers[device].append(idx) + else: + device_to_workers = kwargs.get("device_to_workers", {}) + # Try to get cached shared memory weights if hasattr(context, "get_cached_weights"): cached_weights = context.get_cached_weights(model_id) else: cached_weights = None else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") + device_to_workers = kwargs.get("device_to_workers", {}) cached_weights = kwargs.get("cached_weights") - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 + if not device_to_workers: + raise ValueError( + "device_to_workers mapping must be provided via context or kwargs" + ) + + # Create per-worker queues if not already created + # Collect all unique worker indices + all_workers = set() + for workers in device_to_workers.values(): + all_workers.update(workers) + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() - # Register pipes with shared transport for lazy buffer distribution - for pipe in pipes: - self._shared_transport.register_pipe(pipe) + # Set worker info in transport + self._shared_transport.set_worker_info(device_to_workers) + self._shared_transport._weight_queues = self._weight_init_queues # If we have cached shared memory weights, pre-register them if cached_weights is not None: @@ -1653,8 +1691,7 @@ def init_on_sender( if model_id not in self.policy_weights: self.register_shared_weights(model_id, cached_weights) - # Send buffer references for any weights that were pre-registered - # before pipes were available (e.g., via explicit register_shared_weights call) + # Distribute any pre-registered weights to workers if model_id in self.policy_weights: if model_id not in self._shared_transport._registered_with_workers: self._shared_transport._send_buffer_to_workers( @@ -1675,33 +1712,72 @@ def init_on_worker( self, model_id: str, context: Any = None, + model: Any = None, + worker_idx: int | None = None, **kwargs, ) -> None: """Initialize on worker process (receiver side). + Reads from the worker's dedicated queue to receive shared weights, + then registers them in the transport. The receiver then applies these weights + to the model. + Args: model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) + context: Optional context object providing model and worker_idx + model: Model being synchronized + worker_idx: Worker index + **kwargs: Alternative to context (model, worker_idx, timeout, etc.) """ # Extract parameters from context or kwargs if context is not None: - getattr(context, "pipe", None) if hasattr(context, "get_model"): model = context.get_model(model_id) - else: + elif model is not None: model = None - else: - model = kwargs.get("model") - - # For shared memory, we don't need the pipe in the receiver - # The transport is shared and workers see updates automatically + worker_idx = getattr(context, "worker_idx", worker_idx) + + # Receive weights from this worker's dedicated queue if available + if self._weight_init_queues and worker_idx is not None: + # Each worker has its own queue - no race conditions! + if worker_idx in self._weight_init_queues: + worker_queue = self._weight_init_queues[worker_idx] + timeout = kwargs.get("timeout", 10.0) + try: + # Read from our dedicated queue - only messages for this worker are here + while True: + msg_model_id, shared_weights = worker_queue.get(timeout=timeout) + + # Register the shared weights in the transport + self._shared_transport._policy_weights[ + msg_model_id + ] = shared_weights + + # If this is the model we're initializing, apply weights + if msg_model_id == model_id and model is not None: + shared_weights.to_module(model) + self._shared_transport._registered_with_workers.add( + msg_model_id + ) + break + elif msg_model_id == model_id: + # Model will be applied later when it's available + self._shared_transport._registered_with_workers.add( + msg_model_id + ) + break + # If not the model we're looking for, still register it but keep looking + except Empty: + # No weights pre-registered for this model (will use auto-register or policy_factory) + pass # Create receiver with the shared transport receiver = WeightReceiver(self) if context is not None: receiver._context_ref = weakref.ref(context) receiver._transport = self._shared_transport # Use shared transport + + # Register the model - this will apply the shared weights to it if model is not None: receiver._register_model(model) else: @@ -1711,18 +1787,36 @@ def init_on_worker( self._receiver = receiver self._initialized_on_worker = True - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport and register pipe for lazy buffer distribution (legacy). + def get_weight_queues(self): + """Get the per-worker weight initialization queues. + + Returns: + Dict mapping worker_idx to Queue for receiving shared weight references. + + Raises: + RuntimeError: If init_on_sender() hasn't been called yet. + """ + if not self._weight_init_queues: + raise RuntimeError("Queues not created. Call init_on_sender() first.") + return self._weight_init_queues - For lazy registration to work, we register each worker's pipe with the transport. - On first weight send, the transport will send buffer references via these pipes. + def get_message_queue(self): + """Get the general message queue for coordination. + + Returns: + The message queue for general coordination messages. + """ + return self._message_queue + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create shared memory transport (legacy). Returns the shared transport instance that all workers will use. Since this is shared memory, there's only one transport shared by all workers. + + Note: This is a legacy method. The new init_on_sender/init_on_worker API + is the preferred way to set up the transport. """ - # Register the pipe for lazy buffer distribution - if pipe_or_context is not None: - self._shared_transport.register_pipe(pipe_or_context) return self._shared_transport def prepare_weights( From 78f064bbb21022480aed908b757cad86e364b13c Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Nov 2025 18:21:26 +0000 Subject: [PATCH 02/17] fix test --- test/test_collector.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index bc99b51c08e..58727d1550f 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2993,16 +2993,15 @@ def test_param_sync_mixed_device( not torch.cuda.is_available() or torch.cuda.device_count() < 3, reason="requires at least 3 CUDA devices", ) - @pytest.mark.parametrize( - "weight_sync_scheme", - [SharedMemWeightSyncScheme, MultiProcessWeightSyncScheme], - ) - def test_shared_device_weight_update(self, weight_sync_scheme): + def test_shared_device_weight_update(self): """Test that weight updates work correctly when multiple workers share the same device. This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme. When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through dedicated queues, preventing race conditions that could occur with a single shared queue. + + Note: This test only uses SharedMemWeightSyncScheme (not MultiProcessWeightSyncScheme) because + the latter sends tensors through pipes, which we want to avoid. """ # Create policy on cuda:0 policy = TensorDictModule( @@ -3023,7 +3022,7 @@ def make_env(): total_frames=300, device=["cuda:2", "cuda:1", "cuda:2"], storing_device=["cuda:2", "cuda:1", "cuda:2"], - weight_sync_schemes={"policy": weight_sync_scheme()}, + weight_sync_schemes={"policy": SharedMemWeightSyncScheme()}, ) try: From 18ff77f753b7b0901ec40c81ecdd58dd80d795a9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 12 Nov 2025 17:39:27 +0000 Subject: [PATCH 03/17] refactor --- .../reference/collectors_weightsync.rst | 6 +- examples/collectors/weight_sync_collectors.py | 2 +- examples/collectors/weight_sync_standalone.py | 4 +- test/test_collector.py | 32 +- test/test_weightsync.py | 23 +- torchrl/collectors/_multi_base.py | 80 +- torchrl/collectors/_runner.py | 30 +- torchrl/collectors/_single.py | 77 +- torchrl/collectors/utils.py | 59 +- .../algorithms/configs/weight_sync_schemes.py | 5 - torchrl/weight_update/weight_sync_schemes.py | 758 ++++++++++-------- 11 files changed, 568 insertions(+), 508 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 0fcf174f3c1..b6c2257e28f 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -93,8 +93,8 @@ Here's a basic example: # Example 2: Shared memory weight synchronization # ------------------------------------------------ - # Create shared memory scheme with auto-registration - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + # Create shared memory scheme + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict") # Initialize with pipes for lazy registration parent_pipe2, child_pipe2 = mp.Pipe() @@ -159,7 +159,7 @@ across multiple inference workers: # Example 2: Multiple collectors with shared memory # -------------------------------------------------- # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict") collector = MultiSyncDataCollector( create_env_fn=[ diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index a3962966c8c..020ad0b8a61 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory(): env.close() # Shared memory is more efficient for frequent updates - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py index 2d918cb10a2..2899febd06b 100644 --- a/examples/collectors/weight_sync_standalone.py +++ b/examples/collectors/weight_sync_standalone.py @@ -141,8 +141,8 @@ def example_shared_memory_sync(): # Create a simple policy policy = nn.Linear(4, 2) - # Create shared memory scheme with auto-registration - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + # Create shared memory scheme + scheme = SharedMemWeightSyncScheme(strategy="tensordict") sender = scheme.create_sender() # Create pipe for lazy registration diff --git a/test/test_collector.py b/test/test_collector.py index 58727d1550f..8ce8a055091 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1132,40 +1132,20 @@ def make_and_test_policy( policy, policy_device=original_device, env_device=original_device ) - # a deepcopy must occur when the policy_device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + # Test that we DON'T raise deepcopy errors anymore even when policy_device differs + # These scenarios previously would have triggered deepcopy, but now use meta device context manager + if collector_type is not SyncDataCollector: + # policy_device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy( policy, policy_device=shared_device, env_device=shared_device ) - # a deepcopy must occur when device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + if collector_type is not SyncDataCollector: + # device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy(policy, device=shared_device) - # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device - # is there to inform us - substitute_device = ( - original_device if torch.cuda.is_available() else torch.device("cpu") - ) - policy = make_policy(substitute_device, nn_module=False) - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=substitute_device - ) - # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=shared_device - ) - make_and_test_policy( - policy, - policy_device=substitute_device, - env_device=shared_device, - trust_policy=True, - ) - # If there is no policy_device, we assume that the user is doing things right too but don't warn if collector_type is SyncDataCollector or original_device.type != "mps": policy = make_policy(original_device, nn_module=False) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 2ccd4308ccf..82992b14ca4 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -244,9 +244,7 @@ def test_shared_mem_scheme(self): ).share_memory_() scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, strategy="tensordict", - auto_register=False, ) transport = scheme.create_transport(None) @@ -260,21 +258,6 @@ def test_shared_mem_scheme(self): assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - def test_shared_mem_scheme_auto_register(self): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - transport = scheme.create_transport(None) - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - - transport.send_weights("policy", weights) - - assert "policy" in scheme.policy_weights - assert torch.allclose( - scheme.policy_weights["policy"]["weight"], torch.ones(2, 4) - ) - def test_no_weight_sync_scheme(self): scheme = NoWeightSyncScheme() transport = scheme.create_transport(None) @@ -396,7 +379,7 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): collector.shutdown() def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") collector = MultiSyncDataCollector( create_env_fn=[ @@ -677,7 +660,7 @@ def test_multiprocess_scheme_serialize_after_sender_init(self): def test_shared_mem_scheme_serialize_before_init(self): """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") # Serialize and deserialize pickled = pickle.dumps(scheme) @@ -698,9 +681,7 @@ def test_shared_mem_scheme_serialize_after_init(self): ).share_memory_() scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, strategy="tensordict", - auto_register=False, ) def init_on_sender(scheme, child_pipe): diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index f9d7ea7a8bd..44efecc58ec 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -334,12 +334,14 @@ def __init__( policy_factory = self._setup_policy_factory(policy_factory) # Set up weight synchronization + weight_sync_schemes = {} if ( not any(policy_factory) and not weight_sync_schemes and weight_updater is None + and isinstance(policy, nn.Module) ): - weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} + weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() self._setup_multi_policy_and_weights( policy, policy_factory, weight_updater, weight_sync_schemes @@ -511,52 +513,16 @@ def _setup_multi_policy_and_weights( raise TypeError("policy_factory and policy are mutually exclusive") if weight_sync_schemes is not None: - # Weight sync schemes handle all weight distribution - # Extract weights so schemes can access them, but don't do in-place replacement - self._policy_weights_dict = {} - self._fallback_policy = None - - if not any(policy_factory) and policy is not None: - # Extract weights for the first device so schemes can access them - # Use first device as representative - first_device = self.policy_device[0] if self.policy_device else None - - # Validate device types for SharedMemWeightSyncScheme - for scheme in weight_sync_schemes.values(): - if isinstance(scheme, SharedMemWeightSyncScheme): - for policy_device in self.policy_device: - if policy_device and policy_device.type not in ( - "cpu", - "cuda", - ): - raise NotImplementedError( - f"Device type '{policy_device.type}' not supported for SharedMemWeightSyncScheme. " - f"Only 'cpu' and 'cuda' are supported." - ) - - # Extract weights from policy - # Use .data to avoid gradient tracking (can't serialize tensors with requires_grad) - weights = ( - TensorDict.from_module(policy, as_module=True).data - if isinstance(policy, nn.Module) - else TensorDict() + weight_sync_policy = weight_sync_schemes.get("policy") + if weight_sync_policy is None: + return + if weight_sync_policy._initialized_on_sender: + return + if any(p is not None for p in policy_factory): + raise RuntimeError( + f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" ) - - # For SharedMemWeightSyncScheme, share the weights - if any( - isinstance(scheme, SharedMemWeightSyncScheme) - for scheme in weight_sync_schemes.values() - ): - if first_device and first_device.type == "cpu": - weights = weights.share_memory_() - elif first_device and first_device.type == "cuda": - # CUDA tensors maintain shared references through mp.Queue - weights = weights.to(first_device).share_memory_() - - self._policy_weights_dict[first_device] = weights - self._fallback_policy = policy - - self._get_weights_fn = None + weight_sync_policy.init_on_sender(model=policy, devices=self.policy_device) else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( @@ -900,13 +866,16 @@ def _run_processes(self) -> None: # Schemes handle weight distribution on worker side if any(policy_factory): policy_to_send = None # Factory will create policy in worker + cm = contextlib.nullcontext() elif policy is not None: - # Send meta-device policy (empty structure) - schemes apply weights - policy_to_send = _make_meta_policy(policy) + # Send policy with meta-device parameters (empty structure) - schemes apply weights + policy_to_send = policy + cm = _make_meta_policy(policy) else: policy_to_send = None - cm = contextlib.nullcontext() - else: + cm = contextlib.nullcontext() + elif hasattr(self, "_policy_weights_dict"): + # LEGACY: # With weight updater, use in-place weight replacement # Take the weights and locally dispatch them to the policy before sending. # This ensures a given set of shared weights for a device are shared @@ -917,6 +886,10 @@ def _run_processes(self) -> None: cm = policy_weights.to_module(policy) else: cm = contextlib.nullcontext() + else: + # Parameter-less policy + cm = contextlib.nullcontext() + policy_to_send = policy with cm: kwargs = { @@ -995,6 +968,13 @@ def _run_processes(self) -> None: self.procs.append(proc) self.pipes.append(pipe_parent) + # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" + # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: + # Workers will call receiver.synchronize_weights() during init and may block waiting for data + if self._weight_senders: + for model_id, sender in self._weight_senders.items(): + sender.synchronize_weights() + # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 54e5c823888..14ceb8f86d8 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -26,7 +26,6 @@ from torchrl.envs import EnvBase, EnvCreator from torchrl.envs.utils import ExplorationType from torchrl.weight_update import WeightSyncScheme -from torchrl.weight_update.weight_sync_schemes import _resolve_model def _make_policy_factory( @@ -38,9 +37,13 @@ def _make_policy_factory( policy = policy_factory() if weight_sync_scheme is not None: + # Initialize the receiver on the worker side weight_sync_scheme.init_on_worker( model=policy, model_id="policy", worker_idx=worker_idx ) + # Get the receiver and synchronize initial weights + receiver = weight_sync_scheme.get_receiver() + receiver.synchronize_weights(worker_idx=worker_idx) return policy @@ -123,8 +126,11 @@ def _main_async_collector( no_cuda_sync=no_cuda_sync, weight_sync_schemes=weight_sync_schemes, ) + print("Inner collector created") # Set up weight receivers for worker process + # Note: For the "policy" model, initialization is done in _make_policy_factory + # This section only handles additional models (not "policy") if weight_sync_schemes: inner_collector._weight_receivers = {} inner_collector.pipe = pipe_child # Add pipe attribute for context @@ -133,22 +139,16 @@ def _main_async_collector( ) for model_id, scheme in weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_worker"): - # For SharedMemWeightSyncScheme, init_on_worker reads from queue - # and applies weights to model - all handled by the receiver - scheme.init_on_worker(model_id=model_id, context=inner_collector) + if model_id == "policy": + # Policy receiver was already initialized in _make_policy_factory receiver = scheme.get_receiver() + inner_collector._weight_receivers[model_id] = receiver else: - # Legacy API - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) - - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) - - inner_collector._weight_receivers[model_id] = receiver + # Initialize receivers for other models + scheme.init_on_worker(model_id=model_id, context=inner_collector) + receiver = scheme.get_receiver() + receiver.synchronize_weights(worker_idx=worker_idx) + inner_collector._weight_receivers[model_id] = receiver else: inner_collector._weight_receivers = {} diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index aee35c4042a..7a78cf41605 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -452,6 +452,7 @@ def _init_policy( if policy is None: if policy_factory is not None: policy = policy_factory() + print(f"Policy factory created: {policy}") else: policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: @@ -594,38 +595,58 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None break if has_meta_params: - # Skip device placement for meta policies - schemes handle weight application - # Policy stays as-is, weights will be applied by the receiver - self.get_weights_fn = lambda: TensorDict.from_module(policy).data + # Policy has meta params - sent from weight sync schemes + # Skip device placement, weights will come from receiver + # Keep policy on meta device until weights are loaded + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Don't extract weights yet - they're on meta device (empty) + self.policy_weights = TensorDict() + self.get_weights_fn = None else: # Normal path: move policy to correct device policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) - if not self.trust_policy: - self.policy = policy - env = getattr(self, "env", None) - try: - wrapped_policy = _make_compatible_policy( - policy=policy, - observation_spec=getattr(env, "observation_spec", None), - env=self.env, - ) - except (TypeError, AttributeError, ValueError) as err: - raise TypeError( - "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." - ) from err - self._wrapped_policy = wrapped_policy - else: - self.policy = self._wrapped_policy = policy - - # Extract policy weights from the uncompiled policy - # Access _wrapped_policy_uncompiled directly to avoid triggering compilation - if isinstance(self._wrapped_policy_uncompiled, nn.Module): - self.policy_weights = TensorDict.from_module( - self._wrapped_policy_uncompiled, as_module=True - ).data - else: - self.policy_weights = TensorDict() + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Extract policy weights from the uncompiled policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation + if isinstance(self._wrapped_policy_uncompiled, nn.Module): + self.policy_weights = TensorDict.from_module( + self._wrapped_policy_uncompiled, as_module=True + ).data + else: + self.policy_weights = TensorDict() # If policy doesn't have meta params, compile immediately # Otherwise, defer until first use (after weights are loaded) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 4a9470f708d..8492a52041e 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -6,7 +6,6 @@ import contextlib from collections.abc import Callable -from copy import deepcopy import torch from pyvers import implement_for @@ -265,57 +264,39 @@ def nest(*x): @implement_for("torch", "2.5.0") -def _make_meta_policy(policy: nn.Module) -> nn.Module: - """Create policy structure with parameters on meta device. +def _cast(p, param_maybe_buffer): + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if isinstance(param_maybe_buffer, Buffer): + return Buffer(p) + return p + + +def _make_meta_policy(policy: nn.Module): + """Create context manager that temporarily puts policy parameters on meta device. This is used with weight sync schemes to send policy structure without weights. The actual weights are distributed by the schemes. Args: - policy: Policy module to extract structure from. + policy: Policy module to temporarily modify. Returns: - A copy of the policy with all parameters on meta device and requires_grad=False. + A context manager that temporarily replaces policy parameters with meta device versions. + On exit, the original parameters are restored to the policy. """ - def _cast(p, param_maybe_buffer): - if isinstance(param_maybe_buffer, Parameter): - # Create parameter without gradients to avoid serialization issues - return Parameter(p, requires_grad=False) - if isinstance(param_maybe_buffer, Buffer): - return Buffer(p) - return p - param_and_buf = TensorDict.from_module(policy, as_module=True) - with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): - meta_policy = deepcopy(policy) - return meta_policy + return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) @implement_for("torch", None, "2.5.0") -def _make_meta_policy(policy: nn.Module) -> nn.Module: # noqa: F811 - """Create policy structure with parameters on meta device. - - This is used with weight sync schemes to send policy structure without weights. - The actual weights are distributed by the schemes. - - Args: - policy: Policy module to extract structure from. - - Returns: - A copy of the policy with all parameters on meta device and requires_grad=False. - """ - - def _cast(p, param_maybe_buffer): - if isinstance(param_maybe_buffer, Parameter): - # Create parameter without gradients to avoid serialization issues - return Parameter(p, requires_grad=False) - return p - - param_and_buf = TensorDict.from_module(policy, as_module=True) - with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): - meta_policy = deepcopy(policy) - return meta_policy +def _cast(p, param_maybe_buffer): # noqa + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + return p def _map_to_cpu_if_needed(x): diff --git a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py index 4417e5c2cb3..ed128429d76 100644 --- a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py +++ b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py @@ -48,17 +48,12 @@ class SharedMemWeightSyncSchemeConfig(ConfigBase): Weight synchronization using shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing. - - By default, uses lazy registration (auto_register=True) which makes it seamless - to use with Hydra configs - models are automatically registered on first weight send. """ _target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme" _partial_: bool = False - policy_weights: Any = None # dict[str, TensorDictBase] | None strategy: str = "tensordict" # "tensordict" or "state_dict" - auto_register: bool = True # Enable lazy registration by default def __post_init__(self) -> None: """Post-initialization hook for shared memory weight sync scheme configurations.""" diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index ad84c855757..e9fc033294d 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -8,7 +8,6 @@ import weakref from collections.abc import Iterator -from queue import Empty from typing import Any, Literal, Protocol import torch @@ -49,7 +48,7 @@ class TransportBackend(Protocol): """Abstract interface for different communication mechanisms.""" - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights to the receiver.""" ... @@ -61,6 +60,30 @@ def check_connection(self) -> bool: """Check if the connection is still alive.""" ... + def synchronize_weights_on_sender(self) -> None: + """Synchronize weights on sender side before collection starts. + + This is called once after workers are initialized to send the initial + weights. This can be a no-op (weights are sent via + send_weights). + """ + ... + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """Synchronize weights on worker side before collection starts. + + This is called once in each worker after initialization to receive + the initial weights. This is a no-op (weights are received via + receive_weights). + + Args: + worker_idx: The worker index. + + Returns: + The received weights (for SharedMemTransport) or None. + """ + ... + class MPTransport: """Multiprocessing transport using pipes. @@ -74,20 +97,20 @@ def __init__(self, pipe_connection, timeout: float = 10.0): self.timeout = timeout self.pipe = pipe_connection - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights through the pipe. Sends weights and waits for acknowledgment to ensure delivery. """ - self.send_weights_async(model_id, weights) + self.send_weights_async(weights) self.wait_ack() - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights through the pipe without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. """ - self.pipe.send(((model_id, weights), "update_weights")) + self.pipe.send((weights, "update_weights")) def wait_ack(self) -> None: """Wait for acknowledgment from worker.""" @@ -103,12 +126,16 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: Returns: Tuple of (model_id, weights) if weights were received, None if no data available or if a non-weight message was received. + + Note: + model_id is returned as "policy" for backward compatibility, but transports + are now bound to a single model during initialization. """ if self.pipe.poll(timeout): data_in, msg = self.pipe.recv() if msg == "update_weights": - model_id, weights = data_in - return model_id, weights + weights = data_in + return "policy", weights else: # Not a weight update message - put it back and return None # This allows the main worker loop to handle other messages @@ -133,172 +160,97 @@ def check_ack(self, message: str = "updated") -> None: def check_connection(self) -> bool: return not self.pipe.closed + def synchronize_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for MPTransport - weights are received via receive_weights().""" + return None + class SharedMemTransport: """Shared memory transport for in-place weight updates. - This transport updates shared memory tensors directly without message passing. + This transport uses queue-based buffer distribution for initialization, then + updates shared memory tensors directly for subsequent weight updates. Workers automatically see weight updates without explicit communication. - The transport supports lazy registration with queue-based buffer distribution: - - On first weight send for a model, creates shared memory and sends buffer via queue - - Workers receive the buffer reference and update their local references + Initialization flow: + - Shared memory buffers are created and sent to workers via per-worker queues + - Workers receive the buffer reference and apply weights to their models - Subsequent updates are pure in-place shared memory (zero-copy) Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. - Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration. - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to `False` to require explicit registration via - register_weights(). """ - def __init__( - self, - policy_weights: dict[str, TensorDictBase] | None = None, - auto_register: bool = True, - ): - self._policy_weights = policy_weights if policy_weights is not None else {} - self._auto_register = auto_register + def __init__(self): + self._params_map = None # a dict[worker_idx, TensorDictBase] map self._weight_queues = ( None # Dict of per-worker queues for distributing shared weights ) - self._device_to_workers = {} # Maps device -> list of worker indices - # Track which model_ids have been sent to workers - self._registered_with_workers = set() - - def set_worker_info(self, device_to_workers: dict) -> None: - """Set worker device mapping for distributing weights. - - Args: - device_to_workers: Dict mapping device -> list of worker indices on that device. - Example: {torch.device('cuda:1'): [0, 2], torch.device('cuda:2'): [1, 3]} - """ - self._device_to_workers = device_to_workers - def register_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register a shared memory weights TensorDict for a model. + def register_weights( + self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] + ) -> None: + """Initialize per-worker queues for shared memory buffer distribution.""" + self._weight_queues = init_queues + self._params_map = params_map + # Create set of the unique weights + self._unique_weights = [] + for weights in params_map.values(): + if weights in self._unique_weights: + continue + self._unique_weights.append(weights) + + def synchronize_weights_on_sender(self) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, to avoid race conditions. - Weights are stored and will be sent to workers during init_on_sender. """ - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") - - is_new_registration = model_id not in self._policy_weights - if is_new_registration: - self._policy_weights[model_id] = weights - else: - raise RuntimeError("Re-registering weights is not supported.") + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - def _infer_device(self, td: TensorDictBase): - """Infer the device from a TensorDict by checking its tensors. + for worker_idx, queue in self._weight_queues.items(): + weights = self._params_map[worker_idx] + queue.put(weights) - Returns: - torch.device or None if no tensors found or all on different devices. - """ - for value in td.values(True, True): - if isinstance(value, torch.Tensor): - return value.device - return None + def synchronize_weights_on_worker( + self, worker_idx: int, timeout: float = 10.0 + ) -> TensorDictBase: + """Receive shared memory buffer reference from sender via their per-worker queues. - def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None: - """Send shared memory buffer reference to workers via their per-worker queues. + Each worker reads from its own dedicated queue, to avoid race conditions. - Both CPU and CUDA tensors maintain shared references through queues. - Each worker reads from its own dedicated queue, eliminating race conditions. + Args: + worker_idx: The worker index. + timeout: Timeout for reading from queue. - Note: We send buffer.data to avoid gradient tracking issues when crossing - process boundaries. The .data attribute gives us the underlying tensors - without autograd metadata. + Returns: + The shared memory weights TensorDict. """ if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - # Validate device - device = buffer.device or self._infer_device(buffer) - if device is not None and device.type not in ("cpu", "cuda"): - raise NotImplementedError( - f"Device type '{device.type}' not supported for shared memory. " - f"Only 'cpu' and 'cuda' are supported." - ) + if worker_idx not in self._weight_queues: + raise RuntimeError(f"Worker {worker_idx} not registered in queues.") - # Send weights to each worker's dedicated queue - device = buffer.device or self._infer_device(buffer) - if device in self._device_to_workers: - worker_indices = self._device_to_workers[device] - for worker_idx in worker_indices: - # Each worker has its own queue - no race conditions - # Message format: (model_id, weights) - if worker_idx not in self._weight_queues: - raise RuntimeError( - f"Worker {worker_idx} queue not created. " - f"Available queues: {list(self._weight_queues.keys())}" - ) - self._weight_queues[worker_idx].put((model_id, buffer.data)) - else: - # Fallback: send to all workers (for CPU or unknown device) - # Calculate total workers from device_to_workers mapping - all_workers = set() - for workers in self._device_to_workers.values(): - all_workers.update(workers) - for worker_idx in sorted(all_workers): - if worker_idx in self._weight_queues: - self._weight_queues[worker_idx].put((model_id, buffer.data)) - - self._registered_with_workers.add(model_id) - - def send_weights(self, model_id: str, weights: Any) -> None: - """Update weights in-place in shared memory. + # Read from dedicated queue for this worker + worker_queue = self._weight_queues[worker_idx] + weights = worker_queue.get(timeout=timeout) + return weights - If the model is not registered and auto_register=True, it will be automatically - registered by creating a shared memory copy of the provided weights. The shared - buffer reference will be sent to workers via queue during the next init_on_sender call. + def send_weights(self, weights: Any) -> None: + """Update weights in-place in shared memory. Args: - model_id: Identifier for the model whose weights to update. weights: New weights to send. Can be a TensorDictBase or dict. Raises: - KeyError: If model is not registered and auto_register=False. - ValueError: If weights type is unsupported for auto-registration. - """ - if model_id not in self._policy_weights: - if not self._auto_register: - raise KeyError( - f"Model '{model_id}' not registered in SharedMemTransport. " - f"Available models: {list(self._policy_weights.keys())}. " - f"Either register the model using register_weights() or enable auto_register." - ) - - # Auto-register on first send - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError( - f"Cannot auto-register model '{model_id}' with weights type: {type(weights)}. " - f"Supported types for auto-registration: TensorDictBase, dict. " - f"Please manually register shared weights using register_weights()." - ) - # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) - # This is necessary for to_module() to work properly - weights_to_share = weights - # Check if keys are flattened by looking for dots in key names - if any("." in key for key in weights_to_share.keys()): - weights_to_share = weights_to_share.unflatten_keys(".") - shared_buffer = weights_to_share.share_memory_() - - self._policy_weights[model_id] = shared_buffer - - # Note: Buffer will be sent to workers during init_on_sender - # when the queue is available - - shared_weights = self._policy_weights[model_id] - + ValueError: If weights type is unsupported. + """ # Update shared memory in-place (workers see this automatically) if isinstance(weights, dict): weights = TensorDict(weights) @@ -308,7 +260,11 @@ def send_weights(self, model_id: str, weights: Any) -> None: weights_to_update = weights if any("." in key for key in weights.keys()): weights_to_update = weights.unflatten_keys(".") - shared_weights.data.update_(weights_to_update.data) + + for buffer in self._unique_weights: + buffer.update_(weights_to_update, non_blocking=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """No-op for shared memory - weights are already visible.""" @@ -347,13 +303,8 @@ def __init__( self._remote_collector = remote_collector self._tensor_transport = tensor_transport - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via Ray. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via Ray.""" if self._remote_collector is None: return @@ -368,7 +319,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: ) self.ray.wait([future], num_returns=1) - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to remote collector without waiting for completion. Use wait_ack() to wait for completion after sending to all workers. @@ -397,6 +348,13 @@ def check_connection(self) -> bool: """Check if Ray is initialized.""" return self.ray.is_initialized() + def synchronize_weights_on_sender(self) -> None: + """No-op for RayTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayTransport - weights are received via remote method calls.""" + return None + class RayActorTransport: """Ray transport for communicating with Ray actors (not collectors). @@ -427,7 +385,7 @@ def set_actor(self, actor_ref): """Set the Ray actor reference to communicate with.""" self._actor_ref = actor_ref - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights to the Ray actor.""" if self._actor_ref is None: return @@ -447,7 +405,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: else: raise ValueError(f"Unknown update method: {self._update_method}") - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to Ray actor without waiting for completion. Use wait_ack() to wait for completion after sending to all actors. @@ -494,6 +452,13 @@ def check_connection(self) -> bool: return False return True + def synchronize_weights_on_sender(self) -> None: + """No-op for RayActorTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayActorTransport - weights are received via remote method calls.""" + return None + class RPCTransport: """RPC transport for communicating with a single RPC remote collector. @@ -508,13 +473,8 @@ def __init__(self, collector_info=None, collector_rref=None, collector_class=Non self._collector_rref = collector_rref self._collector_class = collector_class - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via RPC. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via RPC.""" if self._collector_info is None or self._collector_rref is None: return @@ -527,7 +487,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: args=(self._collector_rref, weights), ) - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to remote collector without waiting for completion. Use wait_ack() to wait for completion after sending to all workers. @@ -560,6 +520,13 @@ def check_connection(self) -> bool: return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + def synchronize_weights_on_sender(self) -> None: + """No-op for RPCTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via RPC calls.""" + return None + class DistributedTransport: """torch.distributed transport for communicating with a single distributed worker. @@ -582,13 +549,8 @@ def __init__(self, store=None, rank=None, sync=True): self._sync = sync self._weights_buffer = None # TensorDict buffer for receiving weights - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the distributed worker. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ + def send_weights(self, weights: Any) -> None: + """Send weights to the distributed worker.""" if self._store is None or self._rank is None: return @@ -607,7 +569,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: raise RuntimeError(f"Expected 'updated' but got status {status}.") self._store.delete_key(f"NODE_{self._rank}_out") - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to distributed worker without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. @@ -700,6 +662,13 @@ def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" return torch.distributed.is_initialized() + def synchronize_weights_on_sender(self) -> None: + """No-op for DistributedTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for DistributedTransport - weights are received via receive_weights().""" + return None + # ============================================================================ # Weight Strategies @@ -790,7 +759,9 @@ def apply_weights(self, destination: Any, weights: Any) -> None: if any("." in key for key in weights.keys()): weights = weights.unflatten_keys(".") if isinstance(destination, nn.Module): - destination = TensorDict.from_module(destination) + # Do not update in-place + weights.to_module(destination) + return elif isinstance(destination, dict): destination = TensorDict(destination) if any(isinstance(key, str) and "." in key for key in destination.keys()): @@ -950,10 +921,10 @@ def send( # Send to all workers first (non-blocking if transport supports it) for transport in transports: if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) + transport.send_weights_async(prepared_weights) else: # Fallback for transports that don't support async send - transport.send_weights(model_id, prepared_weights) + transport.send_weights(prepared_weights) # Wait for all acknowledgments for transport in transports: @@ -1000,7 +971,7 @@ def send_async( # Send to all workers (non-blocking) for transport in self._pending_transports: if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) + transport.send_weights_async(prepared_weights) else: raise RuntimeError( f"transport of type {type(transport)} does not support async send." @@ -1028,15 +999,30 @@ def wait_async(self) -> None: self._pending_async = False self._pending_transports = None - # Legacy method - kept for backward compatibility + def synchronize_weights(self) -> None: + """Synchronize weights with workers before collection starts. + + This method is called once after workers are initialized to send + the initial weights. For most transports this is a no-op (weights + are sent via send()). For SharedMemTransport, this sends buffer + references via queues. + + This is different from send() which is called during training to + update weights. + """ + # Iterate over all transports and call synchronize_weights_on_sender + for transport in self._iterate_transports(): + if hasattr(transport, "synchronize_weights_on_sender"): + transport.synchronize_weights_on_sender() + def update_weights(self, weights: Any) -> None: - """Send weights to ALL workers for this model (legacy). + """Send weights to ALL workers for this model. Args: weights: Weights to send (can be None, nn.Module, TensorDict, etc.). Note: - This is the legacy method. Use send() instead. + Convenience method that calls send(weights=weights). """ self.send(weights=weights) @@ -1070,6 +1056,7 @@ def __init__(self, scheme: WeightSyncScheme): self._transport = None # lazy self._model_ref = None self._strategy = _get_strategy(scheme.strategy) + self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_worker() def _set_context(self, context: Any) -> None: """Set the context object (inner_collector) for resolving references (internal). @@ -1142,14 +1129,46 @@ def receive(self, timeout: float = 0.001) -> bool: return True + def synchronize_weights(self, worker_idx: int | None = None) -> None: + """Synchronize weights with sender before collection starts. + + This method is called once after the worker is initialized to receive + the initial weights. For most transports this is a no-op (weights are + received via receive()). For SharedMemTransport, this receives the + buffer reference via queue and applies it to the model. + + This is different from receive() which is called during collection + to check for weight updates. + + Args: + worker_idx: The worker index (required for SharedMemTransport). + If not provided, uses the worker_idx stored during init_on_worker(). + """ + if self._transport is None: + return + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = getattr(self, "_worker_idx", None) + + # Call transport's synchronize method if available + weights = self._transport.synchronize_weights_on_worker(worker_idx) + + # Apply weights to model if received (SharedMemTransport case) + if weights is not None and self._model_ref is not None: + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights) + else: + raise ValueError("Failed to synchronize weights") + def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model (legacy). + """Apply received weights to registered model. Args: weights: The weights to apply. Note: - This is the legacy method. Use receive() in the worker loop instead. + Convenience method. Normally weights are received and applied via receive() in the worker loop. """ if self._model_ref is None: raise ValueError("No model registered") @@ -1230,8 +1249,7 @@ def update_weights(self, weights: Any) -> None: self._initialize_transport() if self._single_transport is not None: - model_id = getattr(self, "_model_id", "policy") - self._single_transport.send_weights(model_id, weights) + self._single_transport.send_weights(weights) def _initialize_transport(self) -> None: """Lazily initialize the transport by resolving the actor reference.""" @@ -1397,7 +1415,6 @@ def __setstate__(self, state): """Restore the scheme from pickling.""" self.__dict__.update(state) - # Legacy methods - kept for backward compatibility @abc.abstractmethod def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create transport for communication. @@ -1407,22 +1424,31 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: Returns: A transport backend instance. + + Note: + This is used internally by init_on_sender/init_on_worker. """ ... def create_sender(self) -> WeightSender: - """Create a sender for this scheme (legacy). + """Create a sender for this scheme. Returns: WeightSender instance configured for this scheme. + + Note: + Typically you should use init_on_sender() followed by get_sender() instead. """ return WeightSender(self) def create_receiver(self) -> WeightReceiver: - """Create a receiver for this scheme (legacy). + """Create a receiver for this scheme. Returns: WeightReceiver instance configured for this scheme. + + Note: + Typically you should use init_on_worker() followed by get_receiver() instead. """ return WeightReceiver(self) @@ -1562,141 +1588,148 @@ def init_on_worker( self._initialized_on_worker = True def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe (legacy).""" + """Create an MPTransport using the provided pipe. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ return MPTransport(pipe) class SharedMemWeightSyncScheme(WeightSyncScheme): """Weight synchronization using shared memory. - This scheme mimics the old WeightUpdater behavior by using shared memory - for in-place weight updates. Workers automatically see weight updates - without explicit message passing. - - By default, this scheme uses lazy registration: models are automatically - registered on the first weight send. This makes it seamless to use with - configuration systems like Hydra where schemes are created before models - are available. + This scheme uses shared memory for in-place weight updates. Workers + automatically see weight updates without explicit message passing. Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration (auto_register=True). strategy: The weight transmission strategy (default: "tensordict"). - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to False to require explicit registration via - register_shared_weights(). Example: - >>> # With auto-registration (default) - works with Hydra configs + >>> # Basic usage >>> scheme = SharedMemWeightSyncScheme() - >>> # Models are auto-registered on first weight send - - >>> # With explicit registration - >>> scheme = SharedMemWeightSyncScheme(auto_register=False) - >>> shared_weights = TensorDict.from_module(model).share_memory_() - >>> scheme.register_shared_weights("policy", shared_weights) + >>> # Weights are initialized via init_on_sender() """ def __init__( self, - policy_weights: dict[str, TensorDictBase] | None = None, strategy: str = "tensordict", - auto_register: bool = True, ): super().__init__(strategy) - self.policy_weights = policy_weights if policy_weights is not None else {} - self.auto_register = auto_register # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport( - self.policy_weights, auto_register=auto_register - ) + self._shared_transport = SharedMemTransport() # Create per-worker queues to avoid race conditions # Each worker gets its own queue for weight initialization self._weight_init_queues = {} # worker_idx -> Queue # General message queue for coordination (if needed in future) self._message_queue = mp.Queue() - def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register shared memory weights for a model. - - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. - - Args: - model_id: Identifier for the model. - weights: Shared memory TensorDict containing the model's weights. - """ - # Don't set self.policy_weights[model_id] here - register_weights does that - # (self.policy_weights and transport._policy_weights are the same dict) - self._shared_transport.register_weights(model_id, weights) - def init_on_sender( self, - model_id: str, + model_id: str | None = None, context: Any = None, - **kwargs, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, ) -> None: """Initialize on the main process (sender side). - Creates per-worker queues and distributes any pre-registered weights. + We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers + share the same device, the entry in the dict will be the same. + To do this, we need to know the number of workers, their assigned device, and have access to the parameters. + If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided + explicitly. + + In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the + devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that + will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. Args: model_id: Identifier for the model being synchronized - context: Optional context object providing device_to_workers mapping, cached_weights - **kwargs: Alternative to context (device_to_workers, cached_weights, etc.) - """ - # Extract device_to_workers mapping from context - if context is not None: - # Build device_to_workers from policy_device list - if hasattr(context, "policy_device"): - device_to_workers = {} - for idx, device in enumerate(context.policy_device): - if device not in device_to_workers: - device_to_workers[device] = [] - device_to_workers[device].append(idx) - else: - device_to_workers = kwargs.get("device_to_workers", {}) - - # Try to get cached shared memory weights - if hasattr(context, "get_cached_weights"): - cached_weights = context.get_cached_weights(model_id) - else: - cached_weights = None - else: - device_to_workers = kwargs.get("device_to_workers", {}) - cached_weights = kwargs.get("cached_weights") - - if not device_to_workers: - raise ValueError( - "device_to_workers mapping must be provided via context or kwargs" - ) + context: Optional context object providing device_to_workers mapping and model access + weights: Pre-extracted weights as TensorDict (for policy factory usage) + model: Model to extract weights from + params_map: Direct mapping of worker_idx to weights on device (most explicit) + devices: List of devices for each worker + device_map_fn: Custom function to map worker_idx and weights to device-specific weights + num_workers: Number of workers (required with device_map_fn) + + Examples: + Simple usage with collector context (stateful policy): + + >>> policy = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Pre-initialized usage (policy factory): + + >>> policy_on_main = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> # Must initialize before collector creation when using policy_factory + >>> scheme.init_on_sender( + ... model_id="policy", + ... weights=TensorDict.from_module(policy_on_main), + ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ... num_workers=2, + ... ) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy_factory=[make_stateful_policy], + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + + Direct params_map usage (advanced): + + >>> weights_cpu = TensorDict.from_module(policy).share_memory_() + >>> weights_cuda = weights_cpu.to("cuda").share_memory_() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> scheme.init_on_sender( + ... model_id="policy", + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... ) + """ + # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init + # the weights on the workers. + # Scenarios: + # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating + # the transport and registering the workers etc. + # - The user provides a model or its params and a device map. We need to create the map from the params + # explicitly. + # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need + # to collect the model from the context. + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) # Create per-worker queues if not already created # Collect all unique worker indices - all_workers = set() - for workers in device_to_workers.values(): - all_workers.update(workers) + all_workers = list(params_map.keys()) for worker_idx in all_workers: if worker_idx not in self._weight_init_queues: self._weight_init_queues[worker_idx] = mp.Queue() # Set worker info in transport - self._shared_transport.set_worker_info(device_to_workers) - self._shared_transport._weight_queues = self._weight_init_queues - - # If we have cached shared memory weights, pre-register them - if cached_weights is not None: - # Check if already registered to avoid re-registration error - if model_id not in self.policy_weights: - self.register_shared_weights(model_id, cached_weights) - - # Distribute any pre-registered weights to workers - if model_id in self.policy_weights: - if model_id not in self._shared_transport._registered_with_workers: - self._shared_transport._send_buffer_to_workers( - model_id, self.policy_weights[model_id] - ) + self._shared_transport.register_weights(params_map, self._weight_init_queues) # Create sender with the shared transport sender = WeightSender(self) @@ -1708,6 +1741,126 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True + def synchronize_weights(self): + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's synchronize_weights(). + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" + ) + self._sender.synchronize_weights() + + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Get the params_map for init_on_sender().""" + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + def init_on_worker( self, model_id: str, @@ -1733,56 +1886,21 @@ def init_on_worker( if context is not None: if hasattr(context, "get_model"): model = context.get_model(model_id) - elif model is not None: - model = None + elif model is None: + model = _resolve_model(context, model_id) worker_idx = getattr(context, "worker_idx", worker_idx) - # Receive weights from this worker's dedicated queue if available - if self._weight_init_queues and worker_idx is not None: - # Each worker has its own queue - no race conditions! - if worker_idx in self._weight_init_queues: - worker_queue = self._weight_init_queues[worker_idx] - timeout = kwargs.get("timeout", 10.0) - try: - # Read from our dedicated queue - only messages for this worker are here - while True: - msg_model_id, shared_weights = worker_queue.get(timeout=timeout) - - # Register the shared weights in the transport - self._shared_transport._policy_weights[ - msg_model_id - ] = shared_weights - - # If this is the model we're initializing, apply weights - if msg_model_id == model_id and model is not None: - shared_weights.to_module(model) - self._shared_transport._registered_with_workers.add( - msg_model_id - ) - break - elif msg_model_id == model_id: - # Model will be applied later when it's available - self._shared_transport._registered_with_workers.add( - msg_model_id - ) - break - # If not the model we're looking for, still register it but keep looking - except Empty: - # No weights pre-registered for this model (will use auto-register or policy_factory) - pass - # Create receiver with the shared transport receiver = WeightReceiver(self) if context is not None: receiver._context_ref = weakref.ref(context) receiver._transport = self._shared_transport # Use shared transport - # Register the model - this will apply the shared weights to it - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) + # Register the model + receiver._register_model(model) + + # Store worker_idx for synchronize_weights + receiver._worker_idx = worker_idx self._receiver = receiver self._initialized_on_worker = True @@ -1809,13 +1927,13 @@ def get_message_queue(self): return self._message_queue def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport (legacy). + """Create shared memory transport. Returns the shared transport instance that all workers will use. Since this is shared memory, there's only one transport shared by all workers. - Note: This is a legacy method. The new init_on_sender/init_on_worker API - is the preferred way to set up the transport. + Note: + This is used internally by init_on_sender/init_on_worker. """ return self._shared_transport @@ -1903,10 +2021,14 @@ def init_on_worker( self._initialized_on_worker = True def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Returns None as no transport is needed (legacy).""" + """Create a no-op transport. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ # Return a dummy transport that does nothing class NoOpTransport: - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: pass def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: From 9d1cac25190d0ef3a0dd1fd86f8e67174e642436 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 12 Nov 2025 17:41:01 +0000 Subject: [PATCH 04/17] use id(weight) --- torchrl/weight_update/weight_sync_schemes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index e9fc033294d..3dd817b0ef6 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -199,7 +199,7 @@ def register_weights( # Create set of the unique weights self._unique_weights = [] for weights in params_map.values(): - if weights in self._unique_weights: + if id(weights) in [id(w) for w in self._unique_weights]: continue self._unique_weights.append(weights) From fa5eff38a62700ebe24b88ad5443cf04c6c3dd79 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 12 Nov 2025 17:43:44 +0000 Subject: [PATCH 05/17] clone the state_dict --- torchrl/collectors/_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 14ceb8f86d8..993ec4e6883 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -6,6 +6,7 @@ from multiprocessing import connection, queues from typing import Any +from torchrl.collectors.utils import _cast import numpy as np import torch from tensordict import TensorDictBase @@ -19,6 +20,7 @@ _TIMEOUT, DEFAULT_EXPLORATION_TYPE, ) +from tensordict import TensorDict from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _map_to_cpu_if_needed, _TrajectoryPool @@ -459,6 +461,8 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility # CPU and CUDA tensors are already shareable and don't need conversion state_dict = tree_map(_map_to_cpu_if_needed, state_dict) + state_dict = TensorDict(state_dict) + state_dict = state_dict.clone().apply(_cast, state_dict) pipe_child.send((state_dict, "state_dict")) has_timed_out = False continue From 6783e0e97e404faa8edfa0c9fca93aadee9c1892 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 09:34:09 +0000 Subject: [PATCH 06/17] address device mismatch --- test/test_collector.py | 8 ++++++-- torchrl/collectors/_multi_base.py | 2 +- torchrl/collectors/_runner.py | 9 +++------ torchrl/collectors/_single.py | 1 - torchrl/collectors/utils.py | 1 - torchrl/weight_update/weight_sync_schemes.py | 2 +- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 8ce8a055091..865a234c849 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1900,7 +1900,9 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, it falls back to device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # same but more specific device = None @@ -1920,7 +1922,9 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, and env_device == policy_device, it falls back to env_device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # none has a device device = None diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 44efecc58ec..6fafbe40354 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -972,7 +972,7 @@ def _run_processes(self) -> None: # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: # Workers will call receiver.synchronize_weights() during init and may block waiting for data if self._weight_senders: - for model_id, sender in self._weight_senders.items(): + for sender in self._weight_senders.values(): sender.synchronize_weights() # Wait for workers to be ready diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 993ec4e6883..b92fcad8713 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -6,10 +6,9 @@ from multiprocessing import connection, queues from typing import Any -from torchrl.collectors.utils import _cast import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torch import nn as nn from torchrl import logger as torchrl_logger @@ -20,10 +19,10 @@ _TIMEOUT, DEFAULT_EXPLORATION_TYPE, ) -from tensordict import TensorDict from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.base import DataCollectorBase -from torchrl.collectors.utils import _map_to_cpu_if_needed, _TrajectoryPool + +from torchrl.collectors.utils import _cast, _map_to_cpu_if_needed, _TrajectoryPool from torchrl.data import ReplayBuffer from torchrl.envs import EnvBase, EnvCreator from torchrl.envs.utils import ExplorationType @@ -128,8 +127,6 @@ def _main_async_collector( no_cuda_sync=no_cuda_sync, weight_sync_schemes=weight_sync_schemes, ) - print("Inner collector created") - # Set up weight receivers for worker process # Note: For the "policy" model, initialization is done in _make_policy_factory # This section only handles additional models (not "policy") diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 7a78cf41605..7beda2deb63 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -452,7 +452,6 @@ def _init_policy( if policy is None: if policy_factory is not None: policy = policy_factory() - print(f"Policy factory created: {policy}") else: policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 8492a52041e..799c0a5e692 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -286,7 +286,6 @@ def _make_meta_policy(policy: nn.Module): A context manager that temporarily replaces policy parameters with meta device versions. On exit, the original parameters are restored to the policy. """ - param_and_buf = TensorDict.from_module(policy, as_module=True) return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 3dd817b0ef6..265d344d401 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -7,7 +7,7 @@ import abc import weakref -from collections.abc import Iterator +from collections.abc import Callable, Iterator from typing import Any, Literal, Protocol import torch From 6d581f092be77140b489dcb419da7e973936db42 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 09:55:15 +0000 Subject: [PATCH 07/17] fix policy with device --- test/test_collector.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 865a234c849..ec704bf4773 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1828,8 +1828,14 @@ def forward(self, tensordict): class PolicyWithDevice(TensorDictModuleBase): in_keys = ["observation"] out_keys = ["action"] - # receives and sends data on gpu - default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + + def __init__(self, default_device=None): + super().__init__() + self.default_device = ( + default_device + if default_device is not None + else ("cuda:0" if torch.cuda.device_count() else "cpu") + ) def forward(self, tensordict): assert tensordict.device == _make_ordinal_device( @@ -1846,7 +1852,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = main_device env = self.DeviceLessEnv(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1887,7 +1893,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = None env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1909,7 +1915,7 @@ def test_output_device(self, main_device, storing_device): env_device = main_device policy_device = main_device env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, From e5f9a9c660c0fec88e23f3038cd0cf54670f7836 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:14:56 +0000 Subject: [PATCH 08/17] no TD state_dict --- torchrl/collectors/_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index b92fcad8713..e4448ba71d9 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -456,10 +456,10 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): state_dict = inner_collector.state_dict() # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility - # CPU and CUDA tensors are already shareable and don't need conversion + # CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth) state_dict = tree_map(_map_to_cpu_if_needed, state_dict) state_dict = TensorDict(state_dict) - state_dict = state_dict.clone().apply(_cast, state_dict) + state_dict = state_dict.clone().apply(_cast, state_dict).to_dict() pipe_child.send((state_dict, "state_dict")) has_timed_out = False continue From 484eb9c0f9960b4097b3e5e2253c38c3e64623e8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:34:28 +0000 Subject: [PATCH 09/17] fix legacy code --- test/test_collector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index ec704bf4773..5ad13c43d43 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1576,7 +1576,12 @@ def create_env(): # Create shared memory weight sync scheme weight_sync_scheme = SharedMemWeightSyncScheme() - weight_sync_scheme.register_shared_weights("policy", policy_weights) + # Use the new init_on_sender API with params_map + # All 3 workers share the same CPU weights in shared memory + weight_sync_scheme.init_on_sender( + model_id="policy", + params_map={0: policy_weights, 1: policy_weights, 2: policy_weights}, + ) collector_class = ( MultiSyncDataCollector if not use_async else MultiaSyncDataCollector From 92a4b1b1dfe7e61d0f017cc7727c89ad7aa87259 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:36:35 +0000 Subject: [PATCH 10/17] fix state dict device --- test/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 5ad13c43d43..d367d9f3430 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1515,7 +1515,7 @@ def create_env(): ].keys() for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1533,7 +1533,7 @@ def create_env(): AssertionError ) if torch.cuda.is_available() else nullcontext(): torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1546,7 +1546,7 @@ def create_env(): for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) finally: From 8dff25a0bc384b048aa987d30a6e897ebe1f6fc6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:37:46 +0000 Subject: [PATCH 11/17] fix unwanted model_id --- test/test_collector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index d367d9f3430..f53924784d9 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1579,7 +1579,6 @@ def create_env(): # Use the new init_on_sender API with params_map # All 3 workers share the same CPU weights in shared memory weight_sync_scheme.init_on_sender( - model_id="policy", params_map={0: policy_weights, 1: policy_weights, 2: policy_weights}, ) From 7ec411089931faf0ea8aebbeb2ee3506f69ca5df Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 18:35:03 +0000 Subject: [PATCH 12/17] final? --- .../reference/collectors_weightsync.rst | 6 + examples/collectors/multi_weight_updates.py | 2 +- test/test_collector.py | 21 +- test/test_weightsync.py | 6 +- torchrl/collectors/_multi_base.py | 49 +- torchrl/collectors/_runner.py | 9 +- torchrl/collectors/distributed/generic.py | 4 +- torchrl/collectors/distributed/ray.py | 2 +- torchrl/collectors/distributed/rpc.py | 2 +- torchrl/weight_update/__init__.py | 39 +- torchrl/weight_update/_distributed.py | 210 ++ torchrl/weight_update/_mp.py | 431 +++++ torchrl/weight_update/_noupdate.py | 76 + torchrl/weight_update/_ray.py | 543 ++++++ torchrl/weight_update/_rpc.py | 123 ++ torchrl/weight_update/_shared.py | 519 +++++ .../weight_update/llm/vllm_double_buffer.py | 5 +- torchrl/weight_update/llm/vllm_nccl.py | 6 +- torchrl/weight_update/utils.py | 43 + torchrl/weight_update/weight_sync_schemes.py | 1712 +---------------- 20 files changed, 2099 insertions(+), 1709 deletions(-) create mode 100644 torchrl/weight_update/_distributed.py create mode 100644 torchrl/weight_update/_mp.py create mode 100644 torchrl/weight_update/_noupdate.py create mode 100644 torchrl/weight_update/_ray.py create mode 100644 torchrl/weight_update/_rpc.py create mode 100644 torchrl/weight_update/_shared.py create mode 100644 torchrl/weight_update/utils.py diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index b6c2257e28f..6e73e2a91f6 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -198,6 +198,9 @@ Weight Senders :template: rl_template.rst WeightSender + MPWeightSender + RPCWeightSender + DistributedWeightSender RayModuleTransformSender Weight Receivers @@ -208,6 +211,9 @@ Weight Receivers :template: rl_template.rst WeightReceiver + MPWeightReceiver + RPCWeightReceiver + DistributedWeightReceiver RayModuleTransformReceiver Transports diff --git a/examples/collectors/multi_weight_updates.py b/examples/collectors/multi_weight_updates.py index 7011e7f4879..6533eda3975 100644 --- a/examples/collectors/multi_weight_updates.py +++ b/examples/collectors/multi_weight_updates.py @@ -25,7 +25,7 @@ from torchrl.data import LazyTensorStorage, ReplayBuffer from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms.module import ModuleTransform -from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme +from torchrl.weight_update import MultiProcessWeightSyncScheme def make_module(): diff --git a/test/test_collector.py b/test/test_collector.py index f53924784d9..b0350ec025e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1558,8 +1558,6 @@ def create_env(): ) # MultiSync has known indexing issues with SharedMem def test_update_weights_shared_mem(self, use_async): """Test shared memory weight synchronization scheme.""" - from tensordict import TensorDict - from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme def create_env(): return ContinuousActionVecMockEnv() @@ -4117,16 +4115,17 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): frames_per_batch=16, **kwargs, ) - if not isinstance(collector, SyncDataCollector): - if weight_sync_scheme is not None: - assert isinstance( - collector._weight_sync_schemes["policy"], weight_sync_scheme - ) - else: - assert isinstance( - collector._weight_sync_schemes["policy"], SharedMemWeightSyncScheme - ) try: + if not isinstance(collector, SyncDataCollector): + if weight_sync_scheme is not None: + assert isinstance( + collector._weight_sync_schemes["policy"], weight_sync_scheme + ) + else: + assert isinstance( + collector._weight_sync_schemes["policy"], + SharedMemWeightSyncScheme, + ) collector.start() for _ in range(10): time.sleep(0.1) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 82992b14ca4..022055cd659 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -17,8 +17,7 @@ from tensordict.nn import TensorDictModule from torch import multiprocessing as mp from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, +from torchrl.weight_update import ( DistributedWeightSyncScheme, MPTransport, MultiProcessWeightSyncScheme, @@ -27,6 +26,9 @@ RayWeightSyncScheme, RPCWeightSyncScheme, SharedMemTransport, +) +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( SharedMemWeightSyncScheme, WeightStrategy, ) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 6fafbe40354..01633823242 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -334,7 +334,8 @@ def __init__( policy_factory = self._setup_policy_factory(policy_factory) # Set up weight synchronization - weight_sync_schemes = {} + if weight_sync_schemes is None: + weight_sync_schemes = {} if ( not any(policy_factory) and not weight_sync_schemes @@ -516,13 +517,13 @@ def _setup_multi_policy_and_weights( weight_sync_policy = weight_sync_schemes.get("policy") if weight_sync_policy is None: return - if weight_sync_policy._initialized_on_sender: - return if any(p is not None for p in policy_factory): - raise RuntimeError( - f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" - ) - weight_sync_policy.init_on_sender(model=policy, devices=self.policy_device) + if not weight_sync_policy._initialized_on_sender: + raise RuntimeError( + f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" + ) + # Weight sync scheme initialization happens in _run_processes + # where pipes and workers are available else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( @@ -821,19 +822,20 @@ def _run_processes(self) -> None: torch.set_num_threads(self.num_threads) queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] - self.pipes = [] self._traj_pool = _TrajectoryPool(lock=True) - # Initialize weight sync schemes early for SharedMemWeightSyncScheme - # (queue created in __init__ will be pickled with scheme to workers) - # For MultiProcessWeightSyncScheme, we'll initialize after pipes are available + # Create all pipes upfront (needed for weight sync scheme initialization) + # Store as list of (parent, child) tuples for use in worker creation + pipe_pairs = [mp.Pipe() for _ in range(self.num_workers)] + # Extract parent pipes for external use (e.g., polling, receiving messages) + self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs] + + # Initialize all weight sync schemes now that pipes are available + # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes) + # can be initialized here since all required resources exist if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): - # Only initialize SharedMemWeightSyncScheme now (needs queue before workers) - # MultiProcessWeightSyncScheme will be initialized after workers are created - if isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( - scheme, "init_on_sender" - ): + if hasattr(scheme, "init_on_sender"): scheme.init_on_sender(model_id=model_id, context=self) self._weight_senders[model_id] = scheme.get_sender() @@ -848,7 +850,7 @@ def _run_processes(self) -> None: for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): - pipe_parent, pipe_child = mp.Pipe() # send messages to procs + pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( env_fun, EnvBase ): # to avoid circular imports @@ -966,7 +968,6 @@ def _run_processes(self) -> None: ) from err pipe_child.close() self.procs.append(proc) - self.pipes.append(pipe_parent) # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: @@ -1027,18 +1028,6 @@ def _run_processes(self) -> None: # Legacy string error message raise RuntimeError(msg) - # Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available - # (SharedMemWeightSyncScheme was already initialized before workers) - if self._weight_sync_schemes: - for model_id, scheme in self._weight_sync_schemes.items(): - # Only initialize non-SharedMem schemes here (need pipes) - if not isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( - scheme, "init_on_sender" - ): - scheme.init_on_sender(model_id=model_id, context=self) - # Get the initialized sender - self._weight_senders[model_id] = scheme.get_sender() - self.queue_out = queue_out self.closed = False diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index e4448ba71d9..091ab8c4c9d 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -30,7 +30,7 @@ def _make_policy_factory( - *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None ): if policy is not None and policy_factory is not None: raise ValueError("policy cannot be used with policy_factory") @@ -40,7 +40,7 @@ def _make_policy_factory( if weight_sync_scheme is not None: # Initialize the receiver on the worker side weight_sync_scheme.init_on_worker( - model=policy, model_id="policy", worker_idx=worker_idx + model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe ) # Get the receiver and synchronize initial weights receiver = weight_sync_scheme.get_receiver() @@ -92,8 +92,11 @@ def _main_async_collector( _make_policy_factory, policy=policy, policy_factory=policy_factory, - weight_sync_scheme=weight_sync_schemes.get("policy"), + weight_sync_scheme=weight_sync_schemes.get("policy") + if weight_sync_schemes + else None, worker_idx=worker_idx, + pipe=pipe_child, ) policy = None try: diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index ff15aa63d67..61180a3cb21 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -570,9 +570,7 @@ def __init__( # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Distributed weight sync scheme for distributed collectors - from torchrl.weight_update.weight_sync_schemes import ( - DistributedWeightSyncScheme, - ) + from torchrl.weight_update import DistributedWeightSyncScheme weight_sync_schemes = { "policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index a88e1aa7fcb..7547985e1ac 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -539,7 +539,7 @@ def check_list_length_consistency(*lists): # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Ray weight sync scheme for Ray collectors - from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme + from torchrl.weight_update import RayWeightSyncScheme weight_sync_schemes = {"policy": RayWeightSyncScheme()} diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index bdf28942e0f..dfbd8a7c5a2 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -417,7 +417,7 @@ def __init__( # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to RPC weight sync scheme for RPC collectors - from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme + from torchrl.weight_update import RPCWeightSyncScheme weight_sync_schemes = {"policy": RPCWeightSyncScheme()} diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py index 556064a6113..6e2b66c9d51 100644 --- a/torchrl/weight_update/__init__.py +++ b/torchrl/weight_update/__init__.py @@ -3,22 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .weight_sync_schemes import ( +from ._distributed import ( DistributedTransport, + DistributedWeightReceiver, + DistributedWeightSender, DistributedWeightSyncScheme, +) +from ._mp import ( MPTransport, + MPWeightReceiver, + MPWeightSender, MultiProcessWeightSyncScheme, - NoWeightSyncScheme, +) +from ._noupdate import NoWeightSyncScheme +from ._ray import ( RayActorTransport, RayModuleTransformReceiver, RayModuleTransformScheme, RayModuleTransformSender, RayTransport, RayWeightSyncScheme, - RPCTransport, - RPCWeightSyncScheme, - SharedMemTransport, - SharedMemWeightSyncScheme, +) +from ._rpc import RPCTransport, RPCWeightReceiver, RPCWeightSender, RPCWeightSyncScheme +from ._shared import SharedMemTransport, SharedMemWeightSyncScheme +from .weight_sync_schemes import ( TransportBackend, WeightReceiver, WeightSender, @@ -27,19 +35,30 @@ ) __all__ = [ + # Base classes "TransportBackend", + "WeightStrategy", + "WeightSender", + "WeightReceiver", + "WeightSyncScheme", + # Transports "MPTransport", "SharedMemTransport", "RayTransport", "RayActorTransport", "RPCTransport", "DistributedTransport", - "WeightStrategy", - "WeightSender", - "WeightReceiver", + # Senders + "MPWeightSender", + "RPCWeightSender", + "DistributedWeightSender", "RayModuleTransformSender", + # Receivers + "MPWeightReceiver", + "RPCWeightReceiver", + "DistributedWeightReceiver", "RayModuleTransformReceiver", - "WeightSyncScheme", + # Schemes "MultiProcessWeightSyncScheme", "SharedMemWeightSyncScheme", "NoWeightSyncScheme", diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py new file mode 100644 index 00000000000..a742d922a12 --- /dev/null +++ b/torchrl/weight_update/_distributed.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from typing import Any + +import torch +from tensordict import TensorDict + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class DistributedWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed. + + This scheme uses torch.distributed primitives (send/recv) to synchronize + weights across distributed workers. Each worker gets its own transport, + following the same pattern as multiprocess collectors. + + Args: + backend (str): The distributed backend ("gloo", "nccl", etc.) + sync (bool): Whether to use synchronous weight updates + """ + + def __init__(self, backend: str = "gloo", sync: bool = True): + super().__init__() + self.backend = backend + self.sync = sync + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create distributed transport for a specific worker. + + Args: + pipe_or_context: A tuple of (store, rank) for the worker. + + Returns: + DistributedTransport configured for this specific worker. + """ + if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: + store, rank = pipe_or_context + return DistributedTransport(store=store, rank=rank, sync=self.sync) + # Fallback - shouldn't normally happen + return DistributedTransport() + + +class DistributedTransport: + """torch.distributed transport for communicating with a single distributed worker. + + This transport handles weight updates for ONE specific distributed worker via + torch.distributed send/recv. Multiple transports are created for multiple workers, + following the same pattern as multiprocess collectors. + """ + + def __init__(self, store=None, rank=None, sync=True): + """Initialize the DistributedTransport. + + Args: + store: TCPStore for communication. + rank: Worker rank (1-indexed). + sync: Whether to use synchronous weight updates. + """ + self._store = store + self._rank = rank + self._sync = sync + self._weights_buffer = None # TensorDict buffer for receiving weights + + def send_weights(self, weights: Any) -> None: + """Send weights to the distributed worker.""" + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + # Wait for acknowledgment + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def send_weights_async(self, weights: Any) -> None: + """Send weights to distributed worker without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + def wait_ack(self) -> None: + """Wait for acknowledgment from distributed worker.""" + if self._store is None or self._rank is None: + return + + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights via torch.distributed, using TCPStore for signaling. + + This implements the RPC-like pattern: + 1. Check TCPStore for signal (non-blocking) + 2. If signal present, receive weights via torch.distributed + 3. Clean up signal and send acknowledgment + + Args: + timeout: Timeout for receiving (currently not used for TCPStore check) + + Returns: + Tuple of (model_id, weights) if weights were received, None otherwise. + """ + if self._store is None or self._rank is None: + return None + + try: + # Non-blocking check of TCPStore "mailbox" for signal + msg = self._store.get(f"NODE_{self._rank}_in") + + if msg == b"update_weights": + # Initialize weights buffer on first use + if self._weights_buffer is None: + self._weights_buffer = TensorDict() + + # Receive weights via torch.distributed + # recv() and irecv() update the TensorDict in place + if self._sync: + self._weights_buffer.recv(src=0) + else: + # irecv() blocks until weights are received + self._weights_buffer.irecv(src=0) + + # Clean up the signal + self._store.delete_key(f"NODE_{self._rank}_in") + + # Note: Acknowledgment is sent separately via send_ack() if transport supports it + # This matches the pattern in WeightReceiver.receive() + + # Return model_id and received weights + # For distributed transport, we use "policy" as default model_id + return ("policy", self._weights_buffer) + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") + except KeyError: + # No message in store - no weights available + return None + + return None + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender via TCPStore. + + Args: + message: Acknowledgment message to send (default: "updated") + """ + if self._store is None or self._rank is None: + return + + self._store.set(f"NODE_{self._rank}_out", message.encode()) + + def check_connection(self) -> bool: + """Check if torch.distributed is initialized.""" + return torch.distributed.is_initialized() + + def synchronize_weights_on_sender(self) -> None: + """No-op for DistributedTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for DistributedTransport - weights are received via receive_weights().""" + return None + + +class DistributedWeightReceiver(WeightReceiver): + """Weight receiver for torch.distributed systems. + + Receives weight updates from the main process via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None + + +class DistributedWeightSender(WeightSender): + """Weight sender for torch.distributed systems. + + Sends weight updates to distributed workers via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py new file mode 100644 index 00000000000..12d9c7be3fb --- /dev/null +++ b/torchrl/weight_update/_mp.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +import weakref +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class MultiProcessWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for multiprocess operations using pipes. + + This scheme creates transports that communicate via multiprocessing pipes. + Similar to SharedMemWeightSyncScheme which uses queues for shared memory + buffer distribution, MultiProcessWeightSyncScheme uses pipes to send + weight copies to each worker. + + Synchronization flow: + - init_on_sender() creates a MPWeightSender and registers all worker pipes + - synchronize_weights() triggers the initial weight distribution via pipes + - init_on_worker() creates a MPWeightReceiver that receives from its pipe + - Subsequent updates use send() which extracts, sends, and waits for ACKs + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + + Example: + >>> # Basic usage with collector + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.synchronize_weights() is called automatically by collector + """ + + def synchronize_weights(self): + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's synchronize_weights(). + The sender will extract weights from the context and send them to all workers via pipes. + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + "Must call init_on_sender() before synchronize_weights() on MultiProcessWeightSyncScheme" + ) + self._sender.synchronize_weights() + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipes and num_workers + **kwargs: Alternative to context (pipes, num_workers, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipes = getattr(context, "pipes", None) + num_workers = getattr(context, "num_workers", None) + else: + pipes = kwargs.get("pipes") + num_workers = kwargs.get("num_workers") + + if pipes is None: + raise ValueError("pipes must be provided via context or kwargs") + if num_workers is None: + num_workers = len(pipes) if pipes else 0 + + # Create sender and register all workers + sender = MPWeightSender(self) + sender._model_id = model_id + if context is not None: + sender._context_ref = weakref.ref(context) + + for worker_idx, pipe in enumerate(pipes): + sender._register_worker(worker_idx, pipe) + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipe and model + **kwargs: Alternative to context (pipe, model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipe = getattr(context, "pipe", None) + if hasattr(context, "get_model"): + model = context.get_model(model_id) + else: + model = None + else: + pipe = kwargs.get("pipe") + model = kwargs.get("model") + + if pipe is None: + raise ValueError("pipe must be provided via context or kwargs") + + # Create receiver and register model + receiver = MPWeightReceiver(self) + if context is not None: + receiver._context_ref = weakref.ref(context) + receiver._register_worker_transport(pipe) + if model is not None: + receiver._register_model(model) + else: + # Register by model_id for later resolution + receiver._register_model(model_id) + + self._receiver = receiver + self._initialized_on_worker = True + + def create_transport(self, pipe: Any) -> TransportBackend: + """Create an MPTransport using the provided pipe. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ + return MPTransport(pipe) + + +class MPTransport: + """Multiprocessing transport using pipes. + + This transport uses pipes for weight distribution and synchronization. + Similar to SharedMemTransport's queue-based approach, MPTransport uses + pipes to send initial weights to workers during synchronization. + + Initialization flow: + - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via pipes + - Workers receive the initial weights via synchronize_weights_on_worker() + - Subsequent updates use send_weights_async() followed by acknowledgments + + Args: + pipe_connection (mp.Pipe): The pipe connection to use for communication. + timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. + """ + + def __init__(self, pipe_connection, timeout: float = 10.0): + self.timeout = timeout + self.pipe = pipe_connection + + def send_weights(self, weights: Any) -> None: + """Send weights through the pipe. + + Sends weights and waits for acknowledgment to ensure delivery. + """ + self.send_weights_async(weights) + self.wait_ack() + + def send_weights_async(self, weights: Any, model_id: str = "policy") -> None: + """Send weights through the pipe without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + # Send in format expected by worker loop: ((model_id, weights), "update_weights") + self.pipe.send(((model_id, weights), "update_weights")) + + def wait_ack(self) -> None: + """Wait for acknowledgment from worker.""" + self.check_ack("updated") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from the pipe (used in worker process). + + This method only handles weight update messages. Other messages + (like "close", "continue", etc.) are ignored and should be handled + by the main worker loop. + + Returns: + Tuple of (model_id, weights) if weights were received, None if no data available + or if a non-weight message was received. + + Note: + model_id is returned as "policy" for backward compatibility, but transports + are now bound to a single model during initialization. + """ + if self.pipe.poll(timeout): + data_in, msg = self.pipe.recv() + if msg == "update_weights": + # data_in is now (model_id, weights) + return data_in + else: + # Not a weight update message - put it back and return None + # This allows the main worker loop to handle other messages + # Note: We can't actually "put it back", so we'll just return None + # and the message is lost. This is why receive() should only be called + # when we're expecting weight updates, not in the main message loop. + return None + # No data available - return None instead of raising TimeoutError + # This allows non-blocking checks in the worker loop + return None + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender.""" + self.pipe.send((None, message)) + + def check_ack(self, message: str = "updated") -> None: + """Check for acknowledgment.""" + _, msg = self.pipe.recv() + if msg != message: + raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") + + def check_connection(self) -> bool: + return not self.pipe.closed + + def synchronize_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via MPWeightSender.synchronize_weights(). + + The actual sending happens in MPWeightSender.synchronize_weights(), which: + 1. Extracts weights from the context (e.g., collector.policy) + 2. Calls send_weights_async() on all worker transports + 3. Sends initial weights through pipes to all workers + + This is similar to SharedMemTransport.synchronize_weights_on_sender() which + sends shared memory buffer references via queues. + """ + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """Receive initial weights from sender during worker initialization. + + This method blocks waiting for the initial weights to be sent from the main process + via pipe. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives + shared memory buffer references via queues, this receives the actual weights via pipes. + + The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). + + Args: + worker_idx: The worker index (used for logging/debugging). + + Returns: + The received weights if available, None otherwise (weights will come later via receive()). + """ + # Wait for initial weights (blocking) + if self.pipe.poll(timeout=self.timeout): + data_in, msg = self.pipe.recv() + if msg == "update_weights": + # data_in is (model_id, weights), extract just the weights + _, weights = data_in + return weights + # If we don't receive weights, return None (weights will come later) + return None + + +class MPWeightReceiver(WeightReceiver): + """Weight receiver for multiprocess systems using pipes. + + Receives weight updates from the main process via multiprocessing pipes. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + + +class MPWeightSender(WeightSender): + """Weight sender for multiprocess systems using pipes. + + Sends weight updates to worker processes via multiprocessing pipes. + Supports both synchronous and asynchronous sending patterns. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + _model_id: str + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) + + model_id = self._model_id + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def synchronize_weights(self) -> None: + """Synchronize weights with workers before collection starts. + + Extracts weights from the collector's policy and sends them to all workers + via pipes. This is called once after workers are initialized but before they + start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + Raises: + RuntimeError: If no context is available or context has no policy. + """ + # Get context (collector) + context = self._context_ref() if self._context_ref is not None else None + if context is None or not hasattr(context, "policy"): + raise RuntimeError( + "MPWeightSender requires context with policy for synchronize_weights()" + ) + + # Extract and prepare weights from the policy + prepared_weights = self._scheme.prepare_weights( + weights=context.policy, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) + + # Send to all workers via pipes (no ACK - workers are still initializing) + for transport in self._iterate_transports(): + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) # type: ignore[attr-defined] + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py new file mode 100644 index 00000000000..697f56943e8 --- /dev/null +++ b/torchrl/weight_update/_noupdate.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class NoWeightSyncScheme(WeightSyncScheme): + """No-op weight synchronization scheme. + + This scheme disables weight synchronization entirely. + """ + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Create a no-op sender + sender = WeightSender(self) + sender._model_id = model_id + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Create a no-op receiver + receiver = WeightReceiver(self) + receiver._model_ref = model_id + + self._receiver = receiver + self._initialized_on_worker = True + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create a no-op transport. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ + # Return a dummy transport that does nothing + class NoOpTransport: + def send_weights(self, weights: Any) -> None: + pass + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + return None + + def check_connection(self) -> bool: + return True + + return NoOpTransport() diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py new file mode 100644 index 00000000000..3fb4e571224 --- /dev/null +++ b/torchrl/weight_update/_ray.py @@ -0,0 +1,543 @@ +from __future__ import annotations + +import weakref +from typing import Any, Literal + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class RayWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for Ray distributed computing. + + This scheme uses Ray's object store and remote calls to synchronize weights + across distributed workers (Ray actors). + + Each remote collector gets its own transport, following the same pattern + as multiprocess collectors. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create Ray-based transport for a specific remote collector. + + Args: + pipe_or_context: The Ray actor handle for the remote collector. + + Returns: + RayTransport configured for this specific remote collector. + """ + return RayTransport(remote_collector=pipe_or_context) + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing remote_collectors + **kwargs: Alternative to context (remote_collectors, source_model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + remote_collectors = getattr(context, "remote_collectors", None) + num_workers = getattr(context, "num_workers", None) or getattr( + context, "num_collectors", None + ) + else: + remote_collectors = kwargs.get("remote_collectors") + num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") + + if remote_collectors is None: + raise ValueError("remote_collectors must be provided via context or kwargs") + if num_workers is None: + num_workers = len(remote_collectors) if remote_collectors else 0 + + # Create sender and register all workers (Ray actors) + sender = WeightSender(self) + sender._model_id = model_id + + # Register each Ray actor - _register_worker will create the transport + for worker_idx, remote_collector in enumerate(remote_collectors): + sender._register_worker(worker_idx, remote_collector) + + # Set context with weak reference to avoid circular refs + if context is not None: + sender._set_context(weakref.ref(context), model_id) + + # Store source model reference if provided for automatic weight extraction + source_model = kwargs.get("source_model") + if source_model is not None: + sender._source_model = source_model + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + For Ray workers, weight updates are handled via remote method calls, + so this is typically a no-op. The receiver is created but doesn't + need special initialization. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the remote collector) + **kwargs: Optional parameters (pipe, model, etc.) + """ + # Create receiver + receiver = WeightReceiver(self) + + # Register model if provided + model = kwargs.get("model") or ( + getattr(context, "policy", None) if context else None + ) + if model is not None: + receiver._register_model(model) + + # Set context if provided + if context is not None: + receiver._set_context(weakref.ref(context)) + + self._receiver = receiver + self._initialized_on_worker = True + + +class RayModuleTransformScheme(WeightSyncScheme): + """Weight synchronization for RayModuleTransform actors. + + This scheme is designed specifically for updating models hosted within + Ray actors, such as RayModuleTransform instances. It creates a transport + that directly calls the actor's weight update methods. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + """ + + def __init__(self, strategy: str = "tensordict"): + super().__init__(strategy) + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create RayActorTransport for the given actor. + + Args: + pipe_or_context: Either a Ray actor reference or a context object + from which to extract the actor reference. + + Returns: + RayActorTransport configured with the actor reference. + """ + actor_ref = self._extract_actor_ref(pipe_or_context) + return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) + + def _extract_actor_ref(self, pipe_or_context: Any) -> Any: + """Extract the Ray actor reference from the context. + + Args: + pipe_or_context: Either a direct actor reference or an object + with an `_actor` attribute. + + Returns: + The Ray actor reference. + """ + if hasattr(pipe_or_context, "_actor"): + return pipe_or_context._actor + return pipe_or_context + + def create_sender(self) -> RayModuleTransformSender: + """Create a specialized sender for Ray actor communication.""" + return RayModuleTransformSender(self) + + def create_receiver(self) -> RayModuleTransformReceiver: + """Create a specialized receiver for Ray actor communication.""" + return RayModuleTransformReceiver(self) + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing actor references + **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) + """ + # Extract actor references from context or kwargs + if context is not None: + # Could be actor_refs, actors, or remote_collectors + actor_refs = ( + getattr(context, "actor_refs", None) + or getattr(context, "actors", None) + or getattr(context, "remote_collectors", None) + ) + else: + actor_refs = ( + kwargs.get("actor_refs") + or kwargs.get("actors") + or kwargs.get("remote_collectors") + ) + + if actor_refs is None: + raise ValueError( + "actor_refs (or actors) must be provided via context or kwargs" + ) + + # Create specialized sender + sender = self.create_sender() + sender._model_id = model_id + + # Register all actors - _register_worker will create the transport + for worker_idx, actor_ref in enumerate(actor_refs): + sender._register_worker(worker_idx, actor_ref) + + # Set context with weak reference + if context is not None: + sender._set_context(weakref.ref(context), model_id) + + # Store source model if provided + source_model = kwargs.get("source_model") + if source_model is not None: + sender._source_model = source_model + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the actor itself) + **kwargs: Optional parameters (actor_ref, model, etc.) + """ + # Create specialized receiver + receiver = self.create_receiver() + + # Extract actor reference if needed + actor_ref = kwargs.get("actor_ref") or context + if actor_ref is not None: + # Register the transport for this actor + transport = self.create_transport(actor_ref) + receiver._register_worker_transport(transport) + + # Register model if provided + model = kwargs.get("model") or ( + getattr(context, "_actor_module", None) or getattr(context, "module", None) + if context + else None + ) + if model is not None: + receiver._register_model(model) + + # Set context if provided + if context is not None: + receiver._set_context(weakref.ref(context)) + + self._receiver = receiver + self._initialized_on_worker = True + + +class RayTransport: + """Ray transport for communicating with a single Ray collector actor. + + This transport handles weight updates for ONE specific remote collector. + Multiple transports are created for multiple collectors, following the + same pattern as multiprocess collectors. + """ + + def __init__( + self, + remote_collector=None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + ): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayTransport") + self._remote_collector = remote_collector + self._tensor_transport = tensor_transport + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via Ray.""" + if self._remote_collector is None: + return + + # Put weights in Ray's object store for efficient distribution + # Ray will automatically deduplicate if the same weights are sent to multiple actors + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + + # Send to the remote collector and wait for completion + # This ensures weights are applied before we continue + future = self._remote_collector.update_policy_weights_.remote( + policy_or_weights=weights_ref + ) + self.ray.wait([future], num_returns=1) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to remote collector without waiting for completion. + + Use wait_ack() to wait for completion after sending to all workers. + """ + if self._remote_collector is None: + return + + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + self._pending_future = self._remote_collector.update_policy_weights_.remote( + policy_or_weights=weights_ref + ) + + def wait_ack(self) -> None: + """Wait for the remote collector to finish applying weights.""" + if hasattr(self, "_pending_future"): + self.ray.wait([self._pending_future], num_returns=1) + del self._pending_future + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Ray workers typically don't receive weights through this transport.""" + return None + + def check_connection(self) -> bool: + """Check if Ray is initialized.""" + return self.ray.is_initialized() + + def synchronize_weights_on_sender(self) -> None: + """No-op for RayTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayTransport - weights are received via remote method calls.""" + return None + + +class RayActorTransport: + """Ray transport for communicating with Ray actors (not collectors). + + This transport is designed for updating models hosted within Ray actors, + such as RayModuleTransform instances. It directly calls the actor's + update_weights method rather than going through collector update methods. + """ + + def __init__( + self, + actor_ref=None, + update_method: str = "tensordict", + tensor_transport: Literal["object_store", "nixl"] = "object_store", + ): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayActorTransport") + + self._actor_ref = actor_ref + self._update_method = update_method + self._tensor_transport = tensor_transport + + def set_actor(self, actor_ref): + """Set the Ray actor reference to communicate with.""" + self._actor_ref = actor_ref + + def send_weights(self, weights: Any) -> None: + """Send weights to the Ray actor.""" + if self._actor_ref is None: + return + + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + + if self._update_method == "tensordict": + self.ray.get( + self._actor_ref._update_weights_tensordict.remote(params=weights_ref) + ) + elif self._update_method == "state_dict": + self.ray.get( + self._actor_ref._update_weights_state_dict.remote( + state_dict=weights_ref + ) + ) + else: + raise ValueError(f"Unknown update method: {self._update_method}") + + def send_weights_async(self, weights: Any) -> None: + """Send weights to Ray actor without waiting for completion. + + Use wait_ack() to wait for completion after sending to all actors. + """ + if self._actor_ref is None: + return + + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + + if self._update_method == "tensordict": + self._pending_future = self._actor_ref._update_weights_tensordict.remote( + params=weights_ref + ) + elif self._update_method == "state_dict": + self._pending_future = self._actor_ref._update_weights_state_dict.remote( + state_dict=weights_ref + ) + else: + raise ValueError(f"Unknown update method: {self._update_method}") + + def wait_ack(self) -> None: + """Wait for Ray actor to finish applying weights.""" + if hasattr(self, "_pending_future"): + self.ray.get(self._pending_future) + del self._pending_future + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Ray actor workers receive weights through direct method calls.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No acknowledgment needed for Ray actors.""" + + def check_ack(self, message: str = "updated") -> None: + """No acknowledgment needed for Ray actors.""" + + def check_connection(self) -> bool: + """Check if Ray is initialized and actor exists.""" + if not self.ray.is_initialized(): + return False + if self._actor_ref is None: + return False + return True + + def synchronize_weights_on_sender(self) -> None: + """No-op for RayActorTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayActorTransport - weights are received via remote method calls.""" + return None + + +class RayModuleTransformReceiver(WeightReceiver): + """Specialized receiver for RayModuleTransform actors. + + This receiver handles weight updates within Ray actors. + Since Ray actors receive weights through direct method calls, + this receiver primarily validates and applies weights locally. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + + def _register_worker_transport(self, actor_or_context: Any) -> None: + """Register the Ray actor's transport (internal). + + This is now handled by init_on_worker(). Only kept for internal use. + + Args: + actor_or_context: Either a Ray actor reference or a context object. + """ + self._transport = self._scheme.create_transport(actor_or_context) + + def apply_weights(self, weights: Any, inplace: bool = True) -> None: + """Apply received weights to registered model. + + For Ray actors, weights are applied directly to the module + within the actor's process space. + + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + """ + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights, inplace=inplace) + + +class RayModuleTransformSender(WeightSender): + """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + + This sender handles weight updates for models hosted within Ray actors. + Unlike the base WeightSender which uses pipes for multiprocessing, + this sender directly communicates with Ray actors via their remote methods. + + For Ray actors, there is typically only one shared actor instance, so we + store a single transport rather than per-worker transports. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + self._actor_ref = None + self._single_transport = None + self._context_ref = None + self._model_id_str = None + + def _set_context(self, context: Any, model_id: str) -> None: + """Set context for lazy actor resolution (internal). + + This is now handled by init_on_sender(). Only kept for internal use. + + Args: + context: The collector instance. + model_id: String path to the Ray actor (e.g., "env.transform[0]"). + """ + self._context_ref = weakref.ref(context) + self._model_id_str = model_id + + def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """For Ray actors, worker registration is a no-op (internal). + + Ray actors are shared across all workers, so we don't need per-worker + transports. The actor reference is resolved lazily on first use. + """ + + def update_weights(self, weights: Any) -> None: + """Send weights to the Ray actor. + + Args: + weights: Weights to send. + """ + if self._single_transport is None: + self._initialize_transport() + + if self._single_transport is not None: + self._single_transport.send_weights(weights) + + def _initialize_transport(self) -> None: + """Lazily initialize the transport by resolving the actor reference.""" + if self._context_ref is None or self._model_id_str is None: + return + + context = self._context_ref() + if context is None: + return + + model = _resolve_model(context, self._model_id_str) + if hasattr(model, "_actor"): + self._actor_ref = model._actor + self._single_transport = self._scheme.create_transport(model) + elif type(model).__name__ == "ActorHandle": + self._actor_ref = model + self._single_transport = self._scheme.create_transport(model) diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py new file mode 100644 index 00000000000..9290b23aa05 --- /dev/null +++ b/torchrl/weight_update/_rpc.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class RPCWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed.rpc. + + This scheme uses RPC calls to synchronize weights across distributed + workers. Each remote collector gets its own transport, following the + same pattern as multiprocess collectors. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create RPC-based transport for a specific remote collector. + + Args: + pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) + for the remote collector. + + Returns: + RPCTransport configured for this specific remote collector. + """ + if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: + collector_info, collector_rref, collector_class = pipe_or_context + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + ) + # If just passed the info directly + return RPCTransport(collector_info=pipe_or_context) + + +class RPCTransport: + """RPC transport for communicating with a single RPC remote collector. + + This transport handles weight updates for ONE specific remote collector via + torch.distributed.rpc. Multiple transports are created for multiple collectors, + following the same pattern as multiprocess collectors. + """ + + def __init__(self, collector_info=None, collector_rref=None, collector_class=None): + self._collector_info = collector_info + self._collector_rref = collector_rref + self._collector_class = collector_class + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via RPC.""" + if self._collector_info is None or self._collector_rref is None: + return + + from torch.distributed import rpc + + # Send weights to the remote collector and wait for completion + rpc.rpc_sync( + self._collector_info, + self._collector_class.update_policy_weights_, + args=(self._collector_rref, weights), + ) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to remote collector without waiting for completion. + + Use wait_ack() to wait for completion after sending to all workers. + """ + if self._collector_info is None or self._collector_rref is None: + return + + from torch.distributed import rpc + + # Send weights asynchronously + self._pending_future = rpc.rpc_async( + self._collector_info, + self._collector_class.update_policy_weights_, + args=(self._collector_rref, weights), + ) + + def wait_ack(self) -> None: + """Wait for the RPC call to complete.""" + if hasattr(self, "_pending_future"): + self._pending_future.wait() + del self._pending_future + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """RPC workers typically don't receive weights through this transport.""" + return None + + def check_connection(self) -> bool: + """Check if RPC is initialized.""" + from torch.distributed import rpc + + return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + + def synchronize_weights_on_sender(self) -> None: + """No-op for RPCTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via RPC calls.""" + return None + + +class RPCWeightReceiver(WeightReceiver): + """Weight receiver for RPC-based distributed systems. + + Receives weight updates from the main process via torch.distributed.rpc. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ + + +class RPCWeightSender(WeightSender): + """Weight sender for RPC-based distributed systems. + + Sends weight updates to remote collectors via torch.distributed.rpc calls. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py new file mode 100644 index 00000000000..098c4fe6e49 --- /dev/null +++ b/torchrl/weight_update/_shared.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import abc + +import weakref +from collections.abc import Callable, Iterator +from typing import Any, Literal, Protocol + +import torch +import torch.distributed + +from tensordict import TensorDict, TensorDictBase + +from torch import multiprocessing as mp, nn + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class SharedMemTransport: + """Shared memory transport for in-place weight updates. + + This transport uses queue-based buffer distribution for initialization, then + updates shared memory tensors directly for subsequent weight updates. + Workers automatically see weight updates without explicit communication. + + Initialization flow: + - Shared memory buffers are created and sent to workers via per-worker queues + - Workers receive the buffer reference and apply weights to their models + - Subsequent updates are pure in-place shared memory (zero-copy) + + Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. + + """ + + def __init__(self): + self._params_map = None # a dict[worker_idx, TensorDictBase] map + self._weight_queues = ( + None # Dict of per-worker queues for distributing shared weights + ) + + def register_weights( + self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] + ) -> None: + """Initialize per-worker queues for shared memory buffer distribution.""" + self._weight_queues = init_queues + self._params_map = params_map + # Create set of the unique weights + self._unique_weights = [] + for weights in params_map.values(): + if id(weights) in [id(w) for w in self._unique_weights]: + continue + self._unique_weights.append(weights) + + def synchronize_weights_on_sender(self) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. + + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, to avoid race conditions. + + """ + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + for worker_idx, queue in self._weight_queues.items(): + weights = self._params_map[worker_idx] + queue.put(weights) + + def synchronize_weights_on_worker( + self, worker_idx: int, timeout: float = 10.0 + ) -> TensorDictBase: + """Receive shared memory buffer reference from sender via their per-worker queues. + + Each worker reads from its own dedicated queue, to avoid race conditions. + + Args: + worker_idx: The worker index. + timeout: Timeout for reading from queue. + + Returns: + The shared memory weights TensorDict. + """ + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + if worker_idx not in self._weight_queues: + raise RuntimeError(f"Worker {worker_idx} not registered in queues.") + + # Read from dedicated queue for this worker + worker_queue = self._weight_queues[worker_idx] + weights = worker_queue.get(timeout=timeout) + return weights + + def send_weights(self, weights: Any) -> None: + """Update weights in-place in shared memory. + + Args: + weights: New weights to send. Can be a TensorDictBase or dict. + + Raises: + ValueError: If weights type is unsupported. + """ + # Update shared memory in-place (workers see this automatically) + if isinstance(weights, dict): + weights = TensorDict(weights) + if not isinstance(weights, TensorDictBase): + raise ValueError(f"Unsupported weights type: {type(weights)}") + # Unflatten if needed to match shared buffer structure + weights_to_update = weights + if any("." in key for key in weights.keys()): + weights_to_update = weights.unflatten_keys(".") + + for buffer in self._unique_weights: + buffer.update_(weights_to_update, non_blocking=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """No-op for shared memory - weights are already visible.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_connection(self) -> bool: + """Shared memory is always 'connected'.""" + return True + + +class SharedMemWeightSyncScheme(WeightSyncScheme): + """Weight synchronization using shared memory. + + This scheme uses shared memory for in-place weight updates. Workers + automatically see weight updates without explicit message passing. + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + + Example: + >>> # Basic usage + >>> scheme = SharedMemWeightSyncScheme() + >>> # Weights are initialized via init_on_sender() + """ + + def __init__( + self, + strategy: str = "tensordict", + ): + super().__init__(strategy) + # Create a single shared transport for all workers + self._shared_transport = SharedMemTransport() + # Create per-worker queues to avoid race conditions + # Each worker gets its own queue for weight initialization + self._weight_init_queues = {} # worker_idx -> Queue + # General message queue for coordination (if needed in future) + self._message_queue = mp.Queue() + + def init_on_sender( + self, + model_id: str | None = None, + context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ) -> None: + """Initialize on the main process (sender side). + + We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers + share the same device, the entry in the dict will be the same. + To do this, we need to know the number of workers, their assigned device, and have access to the parameters. + If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided + explicitly. + + In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the + devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that + will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing device_to_workers mapping and model access + weights: Pre-extracted weights as TensorDict (for policy factory usage) + model: Model to extract weights from + params_map: Direct mapping of worker_idx to weights on device (most explicit) + devices: List of devices for each worker + device_map_fn: Custom function to map worker_idx and weights to device-specific weights + num_workers: Number of workers (required with device_map_fn) + + Examples: + Simple usage with collector context (stateful policy): + + >>> policy = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Pre-initialized usage (policy factory): + + >>> policy_on_main = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> # Must initialize before collector creation when using policy_factory + >>> scheme.init_on_sender( + ... model_id="policy", + ... weights=TensorDict.from_module(policy_on_main), + ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ... num_workers=2, + ... ) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy_factory=[make_stateful_policy], + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + + Direct params_map usage (advanced): + + >>> weights_cpu = TensorDict.from_module(policy).share_memory_() + >>> weights_cuda = weights_cpu.to("cuda").share_memory_() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> scheme.init_on_sender( + ... model_id="policy", + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... ) + """ + # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init + # the weights on the workers. + # Scenarios: + # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating + # the transport and registering the workers etc. + # - The user provides a model or its params and a device map. We need to create the map from the params + # explicitly. + # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need + # to collect the model from the context. + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Create per-worker queues if not already created + # Collect all unique worker indices + all_workers = list(params_map.keys()) + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + + # Set worker info in transport + self._shared_transport.register_weights(params_map, self._weight_init_queues) + + # Create sender with the shared transport + sender = SharedMemWeightSender(self) + sender._model_id = model_id + sender._transport = self._shared_transport # Use shared transport + if context is not None: + sender._context_ref = weakref.ref(context) + + self._sender = sender + self._initialized_on_sender = True + + def synchronize_weights(self): + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's synchronize_weights(). + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" + ) + self._sender.synchronize_weights() + + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Get the params_map for init_on_sender().""" + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + + def init_on_worker( + self, + model_id: str, + context: Any = None, + model: Any = None, + worker_idx: int | None = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Reads from the worker's dedicated queue to receive shared weights, + then registers them in the transport. The receiver then applies these weights + to the model. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing model and worker_idx + model: Model being synchronized + worker_idx: Worker index + **kwargs: Alternative to context (model, worker_idx, timeout, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + if hasattr(context, "get_model"): + model = context.get_model(model_id) + elif model is None: + model = _resolve_model(context, model_id) + worker_idx = getattr(context, "worker_idx", worker_idx) + + # Create receiver with the shared transport + receiver = SharedMemWeightReceiver(self) + if context is not None: + receiver._context_ref = weakref.ref(context) + receiver._transport = self._shared_transport # Use shared transport + + # Register the model + receiver._register_model(model) + + # Store worker_idx for synchronize_weights + receiver._worker_idx = worker_idx + + self._receiver = receiver + self._initialized_on_worker = True + + def get_weight_queues(self): + """Get the per-worker weight initialization queues. + + Returns: + Dict mapping worker_idx to Queue for receiving shared weight references. + + Raises: + RuntimeError: If init_on_sender() hasn't been called yet. + """ + if not self._weight_init_queues: + raise RuntimeError("Queues not created. Call init_on_sender() first.") + return self._weight_init_queues + + def get_message_queue(self): + """Get the general message queue for coordination. + + Returns: + The message queue for general coordination messages. + """ + return self._message_queue + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create shared memory transport. + + Returns the shared transport instance that all workers will use. + Since this is shared memory, there's only one transport shared by all workers. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ + return self._shared_transport + + def prepare_weights( + self, + weights: Any, + model_id: str, + strategy: WeightStrategy, + context: Any = None, + ) -> Any: + """Prepare weights for SharedMemWeightSyncScheme. + + For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights + from the context (collector) to avoid extracting fresh (non-shared) weights. + + Args: + weights: Raw weights input + model_id: The model identifier + strategy: WeightStrategy for extracting/converting weights + context: Optional context (e.g., collector) for cache lookup + + Returns: + Shared memory weights ready to send + """ + # If no weights provided, check for cached shared memory weights in collector + if weights is None and context is not None: + if model_id == "policy" and hasattr(context, "_policy_weights_dict"): + policy_device = ( + context.policy_device + if not isinstance(context.policy_device, (list, tuple)) + else context.policy_device[0] + ) + cached_weights = context._policy_weights_dict.get(policy_device) + if cached_weights is not None: + return cached_weights + + # Fall back to default behavior + return super().prepare_weights(weights, model_id, strategy, context) + +class SharedMemWeightReceiver(WeightReceiver): + _transport: SharedMemTransport | None + +class SharedMemWeightSender(WeightSender): + _transport: SharedMemTransport | None \ No newline at end of file diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 2482f250d0e..735c9e59804 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -301,7 +301,7 @@ def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine): f"Initialized double-buffer receiver reading from {self._scheme.local_addr}" ) - def apply_weights(self, weights: TensorDict) -> None: + def apply_weights(self, weights: TensorDict, inplace: bool = True) -> None: """Apply weights to vLLM engine using RPC. This method uses RPC to tell all vLLM workers to load weights from @@ -310,7 +310,10 @@ def apply_weights(self, weights: TensorDict) -> None: Args: weights: TensorDict with flattened keys containing weights. + inplace: Whether to apply weights in place. Default is `True`. """ + if not inplace: + raise ValueError("Cannot apply weights out of place for vLLM double-buffer") logger.info("Applying weights to vLLM engine via RPC") # Convert TensorDict to list of (name, tensor) tuples diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 840a9883d14..f57883e5cd8 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -647,9 +647,13 @@ def init_all_workers_group( ) self._transport.init_all_workers_group(model_metadata) - def apply_weights(self, weights: Any) -> None: + def apply_weights(self, weights: Any, inplace: bool = True) -> None: """Apply weights to vLLM engine. + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + Note: For vLLM, weights are applied automatically during the collective broadcast operation. This method is a no-op but kept for API consistency. """ diff --git a/torchrl/weight_update/utils.py b/torchrl/weight_update/utils.py new file mode 100644 index 00000000000..250a1503dd0 --- /dev/null +++ b/torchrl/weight_update/utils.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Any + + +def _resolve_model(context: Any, model_id: str) -> Any: + """Resolve model_id like 'policy' or 'env.value_net' to actual object. + + Also processes getitem notation like 'env.transform[0]' to actual object. + + Args: + context: The context object (collector or inner_collector). + model_id: A string address like "policy" or "env.value_net". + + Returns: + The object at the specified address. + + Examples: + _resolve_model(collector, "policy") # -> collector.policy + _resolve_model(collector, "env.value_net") # -> collector.env.value_net + """ + parts = model_id.split(".") + obj = context + for i, part in enumerate(parts): + if "[" in part: + key, *indices = part.split("[") + indices = [int(index[:-1]) for index in indices] + try: + obj = getattr(obj, key) + except AttributeError: + raise AttributeError( + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + for index in indices: + obj = obj[index] + else: + try: + obj = getattr(obj, part) + except AttributeError: + raise AttributeError( + f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + return obj diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 265d344d401..b3e3b1870ba 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -5,41 +5,26 @@ from __future__ import annotations import abc - +import warnings import weakref -from collections.abc import Callable, Iterator +from collections.abc import Iterator from typing import Any, Literal, Protocol -import torch -import torch.distributed - from tensordict import TensorDict, TensorDictBase -from torch import multiprocessing as mp, nn +from torch import nn __all__ = [ "TransportBackend", - "MPTransport", - "SharedMemTransport", - "RayTransport", - "RayActorTransport", - "RPCTransport", - "DistributedTransport", "WeightStrategy", "WeightSender", "WeightReceiver", - "RayModuleTransformSender", - "RayModuleTransformReceiver", "WeightSyncScheme", - "MultiProcessWeightSyncScheme", - "SharedMemWeightSyncScheme", - "NoWeightSyncScheme", - "RayWeightSyncScheme", - "RayModuleTransformScheme", - "RPCWeightSyncScheme", - "DistributedWeightSyncScheme", ] +from torchrl.weight_update.utils import _resolve_model + + # ============================================================================ # Transport Layer Abstraction # ============================================================================ @@ -85,591 +70,6 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: ... -class MPTransport: - """Multiprocessing transport using pipes. - - Args: - pipe_connection (mp.Pipe): The pipe connection to use for communication. - timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. - """ - - def __init__(self, pipe_connection, timeout: float = 10.0): - self.timeout = timeout - self.pipe = pipe_connection - - def send_weights(self, weights: Any) -> None: - """Send weights through the pipe. - - Sends weights and waits for acknowledgment to ensure delivery. - """ - self.send_weights_async(weights) - self.wait_ack() - - def send_weights_async(self, weights: Any) -> None: - """Send weights through the pipe without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - self.pipe.send((weights, "update_weights")) - - def wait_ack(self) -> None: - """Wait for acknowledgment from worker.""" - self.check_ack("updated") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process). - - This method only handles weight update messages. Other messages - (like "close", "continue", etc.) are ignored and should be handled - by the main worker loop. - - Returns: - Tuple of (model_id, weights) if weights were received, None if no data available - or if a non-weight message was received. - - Note: - model_id is returned as "policy" for backward compatibility, but transports - are now bound to a single model during initialization. - """ - if self.pipe.poll(timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - weights = data_in - return "policy", weights - else: - # Not a weight update message - put it back and return None - # This allows the main worker loop to handle other messages - # Note: We can't actually "put it back", so we'll just return None - # and the message is lost. This is why receive() should only be called - # when we're expecting weight updates, not in the main message loop. - return None - # No data available - return None instead of raising TimeoutError - # This allows non-blocking checks in the worker loop - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender.""" - self.pipe.send((None, message)) - - def check_ack(self, message: str = "updated") -> None: - """Check for acknowledgment.""" - _, msg = self.pipe.recv() - if msg != message: - raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") - - def check_connection(self) -> bool: - return not self.pipe.closed - - def synchronize_weights_on_sender(self) -> None: - """No-op for MPTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for MPTransport - weights are received via receive_weights().""" - return None - - -class SharedMemTransport: - """Shared memory transport for in-place weight updates. - - This transport uses queue-based buffer distribution for initialization, then - updates shared memory tensors directly for subsequent weight updates. - Workers automatically see weight updates without explicit communication. - - Initialization flow: - - Shared memory buffers are created and sent to workers via per-worker queues - - Workers receive the buffer reference and apply weights to their models - - Subsequent updates are pure in-place shared memory (zero-copy) - - Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. - - """ - - def __init__(self): - self._params_map = None # a dict[worker_idx, TensorDictBase] map - self._weight_queues = ( - None # Dict of per-worker queues for distributing shared weights - ) - - def register_weights( - self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] - ) -> None: - """Initialize per-worker queues for shared memory buffer distribution.""" - self._weight_queues = init_queues - self._params_map = params_map - # Create set of the unique weights - self._unique_weights = [] - for weights in params_map.values(): - if id(weights) in [id(w) for w in self._unique_weights]: - continue - self._unique_weights.append(weights) - - def synchronize_weights_on_sender(self) -> None: - """Send shared memory buffer reference to workers via their per-worker queues. - - Both CPU and CUDA tensors maintain shared references through queues. - Each worker reads from its own dedicated queue, to avoid race conditions. - - """ - if self._weight_queues is None: - raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - - for worker_idx, queue in self._weight_queues.items(): - weights = self._params_map[worker_idx] - queue.put(weights) - - def synchronize_weights_on_worker( - self, worker_idx: int, timeout: float = 10.0 - ) -> TensorDictBase: - """Receive shared memory buffer reference from sender via their per-worker queues. - - Each worker reads from its own dedicated queue, to avoid race conditions. - - Args: - worker_idx: The worker index. - timeout: Timeout for reading from queue. - - Returns: - The shared memory weights TensorDict. - """ - if self._weight_queues is None: - raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - - if worker_idx not in self._weight_queues: - raise RuntimeError(f"Worker {worker_idx} not registered in queues.") - - # Read from dedicated queue for this worker - worker_queue = self._weight_queues[worker_idx] - weights = worker_queue.get(timeout=timeout) - return weights - - def send_weights(self, weights: Any) -> None: - """Update weights in-place in shared memory. - - Args: - weights: New weights to send. Can be a TensorDictBase or dict. - - Raises: - ValueError: If weights type is unsupported. - """ - # Update shared memory in-place (workers see this automatically) - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Unsupported weights type: {type(weights)}") - # Unflatten if needed to match shared buffer structure - weights_to_update = weights - if any("." in key for key in weights.keys()): - weights_to_update = weights.unflatten_keys(".") - - for buffer in self._unique_weights: - buffer.update_(weights_to_update, non_blocking=True) - if torch.cuda.is_available(): - torch.cuda.synchronize() - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """No-op for shared memory - weights are already visible.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_connection(self) -> bool: - """Shared memory is always 'connected'.""" - return True - - -class RayTransport: - """Ray transport for communicating with a single Ray collector actor. - - This transport handles weight updates for ONE specific remote collector. - Multiple transports are created for multiple collectors, following the - same pattern as multiprocess collectors. - """ - - def __init__( - self, - remote_collector=None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayTransport") - self._remote_collector = remote_collector - self._tensor_transport = tensor_transport - - def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via Ray.""" - if self._remote_collector is None: - return - - # Put weights in Ray's object store for efficient distribution - # Ray will automatically deduplicate if the same weights are sent to multiple actors - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - # Send to the remote collector and wait for completion - # This ensures weights are applied before we continue - future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - self.ray.wait([future], num_returns=1) - - def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._remote_collector is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - self._pending_future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - - def wait_ack(self) -> None: - """Wait for the remote collector to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.wait([self._pending_future], num_returns=1) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if Ray is initialized.""" - return self.ray.is_initialized() - - def synchronize_weights_on_sender(self) -> None: - """No-op for RayTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RayTransport - weights are received via remote method calls.""" - return None - - -class RayActorTransport: - """Ray transport for communicating with Ray actors (not collectors). - - This transport is designed for updating models hosted within Ray actors, - such as RayModuleTransform instances. It directly calls the actor's - update_weights method rather than going through collector update methods. - """ - - def __init__( - self, - actor_ref=None, - update_method: str = "tensordict", - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayActorTransport") - - self._actor_ref = actor_ref - self._update_method = update_method - self._tensor_transport = tensor_transport - - def set_actor(self, actor_ref): - """Set the Ray actor reference to communicate with.""" - self._actor_ref = actor_ref - - def send_weights(self, weights: Any) -> None: - """Send weights to the Ray actor.""" - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self.ray.get( - self._actor_ref._update_weights_tensordict.remote(params=weights_ref) - ) - elif self._update_method == "state_dict": - self.ray.get( - self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def send_weights_async(self, weights: Any) -> None: - """Send weights to Ray actor without waiting for completion. - - Use wait_ack() to wait for completion after sending to all actors. - """ - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self._pending_future = self._actor_ref._update_weights_tensordict.remote( - params=weights_ref - ) - elif self._update_method == "state_dict": - self._pending_future = self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.get(self._pending_future) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray actor workers receive weights through direct method calls.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_connection(self) -> bool: - """Check if Ray is initialized and actor exists.""" - if not self.ray.is_initialized(): - return False - if self._actor_ref is None: - return False - return True - - def synchronize_weights_on_sender(self) -> None: - """No-op for RayActorTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RayActorTransport - weights are received via remote method calls.""" - return None - - -class RPCTransport: - """RPC transport for communicating with a single RPC remote collector. - - This transport handles weight updates for ONE specific remote collector via - torch.distributed.rpc. Multiple transports are created for multiple collectors, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, collector_info=None, collector_rref=None, collector_class=None): - self._collector_info = collector_info - self._collector_rref = collector_rref - self._collector_class = collector_class - - def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via RPC.""" - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights to the remote collector and wait for completion - rpc.rpc_sync( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights asynchronously - self._pending_future = rpc.rpc_async( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def wait_ack(self) -> None: - """Wait for the RPC call to complete.""" - if hasattr(self, "_pending_future"): - self._pending_future.wait() - del self._pending_future - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """RPC workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if RPC is initialized.""" - from torch.distributed import rpc - - return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True - - def synchronize_weights_on_sender(self) -> None: - """No-op for RPCTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RPCTransport - weights are received via RPC calls.""" - return None - - -class DistributedTransport: - """torch.distributed transport for communicating with a single distributed worker. - - This transport handles weight updates for ONE specific distributed worker via - torch.distributed send/recv. Multiple transports are created for multiple workers, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, store=None, rank=None, sync=True): - """Initialize the DistributedTransport. - - Args: - store: TCPStore for communication. - rank: Worker rank (1-indexed). - sync: Whether to use synchronous weight updates. - """ - self._store = store - self._rank = rank - self._sync = sync - self._weights_buffer = None # TensorDict buffer for receiving weights - - def send_weights(self, weights: Any) -> None: - """Send weights to the distributed worker.""" - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - # Wait for acknowledgment - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def send_weights_async(self, weights: Any) -> None: - """Send weights to distributed worker without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - def wait_ack(self) -> None: - """Wait for acknowledgment from distributed worker.""" - if self._store is None or self._rank is None: - return - - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights via torch.distributed, using TCPStore for signaling. - - This implements the RPC-like pattern: - 1. Check TCPStore for signal (non-blocking) - 2. If signal present, receive weights via torch.distributed - 3. Clean up signal and send acknowledgment - - Args: - timeout: Timeout for receiving (currently not used for TCPStore check) - - Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. - """ - if self._store is None or self._rank is None: - return None - - try: - # Non-blocking check of TCPStore "mailbox" for signal - msg = self._store.get(f"NODE_{self._rank}_in") - - if msg == b"update_weights": - # Initialize weights buffer on first use - if self._weights_buffer is None: - self._weights_buffer = TensorDict() - - # Receive weights via torch.distributed - # recv() and irecv() update the TensorDict in place - if self._sync: - self._weights_buffer.recv(src=0) - else: - # irecv() blocks until weights are received - self._weights_buffer.irecv(src=0) - - # Clean up the signal - self._store.delete_key(f"NODE_{self._rank}_in") - - # Note: Acknowledgment is sent separately via send_ack() if transport supports it - # This matches the pattern in WeightReceiver.receive() - - # Return model_id and received weights - # For distributed transport, we use "policy" as default model_id - return ("policy", self._weights_buffer) - else: - raise ValueError(f"Expected 'update_weights' but got {msg}") - except KeyError: - # No message in store - no weights available - return None - - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender via TCPStore. - - Args: - message: Acknowledgment message to send (default: "updated") - """ - if self._store is None or self._rank is None: - return - - self._store.set(f"NODE_{self._rank}_out", message.encode()) - - def check_connection(self) -> bool: - """Check if torch.distributed is initialized.""" - return torch.distributed.is_initialized() - - def synchronize_weights_on_sender(self) -> None: - """No-op for DistributedTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for DistributedTransport - weights are received via receive_weights().""" - return None - - # ============================================================================ # Weight Strategies # ============================================================================ @@ -691,6 +91,11 @@ class WeightStrategy: """ def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"): + if extract_as == "state_dict": + warnings.warn( + "state_dict strategy is experimental. Use tensordict strategy for safer weight updates.", + UserWarning, + ) if extract_as not in ("tensordict", "state_dict"): raise ValueError( f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}" @@ -722,7 +127,7 @@ def extract_weights(self, source: Any) -> Any: raise ValueError( f"Unsupported source type for TensorDict extraction: {type(source)}" ) - else: # state_dict + elif self.extract_as == "state_dict": # state_dict # Extract as state_dict if isinstance(source, nn.Module): return source.state_dict() @@ -730,13 +135,19 @@ def extract_weights(self, source: Any) -> Any: return source elif isinstance(source, TensorDictBase): # Convert TensorDict to state_dict - return source.to_dict() + return source.flatten_keys().to_dict() else: raise ValueError( f"Unsupported source type for state_dict extraction: {type(source)}" ) + else: + raise ValueError( + f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'." + ) - def apply_weights(self, destination: Any, weights: Any) -> None: + def apply_weights( + self, destination: Any, weights: Any, inplace: bool = True + ) -> None: """Apply weights to destination model. The format is automatically detected from the weights type: @@ -749,6 +160,7 @@ def apply_weights(self, destination: Any, weights: Any) -> None: - TensorDictBase: TensorDict - dict: State dictionary weights: The weights to apply (dict or TensorDictBase). + inplace: Whether to apply weights in place. """ if weights is None: return @@ -760,30 +172,35 @@ def apply_weights(self, destination: Any, weights: Any) -> None: weights = weights.unflatten_keys(".") if isinstance(destination, nn.Module): # Do not update in-place - weights.to_module(destination) - return + if not inplace: + weights.to_module(destination) + return + else: + destination = TensorDict.from_module(destination) elif isinstance(destination, dict): + if not inplace: + raise ValueError("Cannot update state_dict out of place") destination = TensorDict(destination) if any(isinstance(key, str) and "." in key for key in destination.keys()): destination = destination.unflatten_keys(".") - if isinstance(weights, TensorDictBase): - # Apply TensorDict format - if isinstance(destination, TensorDictBase): - try: - destination.data.update_(weights.data) - except Exception as e: - raise KeyError( - f"Error updating destination: {e}. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" - ) - else: - raise ValueError( - f"Unsupported destination type for TensorDict: {type(destination)}" - ) - else: + if not isinstance(weights, TensorDictBase) or not isinstance( + destination, TensorDictBase + ): raise ValueError( - f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase." + f"Unsupported weights or destination type: {type(weights)=} or {type(destination)=}. Expected TensorDictBase." ) + # Apply TensorDict format + try: + if not inplace: + destination.update(weights) + else: + destination.data.update_(weights.data) + except Exception as e: + raise KeyError( + f"Error updating destination. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" + ) from e + return def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy: @@ -905,13 +322,12 @@ def send( "Cannot call send() while an async send is pending. Call wait_async() first." ) - model_id = getattr(self, "_model_id", "policy") context = self._context_ref() if self._context_ref is not None else None # Let the scheme prepare the weights prepared_weights = self._scheme.prepare_weights( weights=weights, - model_id=model_id, + model_id=self._model_id, strategy=self._strategy, context=context, ) @@ -954,13 +370,12 @@ def send_async( "Cannot call send_async() again while a previous send is pending. Call wait_async() first." ) - model_id = getattr(self, "_model_id", "policy") context = self._context_ref() if self._context_ref is not None else None # Let the scheme prepare the weights prepared_weights = self._scheme.prepare_weights( weights=weights, - model_id=model_id, + model_id=self._model_id, strategy=self._strategy, context=context, ) @@ -1003,17 +418,16 @@ def synchronize_weights(self) -> None: """Synchronize weights with workers before collection starts. This method is called once after workers are initialized to send - the initial weights. For most transports this is a no-op (weights - are sent via send()). For SharedMemTransport, this sends buffer - references via queues. + the initial weights. For SharedMemTransport, this sends buffer + references via queues. For MultiProcessWeightSyncScheme (MPTransport), + this extracts and sends initial weights via pipes. This is different from send() which is called during training to update weights. """ - # Iterate over all transports and call synchronize_weights_on_sender + # For other schemes (SharedMemWeightSyncScheme, etc.), use transport's method for transport in self._iterate_transports(): - if hasattr(transport, "synchronize_weights_on_sender"): - transport.synchronize_weights_on_sender() + transport.synchronize_weights_on_sender() def update_weights(self, weights: Any) -> None: """Send weights to ALL workers for this model. @@ -1155,17 +569,21 @@ def synchronize_weights(self, worker_idx: int | None = None) -> None: weights = self._transport.synchronize_weights_on_worker(worker_idx) # Apply weights to model if received (SharedMemTransport case) - if weights is not None and self._model_ref is not None: - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - else: - raise ValueError("Failed to synchronize weights") + # For other transports (MPTransport, etc.), weights is None and synchronization + # happens later via receive(), so this is a no-op + if weights is not None: + if self._model_ref is not None: + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights, inplace=False) + else: + raise ValueError("Received weights but no model registered") - def apply_weights(self, weights: Any) -> None: + def apply_weights(self, weights: Any, inplace: bool = True) -> None: """Apply received weights to registered model. Args: weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. Note: Convenience method. Normally weights are received and applied via receive() in the worker loop. @@ -1174,7 +592,7 @@ def apply_weights(self, weights: Any) -> None: raise ValueError("No model registered") model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) + self._strategy.apply_weights(model, weights, inplace=inplace) # Send acknowledgment if transport supports it if hasattr(self._transport, "send_ack"): @@ -1202,123 +620,19 @@ def __setstate__(self, state): self.__dict__.update(state) -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. +# ============================================================================ +# Weight Synchronization Schemes +# ============================================================================ - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None - - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). - - This is now handled by init_on_sender(). Only kept for internal use. - - Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). - """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id - - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). - - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. - """ - - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. - - Args: - weights: Weights to send. - """ - if self._single_transport is None: - self._initialize_transport() - - if self._single_transport is not None: - self._single_transport.send_weights(weights) - - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return - - context = self._context_ref() - if context is None: - return - - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(model) - - -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. - - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - - def _register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport (internal). - - This is now handled by init_on_worker(). Only kept for internal use. - - Args: - actor_or_context: Either a Ray actor reference or a context object. - """ - self._transport = self._scheme.create_transport(actor_or_context) - - def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model. - - For Ray actors, weights are applied directly to the module - within the actor's process space. - - Args: - weights: The weights to apply. - """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - - -# ============================================================================ -# Weight Synchronization Schemes -# ============================================================================ - - -class WeightSyncScheme(metaclass=abc.ABCMeta): - """Configuration for how to synchronize ONE model across workers. +class WeightSyncScheme(metaclass=abc.ABCMeta): + """Configuration for how to synchronize ONE model across workers. A scheme manages synchronization of ONE model across workers. The collector maintains a dict of {model_id: scheme} pairs. """ - def __init__(self, strategy: Literal["state_dict", "tensordict"] = "state_dict"): + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): self.strategy = strategy self._sender = None self._receiver = None @@ -1500,895 +814,3 @@ def prepare_weights( else: # Already extracted weights (TensorDict, dict, etc.) return weights - - -class MultiProcessWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for multiprocess operations using pipes. - - This scheme creates transports that communicate via multiprocessing pipes. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes and num_workers - **kwargs: Alternative to context (pipes, num_workers, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 - - # Create sender and register all workers - sender = WeightSender(self) - sender._model_id = model_id - if context is not None: - sender._context_ref = weakref.ref(context) - - for worker_idx, pipe in enumerate(pipes): - sender._register_worker(worker_idx, pipe) - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipe = getattr(context, "pipe", None) - if hasattr(context, "get_model"): - model = context.get_model(model_id) - else: - model = None - else: - pipe = kwargs.get("pipe") - model = kwargs.get("model") - - if pipe is None: - raise ValueError("pipe must be provided via context or kwargs") - - # Create receiver and register model - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._register_worker_transport(pipe) - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe. - - Note: - This is used internally by init_on_sender/init_on_worker. - """ - return MPTransport(pipe) - - -class SharedMemWeightSyncScheme(WeightSyncScheme): - """Weight synchronization using shared memory. - - This scheme uses shared memory for in-place weight updates. Workers - automatically see weight updates without explicit message passing. - - Args: - strategy: The weight transmission strategy (default: "tensordict"). - - Example: - >>> # Basic usage - >>> scheme = SharedMemWeightSyncScheme() - >>> # Weights are initialized via init_on_sender() - """ - - def __init__( - self, - strategy: str = "tensordict", - ): - super().__init__(strategy) - # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport() - # Create per-worker queues to avoid race conditions - # Each worker gets its own queue for weight initialization - self._weight_init_queues = {} # worker_idx -> Queue - # General message queue for coordination (if needed in future) - self._message_queue = mp.Queue() - - def init_on_sender( - self, - model_id: str | None = None, - context: Any = None, - weights: TensorDictBase | None = None, - model: nn.Module | None = None, - params_map: dict[int, TensorDictBase] | None = None, - devices: list[torch.device] | None = None, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, - num_workers: int | None = None, - ) -> None: - """Initialize on the main process (sender side). - - We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers - share the same device, the entry in the dict will be the same. - To do this, we need to know the number of workers, their assigned device, and have access to the parameters. - If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided - explicitly. - - In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the - devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that - will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing device_to_workers mapping and model access - weights: Pre-extracted weights as TensorDict (for policy factory usage) - model: Model to extract weights from - params_map: Direct mapping of worker_idx to weights on device (most explicit) - devices: List of devices for each worker - device_map_fn: Custom function to map worker_idx and weights to device-specific weights - num_workers: Number of workers (required with device_map_fn) - - Examples: - Simple usage with collector context (stateful policy): - - >>> policy = make_stateful_policy() - >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[lambda: GymEnv("CartPole-v1")], - ... policy=policy, - ... frames_per_batch=100, - ... total_frames=1000, - ... weight_sync_schemes={"policy": scheme}, - ... ) - >>> # scheme.init_on_sender() is called automatically by collector - - Pre-initialized usage (policy factory): - - >>> policy_on_main = make_stateful_policy() - >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") - >>> # Must initialize before collector creation when using policy_factory - >>> scheme.init_on_sender( - ... model_id="policy", - ... weights=TensorDict.from_module(policy_on_main), - ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], - ... num_workers=2, - ... ) - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[lambda: GymEnv("CartPole-v1")], - ... policy_factory=[make_stateful_policy], - ... frames_per_batch=100, - ... total_frames=1000, - ... weight_sync_schemes={"policy": scheme}, - ... ) - - Direct params_map usage (advanced): - - >>> weights_cpu = TensorDict.from_module(policy).share_memory_() - >>> weights_cuda = weights_cpu.to("cuda").share_memory_() - >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") - >>> scheme.init_on_sender( - ... model_id="policy", - ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, - ... ) - """ - # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init - # the weights on the workers. - # Scenarios: - # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating - # the transport and registering the workers etc. - # - The user provides a model or its params and a device map. We need to create the map from the params - # explicitly. - # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need - # to collect the model from the context. - params_map = self._get_params_map( - context=context, - model_id=model_id, - weights=weights, - model=model, - params_map=params_map, - devices=devices, - device_map_fn=device_map_fn, - num_workers=num_workers, - ) - - # Create per-worker queues if not already created - # Collect all unique worker indices - all_workers = list(params_map.keys()) - - for worker_idx in all_workers: - if worker_idx not in self._weight_init_queues: - self._weight_init_queues[worker_idx] = mp.Queue() - - # Set worker info in transport - self._shared_transport.register_weights(params_map, self._weight_init_queues) - - # Create sender with the shared transport - sender = WeightSender(self) - sender._model_id = model_id - sender._transport = self._shared_transport # Use shared transport - if context is not None: - sender._context_ref = weakref.ref(context) - - self._sender = sender - self._initialized_on_sender = True - - def synchronize_weights(self): - """Method to be called once the workers have started. - - Triggers a rendez-vous for the workers to receive their copy of the weights. - - This is a convenience method that delegates to the sender's synchronize_weights(). - """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" - ) - self._sender.synchronize_weights() - - def _get_params_map( - self, - context: Any = None, - model_id: str | None = None, - weights: TensorDictBase | None = None, - model: nn.Module | None = None, - params_map: dict[int, TensorDictBase] | None = None, - devices: list[torch.device] | None = None, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, - num_workers: int | None = None, - ): - """Get the params_map for init_on_sender().""" - if params_map is not None: - # Sanity check: params_map must be a dict[int, TensorDictBase] - # All other args must be None - if ( - not isinstance(params_map, dict) - or not all(isinstance(v, int) for v in params_map.keys()) - or not all(isinstance(v, TensorDictBase) for v in params_map.values()) - ): - raise ValueError("params_map must be a dict[int, TensorDictBase]") - if model_id is not None or weights is not None or model is not None: - raise ValueError( - "model_id, weights, and model cannot be provided if params_map is provided" - ) - if context is not None: - raise ValueError("context cannot be provided if params_map is provided") - if devices is not None: - raise ValueError("devices cannot be provided if params_map is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if params_map is provided" - ) - if num_workers is not None: - raise ValueError( - "num_workers cannot be provided if params_map is provided" - ) - return params_map - elif context is not None: - if devices is not None: - raise ValueError("devices cannot be provided if context is provided") - # Sanity check: model_id must be provided if context is provided - # All other args must be None - if model_id is None: - raise ValueError("model_id must be provided if context is provided") - if model is not None: - raise ValueError("model cannot be provided if context is provided") - if weights is not None: - raise ValueError("weights cannot be provided if context is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if context is provided" - ) - # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later - devices = context.policy_device - if num_workers is not None and num_workers != len(devices): - raise ValueError( - "num_workers cannot be provided if context is provided" - ) - # Get the weights - model = _resolve_model(context, model_id) - weights = TensorDict.from_module(model) - elif model is not None: - if weights is not None: - raise ValueError("weights cannot be provided if model is provided") - weights = TensorDict.from_module(model) - # To make the map, we need the list of devices, or the map fn - if devices is not None: - # Import _cast locally to avoid circular imports - from torchrl.collectors.utils import _cast - - # Get the unique devices - devices_set = set(devices) - weights_devices = {p.device for p in weights.values(True, True)} - if len(weights_devices) == 1: - weights_device = weights_devices.pop() - else: - weights_device = None - - # Create device map with proper Parameter handling using _cast - # _cast ensures Parameters stay as Parameters (with requires_grad=False) - device_map = {} - for d in devices_set: - if d != weights_device: - # Move to device and apply _cast to preserve Parameter/Buffer types - weights_on_device = weights.to(d) - weights_on_device = weights_on_device.apply(_cast, weights) - device_map[d] = weights_on_device - else: - # Already on correct device, just apply _cast - device_map[d] = weights.apply(_cast, weights) - - # Create the map - params_map = { - worker_idx: device_map[device] - for worker_idx, device in enumerate(devices) - } - return params_map - if device_map_fn is not None: - return { - worker_idx: device_map_fn(worker_idx, weights) - for worker_idx in range(num_workers) - } - raise ValueError( - "Either params_map, model_id + context or model/weights + devices must be provided." - ) - - def init_on_worker( - self, - model_id: str, - context: Any = None, - model: Any = None, - worker_idx: int | None = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Reads from the worker's dedicated queue to receive shared weights, - then registers them in the transport. The receiver then applies these weights - to the model. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing model and worker_idx - model: Model being synchronized - worker_idx: Worker index - **kwargs: Alternative to context (model, worker_idx, timeout, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - if hasattr(context, "get_model"): - model = context.get_model(model_id) - elif model is None: - model = _resolve_model(context, model_id) - worker_idx = getattr(context, "worker_idx", worker_idx) - - # Create receiver with the shared transport - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._transport = self._shared_transport # Use shared transport - - # Register the model - receiver._register_model(model) - - # Store worker_idx for synchronize_weights - receiver._worker_idx = worker_idx - - self._receiver = receiver - self._initialized_on_worker = True - - def get_weight_queues(self): - """Get the per-worker weight initialization queues. - - Returns: - Dict mapping worker_idx to Queue for receiving shared weight references. - - Raises: - RuntimeError: If init_on_sender() hasn't been called yet. - """ - if not self._weight_init_queues: - raise RuntimeError("Queues not created. Call init_on_sender() first.") - return self._weight_init_queues - - def get_message_queue(self): - """Get the general message queue for coordination. - - Returns: - The message queue for general coordination messages. - """ - return self._message_queue - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport. - - Returns the shared transport instance that all workers will use. - Since this is shared memory, there's only one transport shared by all workers. - - Note: - This is used internally by init_on_sender/init_on_worker. - """ - return self._shared_transport - - def prepare_weights( - self, - weights: Any, - model_id: str, - strategy: WeightStrategy, - context: Any = None, - ) -> Any: - """Prepare weights for SharedMemWeightSyncScheme. - - For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights - from the context (collector) to avoid extracting fresh (non-shared) weights. - - Args: - weights: Raw weights input - model_id: The model identifier - strategy: WeightStrategy for extracting/converting weights - context: Optional context (e.g., collector) for cache lookup - - Returns: - Shared memory weights ready to send - """ - # If no weights provided, check for cached shared memory weights in collector - if weights is None and context is not None: - if model_id == "policy" and hasattr(context, "_policy_weights_dict"): - policy_device = ( - context.policy_device - if not isinstance(context.policy_device, (list, tuple)) - else context.policy_device[0] - ) - cached_weights = context._policy_weights_dict.get(policy_device) - if cached_weights is not None: - return cached_weights - - # Fall back to default behavior - return super().prepare_weights(weights, model_id, strategy, context) - - -class NoWeightSyncScheme(WeightSyncScheme): - """No-op weight synchronization scheme. - - This scheme disables weight synchronization entirely. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op sender - sender = WeightSender(self) - sender._model_id = model_id - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op receiver - receiver = WeightReceiver(self) - receiver._model_ref = model_id - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create a no-op transport. - - Note: - This is used internally by init_on_sender/init_on_worker. - """ - # Return a dummy transport that does nothing - class NoOpTransport: - def send_weights(self, weights: Any) -> None: - pass - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - return None - - def check_connection(self) -> bool: - return True - - return NoOpTransport() - - -class RayWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for Ray distributed computing. - - This scheme uses Ray's object store and remote calls to synchronize weights - across distributed workers (Ray actors). - - Each remote collector gets its own transport, following the same pattern - as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create Ray-based transport for a specific remote collector. - - Args: - pipe_or_context: The Ray actor handle for the remote collector. - - Returns: - RayTransport configured for this specific remote collector. - """ - return RayTransport(remote_collector=pipe_or_context) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing remote_collectors - **kwargs: Alternative to context (remote_collectors, source_model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - remote_collectors = getattr(context, "remote_collectors", None) - num_workers = getattr(context, "num_workers", None) or getattr( - context, "num_collectors", None - ) - else: - remote_collectors = kwargs.get("remote_collectors") - num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") - - if remote_collectors is None: - raise ValueError("remote_collectors must be provided via context or kwargs") - if num_workers is None: - num_workers = len(remote_collectors) if remote_collectors else 0 - - # Create sender and register all workers (Ray actors) - sender = WeightSender(self) - sender._model_id = model_id - - # Register each Ray actor - _register_worker will create the transport - for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker(worker_idx, remote_collector) - - # Set context with weak reference to avoid circular refs - if context is not None: - sender._set_context(weakref.ref(context), model_id) - - # Store source model reference if provided for automatic weight extraction - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - For Ray workers, weight updates are handled via remote method calls, - so this is typically a no-op. The receiver is created but doesn't - need special initialization. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the remote collector) - **kwargs: Optional parameters (pipe, model, etc.) - """ - # Create receiver - receiver = WeightReceiver(self) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "policy", None) if context else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True - - -class RayModuleTransformScheme(WeightSyncScheme): - """Weight synchronization for RayModuleTransform actors. - - This scheme is designed specifically for updating models hosted within - Ray actors, such as RayModuleTransform instances. It creates a transport - that directly calls the actor's weight update methods. - - Args: - strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). - Default is "tensordict". - """ - - def __init__(self, strategy: str = "tensordict"): - super().__init__(strategy) - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RayActorTransport for the given actor. - - Args: - pipe_or_context: Either a Ray actor reference or a context object - from which to extract the actor reference. - - Returns: - RayActorTransport configured with the actor reference. - """ - actor_ref = self._extract_actor_ref(pipe_or_context) - return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) - - def _extract_actor_ref(self, pipe_or_context: Any) -> Any: - """Extract the Ray actor reference from the context. - - Args: - pipe_or_context: Either a direct actor reference or an object - with an `_actor` attribute. - - Returns: - The Ray actor reference. - """ - if hasattr(pipe_or_context, "_actor"): - return pipe_or_context._actor - return pipe_or_context - - def create_sender(self) -> RayModuleTransformSender: - """Create a specialized sender for Ray actor communication.""" - return RayModuleTransformSender(self) - - def create_receiver(self) -> RayModuleTransformReceiver: - """Create a specialized receiver for Ray actor communication.""" - return RayModuleTransformReceiver(self) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing actor references - **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) - """ - # Extract actor references from context or kwargs - if context is not None: - # Could be actor_refs, actors, or remote_collectors - actor_refs = ( - getattr(context, "actor_refs", None) - or getattr(context, "actors", None) - or getattr(context, "remote_collectors", None) - ) - else: - actor_refs = ( - kwargs.get("actor_refs") - or kwargs.get("actors") - or kwargs.get("remote_collectors") - ) - - if actor_refs is None: - raise ValueError( - "actor_refs (or actors) must be provided via context or kwargs" - ) - - # Create specialized sender - sender = self.create_sender() - sender._model_id = model_id - - # Register all actors - _register_worker will create the transport - for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker(worker_idx, actor_ref) - - # Set context with weak reference - if context is not None: - sender._set_context(weakref.ref(context), model_id) - - # Store source model if provided - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the actor itself) - **kwargs: Optional parameters (actor_ref, model, etc.) - """ - # Create specialized receiver - receiver = self.create_receiver() - - # Extract actor reference if needed - actor_ref = kwargs.get("actor_ref") or context - if actor_ref is not None: - # Register the transport for this actor - transport = self.create_transport(actor_ref) - receiver._register_worker_transport(transport) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "_actor_module", None) or getattr(context, "module", None) - if context - else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True - - -class RPCWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed.rpc. - - This scheme uses RPC calls to synchronize weights across distributed - workers. Each remote collector gets its own transport, following the - same pattern as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RPC-based transport for a specific remote collector. - - Args: - pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) - for the remote collector. - - Returns: - RPCTransport configured for this specific remote collector. - """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: - collector_info, collector_rref, collector_class = pipe_or_context - return RPCTransport( - collector_info=collector_info, - collector_rref=collector_rref, - collector_class=collector_class, - ) - # If just passed the info directly - return RPCTransport(collector_info=pipe_or_context) - - -class DistributedWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed. - - This scheme uses torch.distributed primitives (send/recv) to synchronize - weights across distributed workers. Each worker gets its own transport, - following the same pattern as multiprocess collectors. - - Args: - backend (str): The distributed backend ("gloo", "nccl", etc.) - sync (bool): Whether to use synchronous weight updates - """ - - def __init__(self, backend: str = "gloo", sync: bool = True): - super().__init__() - self.backend = backend - self.sync = sync - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create distributed transport for a specific worker. - - Args: - pipe_or_context: A tuple of (store, rank) for the worker. - - Returns: - DistributedTransport configured for this specific worker. - """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: - store, rank = pipe_or_context - return DistributedTransport(store=store, rank=rank, sync=self.sync) - # Fallback - shouldn't normally happen - return DistributedTransport() - - -# ============================================================================ -# Helper Functions -# ============================================================================ - - -def _resolve_model(context: Any, model_id: str) -> Any: - """Resolve model_id like 'policy' or 'env.value_net' to actual object. - - Also processes getitem notation like 'env.transform[0]' to actual object. - - Args: - context: The context object (collector or inner_collector). - model_id: A string address like "policy" or "env.value_net". - - Returns: - The object at the specified address. - - Examples: - _resolve_model(collector, "policy") # -> collector.policy - _resolve_model(collector, "env.value_net") # -> collector.env.value_net - """ - parts = model_id.split(".") - obj = context - for i, part in enumerate(parts): - if "[" in part: - key, *indices = part.split("[") - indices = [int(index[:-1]) for index in indices] - try: - obj = getattr(obj, key) - except AttributeError: - raise AttributeError( - f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - for index in indices: - obj = obj[index] - else: - try: - obj = getattr(obj, part) - except AttributeError: - raise AttributeError( - f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - return obj From ba39d0a5ee74e03aef2157a8b54894e08424a27e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 14 Nov 2025 17:25:13 +0000 Subject: [PATCH 13/17] final! --- .../reference/collectors_weightsync.rst | 4 +- test/test_weightsync.py | 159 +++++++++--------- torchrl/collectors/_runner.py | 4 +- torchrl/weight_update/_mp.py | 50 +++++- torchrl/weight_update/_noupdate.py | 42 ++++- torchrl/weight_update/_ray.py | 94 ++++++++++- torchrl/weight_update/_shared.py | 141 +++++++++++++++- torchrl/weight_update/weight_sync_schemes.py | 20 +-- 8 files changed, 407 insertions(+), 107 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 6e73e2a91f6..e57b6e7dc38 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -49,7 +49,7 @@ Weight update schemes can be used outside of collectors for custom synchronizati The new simplified API provides four core methods for weight synchronization: - ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side +- ``init_on_receiver(model_id, **kwargs)`` - Initialize on worker process side - ``get_sender()`` - Get the configured sender instance - ``get_receiver()`` - Get the configured receiver instance @@ -85,7 +85,7 @@ Here's a basic example: # or sender.send_async(weights); sender.wait_async() # Asynchronous send # On the worker process side: - # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) + # scheme.init_on_receiver(model_id="policy", pipe=child_pipe, model=policy) # receiver = scheme.get_receiver() # # Non-blocking check for new weights # if receiver.receive(timeout=0.001): diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 022055cd659..b75186c4afe 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -6,7 +6,9 @@ import argparse import importlib.util + import pickle +import threading import time import pytest @@ -26,12 +28,10 @@ RayWeightSyncScheme, RPCWeightSyncScheme, SharedMemTransport, -) -from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import ( SharedMemWeightSyncScheme, WeightStrategy, ) +from torchrl.weight_update.utils import _resolve_model _has_ray = importlib.util.find_spec("ray") is not None @@ -43,7 +43,7 @@ def worker_update_policy(pipe, timeout=5.0): policy.bias.fill_(0.0) scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) receiver = scheme.get_receiver() if receiver._transport.pipe.poll(timeout): @@ -62,7 +62,7 @@ def worker_update_policy_tensordict(pipe, timeout=5.0): policy.bias.fill_(0.0) scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) receiver = scheme.get_receiver() if receiver._transport.pipe.poll(timeout): @@ -100,7 +100,7 @@ def test_mp_transport_basic(self): proc.start() test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", test_weights) + transport.send_weights(test_weights) proc.join(timeout=10.0) assert not proc.is_alive() @@ -113,7 +113,7 @@ def test_mp_transport_async(self): proc.start() test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights_async("policy", test_weights) + transport.send_weights_async(test_weights) transport.wait_ack() proc.join(timeout=10.0) @@ -124,13 +124,16 @@ def test_shared_mem_transport(self): {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] ).share_memory_() - transport = SharedMemTransport({"policy": shared_buffer}) + transport = SharedMemTransport() + transport.register_weights( + params_map={0: shared_buffer}, init_queues={0: mp.Queue()} + ) new_weights = TensorDict( {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] ) - transport.send_weights("policy", new_weights) + transport.send_weights(new_weights) assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) assert torch.allclose(shared_buffer["bias"], torch.ones(2)) @@ -255,7 +258,10 @@ def test_shared_mem_scheme(self): {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] ) - transport.send_weights("policy", new_weights) + transport.register_weights( + params_map={0: shared_buffer}, init_queues={0: mp.Queue()} + ) + transport.send_weights(new_weights) assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) assert torch.allclose(shared_buffer["bias"], torch.ones(2)) @@ -265,7 +271,7 @@ def test_no_weight_sync_scheme(self): transport = scheme.create_transport(None) weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", weights) + transport.send_weights(weights) @classmethod def _worker_with_receive(cls, pipe, scheme): @@ -274,7 +280,7 @@ def _worker_with_receive(cls, pipe, scheme): policy.weight.fill_(0.0) policy.bias.fill_(0.0) - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) receiver = scheme.get_receiver() # Non-blocking receive should return False when no data @@ -354,7 +360,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy): collector.shutdown() def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + scheme = MultiProcessWeightSyncScheme() collector = MultiSyncDataCollector( create_env_fn=[ @@ -660,73 +666,76 @@ def test_multiprocess_scheme_serialize_after_sender_init(self): parent_pipe.close() child_pipe.close() - def test_shared_mem_scheme_serialize_before_init(self): - """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" - scheme = SharedMemWeightSyncScheme(strategy="tensordict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None + # Serialize and deserialize + @staticmethod + def _get_scheme_from_queue(q, scheme): + try: + restored = scheme + # Check that configuration is preserved but runtime state is cleared + assert restored.strategy == "tensordict" + assert restored._sender is None + assert not restored._initialized_on_sender + + q.put("success") + except Exception as err: + q.put(f"failure: {err}") + finally: + q.close() + @pytest.mark.timeout(10) def test_shared_mem_scheme_serialize_after_init(self): """Test that initialized SharedMemWeightSyncScheme can be pickled.""" parent_pipe, child_pipe = mp.Pipe() + q = mp.Queue() + try: + # Create shared buffer + shared_buffer = TensorDict( + {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] + ).share_memory_() + + scheme = SharedMemWeightSyncScheme() + + def init_on_sender(scheme, pipe): + scheme.init_on_sender(params_map={0: shared_buffer}) + scheme.synchronize_weights() + msg = pipe.recv() + assert msg == "registered" + + def init_on_receiver(scheme: SharedMemWeightSyncScheme, child_pipe): + scheme.init_on_receiver( + worker_idx=0, model=nn.Linear(4, 2, device="meta") + ) + scheme.synchronize_weights() + child_pipe.send("registered") + + future_sender = threading.Thread( + target=init_on_sender, + kwargs={"scheme": scheme, "pipe": parent_pipe}, + ) + future_receiver = threading.Thread( + target=init_on_receiver, + kwargs={"scheme": scheme, "child_pipe": child_pipe}, + ) + future_receiver.start() + future_sender.start() + future_receiver.join(timeout=10.0) + future_sender.join(timeout=10.0) - # Create shared buffer - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - strategy="tensordict", - ) - - def init_on_sender(scheme, child_pipe): - (model_id, data), msg = child_pipe.recv() - if msg == "register_shared_weights": - child_pipe.send((None, "registered")) - else: - raise ValueError(f"Expected 'register_shared_weights' but got {msg}") - - # Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker - import threading - - future_sender = threading.Thread( - target=scheme.init_on_sender, - kwargs={"model_id": "policy", "pipes": [parent_pipe]}, - ) - future_receiver = threading.Thread( - target=init_on_sender, - kwargs={"scheme": scheme, "child_pipe": child_pipe}, - ) - future_receiver.start() - future_sender.start() - future_receiver.join() - future_sender.join() - - # Scheme now has _sender with non-serializable state - assert scheme._sender is not None - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "tensordict" - assert restored._sender is None - assert not restored._initialized_on_sender - - # Note: policy_weights dict is preserved (but may need re-sharing) - assert "policy" in restored.policy_weights + # Scheme now has _sender with non-serializable state + assert scheme._sender is not None - # Clean up - parent_pipe.close() - child_pipe.close() + proc = mp.Process(target=self._get_scheme_from_queue, args=(q, scheme)) + proc.start() + try: + msg = q.get(timeout=10.0) + assert msg == "success", msg + finally: + proc.join() + finally: + q.close() + # Clean up + parent_pipe.close() + child_pipe.close() def test_no_weight_sync_scheme_serialize(self): """Test that NoWeightSyncScheme can be pickled.""" @@ -809,7 +818,7 @@ def test_scheme_reinitialization_after_unpickle(self): """Test that a scheme can be re-initialized after unpickling. This is the expected workflow: pickle a scheme, unpickle it in a worker, - then call init_on_worker() to establish new runtime resources. + then call init_on_receiver() to establish new runtime resources. """ # Initialize and pickle a scheme parent_pipe, child_pipe = mp.Pipe() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 091ab8c4c9d..d6ab5ef4d76 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -39,7 +39,7 @@ def _make_policy_factory( if weight_sync_scheme is not None: # Initialize the receiver on the worker side - weight_sync_scheme.init_on_worker( + weight_sync_scheme.init_on_receiver( model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe ) # Get the receiver and synchronize initial weights @@ -147,7 +147,7 @@ def _main_async_collector( inner_collector._weight_receivers[model_id] = receiver else: # Initialize receivers for other models - scheme.init_on_worker(model_id=model_id, context=inner_collector) + scheme.init_on_receiver(model_id=model_id, context=inner_collector) receiver = scheme.get_receiver() receiver.synchronize_weights(worker_idx=worker_idx) inner_collector._weight_receivers[model_id] = receiver diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 12d9c7be3fb..9da2795ba24 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import Any +from typing import Any, overload from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, @@ -22,7 +22,7 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): Synchronization flow: - init_on_sender() creates a MPWeightSender and registers all worker pipes - synchronize_weights() triggers the initial weight distribution via pipes - - init_on_worker() creates a MPWeightReceiver that receives from its pipe + - init_on_receiver() creates a MPWeightReceiver that receives from its pipe - Subsequent updates use send() which extracts, sends, and waits for ACKs Args: @@ -55,6 +55,27 @@ def synchronize_weights(self): ) self._sender.synchronize_weights() + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + *, + pipes: list = ..., + num_workers: int | None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -93,7 +114,28 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + pipe: Any = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -138,7 +180,7 @@ def create_transport(self, pipe: Any) -> TransportBackend: """Create an MPTransport using the provided pipe. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ return MPTransport(pipe) diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 697f56943e8..1f3ff01ea30 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, overload from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, @@ -16,6 +16,24 @@ class NoWeightSyncScheme(WeightSyncScheme): This scheme disables weight synchronization entirely. """ + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -36,7 +54,25 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -60,7 +96,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create a no-op transport. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ # Return a dummy transport that does nothing class NoOpTransport: diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index 3fb4e571224..b8d344a9df8 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import Any, Literal +from typing import Any, Literal, overload from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( @@ -33,6 +33,28 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: """ return RayTransport(remote_collector=pipe_or_context) + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + *, + remote_collectors: list = ..., + num_workers: int | None = None, + source_model: Any | None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -81,7 +103,27 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -166,6 +208,29 @@ def create_receiver(self) -> RayModuleTransformReceiver: """Create a specialized receiver for Ray actor communication.""" return RayModuleTransformReceiver(self) + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + *, + actor_refs: list | None = None, + actors: list | None = None, + remote_collectors: list | None = None, + source_model: Any | None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -219,7 +284,28 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + actor_ref: Any | None = None, + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -452,7 +538,7 @@ def __init__(self, scheme: RayModuleTransformScheme): def _register_worker_transport(self, actor_or_context: Any) -> None: """Register the Ray actor's transport (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: actor_or_context: Either a Ray actor reference or a context object. diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 098c4fe6e49..b8e7e815917 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -1,10 +1,8 @@ from __future__ import annotations -import abc - import weakref -from collections.abc import Callable, Iterator -from typing import Any, Literal, Protocol +from collections.abc import Callable +from typing import Any, overload import torch import torch.distributed @@ -18,6 +16,7 @@ TransportBackend, WeightReceiver, WeightSender, + WeightStrategy, WeightSyncScheme, ) @@ -43,6 +42,7 @@ def __init__(self): self._weight_queues = ( None # Dict of per-worker queues for distributing shared weights ) + self._unique_weights = None def register_weights( self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] @@ -115,6 +115,8 @@ def send_weights(self, weights: Any) -> None: if any("." in key for key in weights.keys()): weights_to_update = weights.unflatten_keys(".") + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") for buffer in self._unique_weights: buffer.update_(weights_to_update, non_blocking=True) if torch.cuda.is_available(): @@ -163,8 +165,94 @@ def __init__( # General message queue for coordination (if needed in future) self._message_queue = mp.Queue() + @overload + def init_on_sender( + self, + *, + model_id: str, + context: Any, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload def init_on_sender( self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + def init_on_sender( + self, + *, model_id: str | None = None, context: Any = None, weights: TensorDictBase | None = None, @@ -400,9 +488,28 @@ def _get_params_map( "Either params_map, model_id + context or model/weights + devices must be provided." ) - def init_on_worker( + @overload + def init_on_receiver( self, + *, model_id: str, + context: Any, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + *, + model: Any, + worker_idx: int, + ) -> None: + ... + + def init_on_receiver( + self, + *, + model_id: str | None = None, context: Any = None, model: Any = None, worker_idx: int | None = None, @@ -423,6 +530,8 @@ def init_on_worker( """ # Extract parameters from context or kwargs if context is not None: + if model_id is None: + raise ValueError("model_id is required when context is provided") if hasattr(context, "get_model"): model = context.get_model(model_id) elif model is None: @@ -472,7 +581,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: Since this is shared memory, there's only one transport shared by all workers. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ return self._shared_transport @@ -512,8 +621,26 @@ def prepare_weights( # Fall back to default behavior return super().prepare_weights(weights, model_id, strategy, context) + class SharedMemWeightReceiver(WeightReceiver): + """Weight receiver for shared memory systems. + + Receives weight updates via shared memory buffers. Workers automatically + see weight updates without explicit message passing, providing zero-copy + weight synchronization. This is typically instantiated and managed by + :class:`SharedMemWeightSyncScheme`. + """ + _transport: SharedMemTransport | None + class SharedMemWeightSender(WeightSender): - _transport: SharedMemTransport | None \ No newline at end of file + """Weight sender for shared memory systems. + + Sends weight updates by writing directly to shared memory buffers. + All workers automatically see updates without explicit communication, + providing zero-copy weight synchronization. This is typically instantiated + and managed by :class:`SharedMemWeightSyncScheme`. + """ + + _transport: SharedMemTransport | None diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index b3e3b1870ba..09ebc333dee 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -470,12 +470,12 @@ def __init__(self, scheme: WeightSyncScheme): self._transport = None # lazy self._model_ref = None self._strategy = _get_strategy(scheme.strategy) - self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_worker() + self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_receiver() def _set_context(self, context: Any) -> None: """Set the context object (inner_collector) for resolving references (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: context: The inner collector instance in the worker process. @@ -485,7 +485,7 @@ def _set_context(self, context: Any) -> None: def _register_model(self, model_ref: Any) -> None: """Register the model to apply weights to (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. @@ -495,7 +495,7 @@ def _register_model(self, model_ref: Any) -> None: def _register_worker_transport(self, pipe: Any) -> None: """Register this worker's communication pipe (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: pipe: The pipe connection for this worker. @@ -556,7 +556,7 @@ def synchronize_weights(self, worker_idx: int | None = None) -> None: Args: worker_idx: The worker index (required for SharedMemTransport). - If not provided, uses the worker_idx stored during init_on_worker(). + If not provided, uses the worker_idx stored during init_on_receiver(). """ if self._transport is None: return @@ -661,7 +661,7 @@ def init_on_sender( """ raise NotImplementedError - def init_on_worker( + def init_on_receiver( self, model_id: str, context: Any = None, @@ -702,11 +702,11 @@ def get_receiver(self) -> WeightReceiver: Receiver instance for receiving weights in this worker Raises: - RuntimeError: If init_on_worker() hasn't been called yet + RuntimeError: If init_on_receiver() hasn't been called yet """ if not self._initialized_on_worker or self._receiver is None: raise RuntimeError( - f"Must call init_on_worker() before get_receiver() on {type(self).__name__}" + f"Must call init_on_receiver() before get_receiver() on {type(self).__name__}" ) return self._receiver @@ -740,7 +740,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: A transport backend instance. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ ... @@ -762,7 +762,7 @@ def create_receiver(self) -> WeightReceiver: WeightReceiver instance configured for this scheme. Note: - Typically you should use init_on_worker() followed by get_receiver() instead. + Typically you should use init_on_receiver() followed by get_receiver() instead. """ return WeightReceiver(self) From 452f09571b8899b200e7896784b9de8297b45f49 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 17:56:53 +0000 Subject: [PATCH 14/17] fixes --- torchrl/weight_update/_mp.py | 425 +++++++++++++++++-- torchrl/weight_update/weight_sync_schemes.py | 6 +- 2 files changed, 387 insertions(+), 44 deletions(-) diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 9da2795ba24..91bb4261233 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -1,8 +1,14 @@ from __future__ import annotations import weakref +from collections.abc import Callable from typing import Any, overload +import torch +from tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, @@ -15,39 +21,69 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): """Weight synchronization for multiprocess operations using pipes. This scheme creates transports that communicate via multiprocessing pipes. - Similar to SharedMemWeightSyncScheme which uses queues for shared memory - buffer distribution, MultiProcessWeightSyncScheme uses pipes to send - weight copies to each worker. + It follows a memory-efficient two-phase pattern similar to SharedMemWeightSyncScheme: + + 1. **init_on_sender()**: Stores the recipe for creating device-specific weights + (model reference, devices, mapping functions) without creating actual copies + 2. **synchronize_weights()**: Creates device-specific weight copies on-demand, + sends them sequentially to workers via pipes, allowing garbage collection + between workers to minimize memory usage + + This approach avoids holding multiple weight copies in memory simultaneously, + which is especially beneficial for large models with many workers. Synchronization flow: - - init_on_sender() creates a MPWeightSender and registers all worker pipes - - synchronize_weights() triggers the initial weight distribution via pipes - - init_on_receiver() creates a MPWeightReceiver that receives from its pipe - - Subsequent updates use send() which extracts, sends, and waits for ACKs + - **init_on_sender()**: Store configuration and register worker pipes + - **synchronize_weights()**: Create and send initial weights on-demand + - **init_on_receiver()**: Create receiver that reads from pipe + - **send()**: Extract and send weight updates, wait for acknowledgments Args: strategy: The weight transmission strategy (default: "tensordict"). + Can be "tensordict" or "state_dict". Example: >>> # Basic usage with collector >>> scheme = MultiProcessWeightSyncScheme() >>> collector = MultiSyncDataCollector( - ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, ... policy=policy, ... frames_per_batch=100, ... total_frames=1000, ... weight_sync_schemes={"policy": scheme}, ... ) >>> # scheme.synchronize_weights() is called automatically by collector + >>> # Weights are created on-demand and sent to workers efficiently + + Note: + The on-demand weight creation means that synchronize_weights() will be + slower than if weights were pre-computed, but memory usage is significantly + reduced, especially when workers use different devices or when the model + is large. """ def synchronize_weights(self): - """Method to be called once the workers have started. + """Send initial weights to all workers before collection starts. + + This method triggers the on-demand creation and distribution of device-specific + weight copies to workers. Unlike pre-computing all weights during init_on_sender(), + this approach creates each worker's weights sequentially, sends them via pipes, + and allows garbage collection before creating the next worker's weights. + + This is a convenience method that delegates to the sender's synchronize_weights(), + which handles the actual weight creation and distribution. + + Memory efficiency note: + If all workers share the same device, only one weight copy is created and + reused. If workers use different devices, weights are created and sent + sequentially to minimize peak memory usage. - Triggers a rendez-vous for the workers to receive their copy of the weights. + Called automatically by: + - MultiSyncDataCollector during initialization + - MultiaSyncDataCollector during initialization - This is a convenience method that delegates to the sender's synchronize_weights(). - The sender will extract weights from the context and send them to all workers via pipes. + Raises: + RuntimeError: If init_on_sender() was not called first """ if not self._initialized_on_sender or self._sender is None: raise RuntimeError( @@ -58,51 +94,196 @@ def synchronize_weights(self): @overload def init_on_sender( self, + *, model_id: str, context: Any, - **kwargs, ) -> None: ... @overload def init_on_sender( self, - model_id: str, - context: None = None, *, - pipes: list = ..., - num_workers: int | None = None, - **kwargs, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, ) -> None: ... + @overload def init_on_sender( self, - model_id: str, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + def init_on_sender( + self, + *, + model_id: str | None = None, context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + pipes: list[Any] | None = None, **kwargs, ) -> None: """Initialize on the main process (sender side). + This method stores the configuration needed to create device-specific weight + copies during synchronization. Weight copies are created on-demand during + `synchronize_weights()` to reduce memory usage. + + Similar to `SharedMemWeightSyncScheme`, this follows a two-phase pattern: + 1. `init_on_sender()`: Store the recipe for creating weights + 2. `synchronize_weights()`: Create and send weights on-demand + Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes and num_workers - **kwargs: Alternative to context (pipes, num_workers, etc.) + model_id: Identifier for the model being synchronized (e.g., "policy"). + Required when using context. + context: Optional context object (e.g., collector) providing: + - pipes: List of multiprocessing pipes for worker communication + - num_workers: Number of worker processes + - policy_device: List of devices for each worker + When provided, model_id is used to resolve the model from context. + weights: Pre-extracted weights as TensorDict. Mutually exclusive with + model and context. Used when weights are already available. + model: Model to extract weights from. Mutually exclusive with weights + and context. + params_map: Pre-computed mapping of worker_idx to device-specific weights. + Most explicit option. When provided, all other parameters except pipes + must be None. + devices: List of devices for each worker. Used with weights or model to + automatically create device-specific copies. Length must equal num_workers. + device_map_fn: Custom function (worker_idx, weights) -> device_weights. + Allows full control over device mapping. Requires num_workers. + num_workers: Number of workers. Required with device_map_fn, inferred + from devices length or pipes otherwise. + pipes: List of multiprocessing pipes. Required unless provided via context. + **kwargs: Alternative way to provide pipes (for backward compatibility). + + Examples: + Simple usage with collector context (most common): + + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, + ... policy=policy, + ... frames_per_batch=100, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Direct initialization with explicit devices: + + >>> scheme = MultiProcessWeightSyncScheme() + >>> weights = TensorDict.from_module(policy) + >>> scheme.init_on_sender( + ... weights=weights, + ... devices=[torch.device("cpu"), torch.device("cuda:0")], + ... pipes=[pipe1, pipe2], + ... ) + + Advanced: Pre-computed params_map: + + >>> weights_cpu = TensorDict.from_module(policy) + >>> weights_cuda = weights_cpu.to("cuda") + >>> scheme.init_on_sender( + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... pipes=[pipe1, pipe2, pipe3], + ... ) """ - # Extract parameters from context or kwargs + # Extract parameters from context or parameters/kwargs if context is not None: pipes = getattr(context, "pipes", None) num_workers = getattr(context, "num_workers", None) else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") + # Use the pipes parameter if provided, otherwise check kwargs + if pipes is None: + pipes = kwargs.get("pipes") if pipes is None: raise ValueError("pipes must be provided via context or kwargs") if num_workers is None: num_workers = len(pipes) if pipes else 0 - # Create sender and register all workers + # Store the mapping recipe for later use in synchronize_weights + # Don't compute params_map yet to save memory + # Note: We don't store context directly to avoid pickle issues - + # it's available via sender._context_ref + self._device_mapping_info = { + "model_id": model_id, + "weights": weights, + "model": model, + "params_map": params_map, + "devices": devices, + "device_map_fn": device_map_fn, + "num_workers": num_workers, + } + + # Create sender with the shared transport sender = MPWeightSender(self) sender._model_id = model_id if context is not None: @@ -114,6 +295,140 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Compute the params_map (dict[worker_idx, device_weights]) on-demand. + + This method creates device-specific weight copies based on the provided + configuration. It's called during synchronize_weights() rather than + init_on_sender() to reduce memory usage. + + The method supports several input patterns: + 1. Direct params_map: Returned as-is (already computed) + 2. Context + model_id: Extract model and devices from context + 3. Model/weights + devices: Create copies on specified devices + 4. Model/weights + device_map_fn: Apply custom mapping function + + Args: + context: Context object (e.g., collector) to extract model and devices from + model_id: Model identifier to resolve within context + weights: Pre-extracted weights as TensorDict + model: Model to extract weights from + params_map: Pre-computed mapping (returned as-is if provided) + devices: List of devices, one per worker + device_map_fn: Custom mapping function (worker_idx, weights) -> device_weights + num_workers: Number of workers (required with device_map_fn) + + Returns: + dict[int, TensorDictBase]: Mapping from worker_idx to device-specific weights + + Raises: + ValueError: If parameter combinations are invalid or mutually exclusive + """ + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + @overload def init_on_receiver( self, @@ -328,6 +643,7 @@ class MPWeightSender(WeightSender): _transport: MPTransport | None _model_id: str + _scheme: MultiProcessWeightSyncScheme def send( self, @@ -438,36 +754,61 @@ def send_async( def synchronize_weights(self) -> None: """Synchronize weights with workers before collection starts. - Extracts weights from the collector's policy and sends them to all workers - via pipes. This is called once after workers are initialized but before they - start collecting data. + Computes device-specific weight copies on-demand and sends them to workers + sequentially via pipes. This is called once after workers are initialized + but before they start collecting data. Unlike send(), this does not wait for acknowledgments since workers are still in their initialization phase. + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + Raises: - RuntimeError: If no context is available or context has no policy. + RuntimeError: If init_on_sender() was not called first. """ - # Get context (collector) - context = self._context_ref() if self._context_ref is not None else None - if context is None or not hasattr(context, "policy"): + # Get the device mapping info stored during init_on_sender + if not hasattr(self._scheme, "_device_mapping_info"): raise RuntimeError( - "MPWeightSender requires context with policy for synchronize_weights()" + "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" ) - # Extract and prepare weights from the policy - prepared_weights = self._scheme.prepare_weights( - weights=context.policy, - model_id=self._model_id, - strategy=self._strategy, + mapping_info = self._scheme._device_mapping_info + + # Get context from sender's weakref + context = self._context_ref() if self._context_ref is not None else None + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._scheme._get_params_map( context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, ) - # Send to all workers via pipes (no ACK - workers are still initializing) - for transport in self._iterate_transports(): + # Send to workers sequentially via pipes (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + worker_weights = params_map[i] if hasattr(transport, "send_weights_async"): - transport.send_weights_async(prepared_weights, model_id=self._model_id) # type: ignore[attr-defined] + transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] else: raise RuntimeError( f"Transport {type(transport)} does not support async send for synchronization" ) + + # Clean up the mapping info after synchronization + delattr(self._scheme, "_device_mapping_info") diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 09ebc333dee..52806416aa8 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -279,13 +279,15 @@ def _iterate_transports( if not self._transports: yield self._transport else: - yield from self._transports.values() + # Make sure transports are sorted + for k in sorted(self._transports.keys()): + yield self._transports[k] else: # Specific workers if isinstance(worker_ids, int): worker_ids = [worker_ids] for worker_id in worker_ids: - if worker_id in self._transports: + if worker_id in sorted(self._transports.keys()): yield self._transports[worker_id] else: raise ValueError(f"Worker {worker_id} not registered") From 8b9508fbeec18fe881c40e3fc2aaf33ce386eee6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 20:36:37 +0000 Subject: [PATCH 15/17] amend --- test/test_collector.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index b0350ec025e..35f6a99ad63 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3953,14 +3953,22 @@ def test_weight_update(self, weight_updater): policy_weights = TensorDict.from_module(policy) kwargs = {} if weight_updater == "scheme_shared": - kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}} + scheme = SharedMemWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "scheme_pipe": - kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}} + scheme = MultiProcessWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "weight_updater": + scheme = None kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)} else: raise NotImplementedError + if scheme is not None: + scheme.init_on_sender( + model=policy_factory(), devices=[device] * 2, model_id="policy" + ) + collector = MultiSyncDataCollector( create_env_fn=[env_maker, env_maker], policy_factory=policy_factory, @@ -3973,6 +3981,8 @@ def test_weight_update(self, weight_updater): storing_device="cpu", **kwargs, ) + if weight_updater == "weight_updater": + assert collector._legacy_weight_updater # When using policy_factory, must pass weights explicitly collector.update_policy_weights_(policy_weights) From e973d9318766748f52af3784103acfe027f78393 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Nov 2025 17:26:24 -1000 Subject: [PATCH 16/17] amend --- test/test_collector.py | 3 +- torchrl/collectors/_multi_base.py | 4 +- torchrl/collectors/_runner.py | 4 +- torchrl/weight_update/_mp.py | 473 +++++-------------- torchrl/weight_update/_noupdate.py | 2 +- torchrl/weight_update/_ray.py | 2 +- torchrl/weight_update/_shared.py | 87 +--- torchrl/weight_update/weight_sync_schemes.py | 111 ++++- 8 files changed, 236 insertions(+), 450 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 35f6a99ad63..04f2b27a24b 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3938,13 +3938,12 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: def all_worker_ids(self) -> list[int] | list[torch.device]: return list(range(self.num_workers)) - @pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.") @pytest.mark.skipif(not _has_gym, reason="requires gym") @pytest.mark.parametrize( "weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"] ) def test_weight_update(self, weight_updater): - device = "cuda:0" + device = "cuda:0" if torch.cuda.is_available() else "cpu" env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu") policy_factory = lambda: TensorDictModule( nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"] diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 01633823242..912ecfd3e6f 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -835,9 +835,9 @@ def _run_processes(self) -> None: # can be initialized here since all required resources exist if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): - if hasattr(scheme, "init_on_sender"): + if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) - self._weight_senders[model_id] = scheme.get_sender() + self._weight_senders[model_id] = scheme.get_sender() # Create a policy on the right device policy_factory = self.policy_factory diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index d6ab5ef4d76..63d1d0c2cd1 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -40,7 +40,9 @@ def _make_policy_factory( if weight_sync_scheme is not None: # Initialize the receiver on the worker side weight_sync_scheme.init_on_receiver( - model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe + model=policy, + model_id="policy", + worker_idx=worker_idx, ) # Get the receiver and synchronize initial weights receiver = weight_sync_scheme.get_receiver() diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 91bb4261233..fc845fcdf64 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -5,37 +5,39 @@ from typing import Any, overload import torch -from tensordict import TensorDict, TensorDictBase -from torch import nn +from tensordict import TensorDictBase +from torch import multiprocessing as mp, nn +from torchrl.weight_update._shared import SharedMemWeightSyncScheme -from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, WeightSender, - WeightSyncScheme, ) -class MultiProcessWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for multiprocess operations using pipes. +class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): + """Weight synchronization for multiprocess operations using queues. - This scheme creates transports that communicate via multiprocessing pipes. - It follows a memory-efficient two-phase pattern similar to SharedMemWeightSyncScheme: + This scheme creates transports that communicate via multiprocessing queues. + Unlike the parent SharedMemWeightSyncScheme which uses shared memory for in-place + updates, this scheme sends actual weight copies through queues to workers. + + It follows the same two-phase pattern as SharedMemWeightSyncScheme: 1. **init_on_sender()**: Stores the recipe for creating device-specific weights (model reference, devices, mapping functions) without creating actual copies 2. **synchronize_weights()**: Creates device-specific weight copies on-demand, - sends them sequentially to workers via pipes, allowing garbage collection + sends them sequentially to workers via queues, allowing garbage collection between workers to minimize memory usage This approach avoids holding multiple weight copies in memory simultaneously, which is especially beneficial for large models with many workers. Synchronization flow: - - **init_on_sender()**: Store configuration and register worker pipes + - **init_on_sender()**: Store configuration and register worker queues - **synchronize_weights()**: Create and send initial weights on-demand - - **init_on_receiver()**: Create receiver that reads from pipe + - **init_on_receiver()**: Create receiver that reads from queue - **send()**: Extract and send weight updates, wait for acknowledgments Args: @@ -62,121 +64,17 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): is large. """ - def synchronize_weights(self): - """Send initial weights to all workers before collection starts. - - This method triggers the on-demand creation and distribution of device-specific - weight copies to workers. Unlike pre-computing all weights during init_on_sender(), - this approach creates each worker's weights sequentially, sends them via pipes, - and allows garbage collection before creating the next worker's weights. - - This is a convenience method that delegates to the sender's synchronize_weights(), - which handles the actual weight creation and distribution. - - Memory efficiency note: - If all workers share the same device, only one weight copy is created and - reused. If workers use different devices, weights are created and sent - sequentially to minimize peak memory usage. + def __init__(self, strategy: str = "tensordict"): + """Initialize the MultiProcessWeightSyncScheme. - Called automatically by: - - MultiSyncDataCollector during initialization - - MultiaSyncDataCollector during initialization - - Raises: - RuntimeError: If init_on_sender() was not called first + Args: + strategy: The weight transmission strategy (default: "tensordict"). """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - "Must call init_on_sender() before synchronize_weights() on MultiProcessWeightSyncScheme" - ) - self._sender.synchronize_weights() - - @overload - def init_on_sender( - self, - *, - model_id: str, - context: Any, - ) -> None: - ... + super().__init__(strategy) + # Override parent's shared transport - we don't use shared memory + self._shared_transport = None - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - model_id: str | None = None, - ) -> None: - ... - - def init_on_sender( + def _init_on_sender_impl( self, *, model_id: str | None = None, @@ -187,7 +85,6 @@ def init_on_sender( devices: list[torch.device] | None = None, device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, num_workers: int | None = None, - pipes: list[Any] | None = None, **kwargs, ) -> None: """Initialize on the main process (sender side). @@ -204,7 +101,6 @@ def init_on_sender( model_id: Identifier for the model being synchronized (e.g., "policy"). Required when using context. context: Optional context object (e.g., collector) providing: - - pipes: List of multiprocessing pipes for worker communication - num_workers: Number of worker processes - policy_device: List of devices for each worker When provided, model_id is used to resolve the model from context. @@ -213,16 +109,14 @@ def init_on_sender( model: Model to extract weights from. Mutually exclusive with weights and context. params_map: Pre-computed mapping of worker_idx to device-specific weights. - Most explicit option. When provided, all other parameters except pipes - must be None. + Most explicit option. When provided, all other parameters must be None. devices: List of devices for each worker. Used with weights or model to automatically create device-specific copies. Length must equal num_workers. device_map_fn: Custom function (worker_idx, weights) -> device_weights. Allows full control over device mapping. Requires num_workers. num_workers: Number of workers. Required with device_map_fn, inferred - from devices length or pipes otherwise. - pipes: List of multiprocessing pipes. Required unless provided via context. - **kwargs: Alternative way to provide pipes (for backward compatibility). + from devices length otherwise. + **kwargs: Reserved for future use. Examples: Simple usage with collector context (most common): @@ -243,7 +137,7 @@ def init_on_sender( >>> scheme.init_on_sender( ... weights=weights, ... devices=[torch.device("cpu"), torch.device("cuda:0")], - ... pipes=[pipe1, pipe2], + ... num_workers=2, ... ) Advanced: Pre-computed params_map: @@ -252,25 +146,23 @@ def init_on_sender( >>> weights_cuda = weights_cpu.to("cuda") >>> scheme.init_on_sender( ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, - ... pipes=[pipe1, pipe2, pipe3], + ... num_workers=3, ... ) """ - # Extract parameters from context or parameters/kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - else: - # Use the pipes parameter if provided, otherwise check kwargs - if pipes is None: - pipes = kwargs.get("pipes") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 + # Get params_map from parent class logic + params_map_result = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) # Store the mapping recipe for later use in synchronize_weights - # Don't compute params_map yet to save memory + # Don't store params_map directly to save memory - we'll recompute on demand # Note: We don't store context directly to avoid pickle issues - # it's available via sender._context_ref self._device_mapping_info = { @@ -280,155 +172,37 @@ def init_on_sender( "params_map": params_map, "devices": devices, "device_map_fn": device_map_fn, - "num_workers": num_workers, + "num_workers": num_workers + if num_workers is not None + else len(params_map_result), } - # Create sender with the shared transport + # Create per-worker queues for weight distribution + # Each worker gets its own queue for receiving weights + all_workers = list(params_map_result.keys()) + if not hasattr(self, "_weight_init_queues"): + self._weight_init_queues = {} + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + + # Create sender sender = MPWeightSender(self) sender._model_id = model_id if context is not None: sender._context_ref = weakref.ref(context) - for worker_idx, pipe in enumerate(pipes): - sender._register_worker(worker_idx, pipe) + # Register workers with their queues + for worker_idx in all_workers: + queue = self._weight_init_queues[worker_idx] + # Create MPTransport for this worker + transport = MPTransport(weight_queue=queue, ack_queue=None) + sender._register_worker(worker_idx, transport) self._sender = sender self._initialized_on_sender = True - def _get_params_map( - self, - context: Any = None, - model_id: str | None = None, - weights: TensorDictBase | None = None, - model: nn.Module | None = None, - params_map: dict[int, TensorDictBase] | None = None, - devices: list[torch.device] | None = None, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, - num_workers: int | None = None, - ): - """Compute the params_map (dict[worker_idx, device_weights]) on-demand. - - This method creates device-specific weight copies based on the provided - configuration. It's called during synchronize_weights() rather than - init_on_sender() to reduce memory usage. - - The method supports several input patterns: - 1. Direct params_map: Returned as-is (already computed) - 2. Context + model_id: Extract model and devices from context - 3. Model/weights + devices: Create copies on specified devices - 4. Model/weights + device_map_fn: Apply custom mapping function - - Args: - context: Context object (e.g., collector) to extract model and devices from - model_id: Model identifier to resolve within context - weights: Pre-extracted weights as TensorDict - model: Model to extract weights from - params_map: Pre-computed mapping (returned as-is if provided) - devices: List of devices, one per worker - device_map_fn: Custom mapping function (worker_idx, weights) -> device_weights - num_workers: Number of workers (required with device_map_fn) - - Returns: - dict[int, TensorDictBase]: Mapping from worker_idx to device-specific weights - - Raises: - ValueError: If parameter combinations are invalid or mutually exclusive - """ - if params_map is not None: - # Sanity check: params_map must be a dict[int, TensorDictBase] - # All other args must be None - if ( - not isinstance(params_map, dict) - or not all(isinstance(v, int) for v in params_map.keys()) - or not all(isinstance(v, TensorDictBase) for v in params_map.values()) - ): - raise ValueError("params_map must be a dict[int, TensorDictBase]") - if model_id is not None or weights is not None or model is not None: - raise ValueError( - "model_id, weights, and model cannot be provided if params_map is provided" - ) - if context is not None: - raise ValueError("context cannot be provided if params_map is provided") - if devices is not None: - raise ValueError("devices cannot be provided if params_map is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if params_map is provided" - ) - if num_workers is not None: - raise ValueError( - "num_workers cannot be provided if params_map is provided" - ) - return params_map - elif context is not None: - if devices is not None: - raise ValueError("devices cannot be provided if context is provided") - # Sanity check: model_id must be provided if context is provided - # All other args must be None - if model_id is None: - raise ValueError("model_id must be provided if context is provided") - if model is not None: - raise ValueError("model cannot be provided if context is provided") - if weights is not None: - raise ValueError("weights cannot be provided if context is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if context is provided" - ) - # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later - devices = context.policy_device - if num_workers is not None and num_workers != len(devices): - raise ValueError( - "num_workers cannot be provided if context is provided" - ) - # Get the weights - model = _resolve_model(context, model_id) - weights = TensorDict.from_module(model) - elif model is not None: - if weights is not None: - raise ValueError("weights cannot be provided if model is provided") - weights = TensorDict.from_module(model) - # To make the map, we need the list of devices, or the map fn - if devices is not None: - # Import _cast locally to avoid circular imports - from torchrl.collectors.utils import _cast - - # Get the unique devices - devices_set = set(devices) - weights_devices = {p.device for p in weights.values(True, True)} - if len(weights_devices) == 1: - weights_device = weights_devices.pop() - else: - weights_device = None - - # Create device map with proper Parameter handling using _cast - # _cast ensures Parameters stay as Parameters (with requires_grad=False) - device_map = {} - for d in devices_set: - if d != weights_device: - # Move to device and apply _cast to preserve Parameter/Buffer types - weights_on_device = weights.to(d) - weights_on_device = weights_on_device.apply(_cast, weights) - device_map[d] = weights_on_device - else: - # Already on correct device, just apply _cast - device_map[d] = weights.apply(_cast, weights) - - # Create the map - params_map = { - worker_idx: device_map[device] - for worker_idx, device in enumerate(devices) - } - return params_map - if device_map_fn is not None: - return { - worker_idx: device_map_fn(worker_idx, weights) - for worker_idx in range(num_workers) - } - raise ValueError( - "Either params_map, model_id + context or model/weights + devices must be provided." - ) - @overload def init_on_receiver( self, @@ -444,7 +218,7 @@ def init_on_receiver( model_id: str, context: None = None, *, - pipe: Any = ..., + worker_idx: int = ..., model: Any | None = None, **kwargs, ) -> None: @@ -460,69 +234,86 @@ def init_on_receiver( Args: model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) + context: Optional context object providing worker_idx and model + **kwargs: Alternative to context (worker_idx, model, etc.) """ # Extract parameters from context or kwargs if context is not None: - pipe = getattr(context, "pipe", None) + worker_idx = getattr(context, "worker_idx", None) if hasattr(context, "get_model"): model = context.get_model(model_id) else: model = None else: - pipe = kwargs.get("pipe") + worker_idx = kwargs.get("worker_idx") model = kwargs.get("model") - if pipe is None: - raise ValueError("pipe must be provided via context or kwargs") + if worker_idx is None: + raise ValueError("worker_idx must be provided via context or kwargs") + + # Get the queue for this worker + if worker_idx not in self._weight_init_queues: + raise ValueError( + f"Worker {worker_idx} not registered. init_on_sender() must be called first." + ) + + queue = self._weight_init_queues[worker_idx] # Create receiver and register model receiver = MPWeightReceiver(self) if context is not None: receiver._context_ref = weakref.ref(context) - receiver._register_worker_transport(pipe) + + # Create transport with the worker's queue + transport = MPTransport(weight_queue=queue, ack_queue=None) + receiver._register_worker_transport(transport) + if model is not None: receiver._register_model(model) else: # Register by model_id for later resolution receiver._register_model(model_id) + # Store worker_idx for synchronize_weights + receiver._worker_idx = worker_idx + self._receiver = receiver self._initialized_on_worker = True - def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe. + def create_transport(self, queue: Any) -> TransportBackend: + """Create an MPTransport using the provided queue. Note: This is used internally by init_on_sender/init_on_receiver. """ - return MPTransport(pipe) + return MPTransport(weight_queue=queue, ack_queue=None) class MPTransport: - """Multiprocessing transport using pipes. + """Multiprocessing transport using queues. - This transport uses pipes for weight distribution and synchronization. + This transport uses queues for weight distribution and synchronization. Similar to SharedMemTransport's queue-based approach, MPTransport uses - pipes to send initial weights to workers during synchronization. + queues to send initial weights to workers during synchronization. Initialization flow: - - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via pipes + - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via queues - Workers receive the initial weights via synchronize_weights_on_worker() - Subsequent updates use send_weights_async() followed by acknowledgments Args: - pipe_connection (mp.Pipe): The pipe connection to use for communication. + weight_queue (mp.Queue): The queue to use for sending weights. + ack_queue (mp.Queue): The queue to use for receiving acknowledgments. timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. """ - def __init__(self, pipe_connection, timeout: float = 10.0): + def __init__(self, weight_queue, ack_queue=None, timeout: float = 10.0): self.timeout = timeout - self.pipe = pipe_connection + self.weight_queue = weight_queue + self.ack_queue = ack_queue def send_weights(self, weights: Any) -> None: - """Send weights through the pipe. + """Send weights through the queue. Sends weights and waits for acknowledgment to ensure delivery. """ @@ -530,19 +321,20 @@ def send_weights(self, weights: Any) -> None: self.wait_ack() def send_weights_async(self, weights: Any, model_id: str = "policy") -> None: - """Send weights through the pipe without waiting for acknowledgment. + """Send weights through the queue without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. """ # Send in format expected by worker loop: ((model_id, weights), "update_weights") - self.pipe.send(((model_id, weights), "update_weights")) + self.weight_queue.put(((model_id, weights), "update_weights")) def wait_ack(self) -> None: """Wait for acknowledgment from worker.""" - self.check_ack("updated") + if self.ack_queue is not None: + self.check_ack("updated") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process). + """Receive weights from the queue (used in worker process). This method only handles weight update messages. Other messages (like "close", "continue", etc.) are ignored and should be handled @@ -556,34 +348,28 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: model_id is returned as "policy" for backward compatibility, but transports are now bound to a single model during initialization. """ - if self.pipe.poll(timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - # data_in is now (model_id, weights) - return data_in - else: - # Not a weight update message - put it back and return None - # This allows the main worker loop to handle other messages - # Note: We can't actually "put it back", so we'll just return None - # and the message is lost. This is why receive() should only be called - # when we're expecting weight updates, not in the main message loop. - return None - # No data available - return None instead of raising TimeoutError - # This allows non-blocking checks in the worker loop - return None + data_in, msg = self.weight_queue.get(timeout=timeout) + if msg == "update_weights": + # data_in is now (model_id, weights) + return data_in + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") def send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender.""" - self.pipe.send((None, message)) + if self.ack_queue is not None: + self.ack_queue.put((None, message)) def check_ack(self, message: str = "updated") -> None: """Check for acknowledgment.""" - _, msg = self.pipe.recv() - if msg != message: - raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") + if self.ack_queue is not None: + _, msg = self.ack_queue.get(timeout=self.timeout) + if msg != message: + raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") def check_connection(self) -> bool: - return not self.pipe.closed + # Queues don't have a 'closed' attribute, so we assume they're always open + return True def synchronize_weights_on_sender(self) -> None: """No-op for MPTransport - weights are sent via MPWeightSender.synchronize_weights(). @@ -591,7 +377,7 @@ def synchronize_weights_on_sender(self) -> None: The actual sending happens in MPWeightSender.synchronize_weights(), which: 1. Extracts weights from the context (e.g., collector.policy) 2. Calls send_weights_async() on all worker transports - 3. Sends initial weights through pipes to all workers + 3. Sends initial weights through queues to all workers This is similar to SharedMemTransport.synchronize_weights_on_sender() which sends shared memory buffer references via queues. @@ -601,8 +387,8 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: """Receive initial weights from sender during worker initialization. This method blocks waiting for the initial weights to be sent from the main process - via pipe. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives - shared memory buffer references via queues, this receives the actual weights via pipes. + via queue. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives + shared memory buffer references via queues, this receives the actual weights via queues. The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). @@ -613,20 +399,19 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: The received weights if available, None otherwise (weights will come later via receive()). """ # Wait for initial weights (blocking) - if self.pipe.poll(timeout=self.timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - # data_in is (model_id, weights), extract just the weights - _, weights = data_in - return weights - # If we don't receive weights, return None (weights will come later) - return None + data_in, msg = self.weight_queue.get(timeout=self.timeout) + if msg == "update_weights": + # data_in is (model_id, weights), extract just the weights + _, weights = data_in + return weights + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") class MPWeightReceiver(WeightReceiver): - """Weight receiver for multiprocess systems using pipes. + """Weight receiver for multiprocess systems using queues. - Receives weight updates from the main process via multiprocessing pipes. + Receives weight updates from the main process via multiprocessing queues. This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. """ @@ -634,9 +419,9 @@ class MPWeightReceiver(WeightReceiver): class MPWeightSender(WeightSender): - """Weight sender for multiprocess systems using pipes. + """Weight sender for multiprocess systems using queues. - Sends weight updates to worker processes via multiprocessing pipes. + Sends weight updates to worker processes via multiprocessing queues. Supports both synchronous and asynchronous sending patterns. This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. """ @@ -755,7 +540,7 @@ def synchronize_weights(self) -> None: """Synchronize weights with workers before collection starts. Computes device-specific weight copies on-demand and sends them to workers - sequentially via pipes. This is called once after workers are initialized + sequentially via queues. This is called once after workers are initialized but before they start collecting data. Unlike send(), this does not wait for acknowledgments since workers are still @@ -799,7 +584,7 @@ def synchronize_weights(self) -> None: num_workers=num_workers, ) - # Send to workers sequentially via pipes (no ACK - workers are still initializing) + # Send to workers sequentially via queues (no ACK - workers are still initializing) # This allows GC to clean up each worker's weights before creating the next for i, transport in enumerate(self._iterate_transports()): worker_weights = params_map[i] diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 1f3ff01ea30..fbb90f8ff34 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -34,7 +34,7 @@ def init_on_sender( ) -> None: ... - def init_on_sender( + def _init_on_sender_impl( self, model_id: str, context: Any = None, diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index b8d344a9df8..0dff3db7417 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -231,7 +231,7 @@ def init_on_sender( ) -> None: ... - def init_on_sender( + def _init_on_sender_impl( self, model_id: str, context: Any = None, diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index b8e7e815917..d12292c95ba 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -165,92 +165,7 @@ def __init__( # General message queue for coordination (if needed in future) self._message_queue = mp.Queue() - @overload - def init_on_sender( - self, - *, - model_id: str, - context: Any, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - model_id: str | None = None, - ) -> None: - ... - - def init_on_sender( + def _init_on_sender_impl( self, *, model_id: str | None = None, diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 52806416aa8..13a11b7b24b 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -7,11 +7,12 @@ import abc import warnings import weakref -from collections.abc import Iterator -from typing import Any, Literal, Protocol +from collections.abc import Callable, Iterator +from typing import Any, Literal, overload, Protocol -from tensordict import TensorDict, TensorDictBase +import torch +from tensordict import TensorDict, TensorDictBase from torch import nn __all__ = [ @@ -641,28 +642,112 @@ def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict") self._initialized_on_sender = False self._initialized_on_worker = False + @overload def init_on_sender( self, + *, model_id: str, - context: Any = None, + context: Any, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + def init_on_sender( + self, + *args, **kwargs, ) -> None: """Initialize on the main process (sender side). This method is called once in the collector's _run_processes() method, after workers have been started and are ready to receive messages. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., collector) providing: - - .pipes: list[mp.Connection] - - .get_model(model_id: str) -> nn.Module - - .get_cached_weights(model_id: str) -> TensorDict | None - - .num_workers: int - **kwargs: Alternative to context (pipes, num_workers, model, cached_weights, etc.) """ + result = self._init_on_sender_impl(*args, **kwargs) + self._initialized_on_sender = True + return result + + def _init_on_sender_impl(self, *args, **kwargs): raise NotImplementedError + @property + def initialized_on_sender(self): + return getattr(self, "_initialized_on_sender", False) + def init_on_receiver( self, model_id: str, From b60d39fb31b85d957a340cc48168073849fad55c Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 25 Nov 2025 16:54:50 +0000 Subject: [PATCH 17/17] intermediate-fix --- .../benchmark_sample_latency_over_rpc.py | 2 +- .../distributed_replay_buffer.py | 2 +- test/test_distributed.py | 67 +-- test/test_weightsync.py | 4 +- torchrl/_utils.py | 6 +- torchrl/collectors/__init__.py | 3 +- torchrl/collectors/{base.py => _base.py} | 118 ++++- torchrl/collectors/_multi_async.py | 3 + torchrl/collectors/_multi_base.py | 38 +- torchrl/collectors/_multi_sync.py | 3 + torchrl/collectors/_runner.py | 69 +-- torchrl/collectors/_single.py | 19 +- torchrl/collectors/collectors.py | 3 +- torchrl/collectors/distributed/generic.py | 390 ++++++++-------- torchrl/collectors/distributed/ray.py | 8 +- torchrl/collectors/distributed/rpc.py | 241 +++++----- torchrl/collectors/distributed/sync.py | 85 ++-- torchrl/collectors/distributed/utils.py | 6 +- torchrl/collectors/utils.py | 45 +- torchrl/weight_update/_distributed.py | 244 ++++++---- torchrl/weight_update/_mp.py | 417 +++++++++--------- torchrl/weight_update/_noupdate.py | 43 +- torchrl/weight_update/_ray.py | 388 ++++++++-------- torchrl/weight_update/_rpc.py | 265 ++++++++--- torchrl/weight_update/_shared.py | 90 ++-- .../weight_update/llm/vllm_double_buffer.py | 6 +- torchrl/weight_update/llm/vllm_nccl.py | 4 +- torchrl/weight_update/weight_sync_schemes.py | 186 +++++++- 28 files changed, 1572 insertions(+), 1183 deletions(-) rename torchrl/collectors/{base.py => _base.py} (80%) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 4af76440290..bf92deb1284 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -144,7 +144,7 @@ def __init__(self, capacity: int): rank = args.rank storage_type = args.storage - torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}") + torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index f92f78de7e1..df522443c06 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -172,7 +172,7 @@ def __init__(self, capacity: int): if __name__ == "__main__": args = parser.parse_args() rank = args.rank - torchrl_logger.info(f"Rank: {rank}") + torchrl_logger.debug(f"RANK: {rank}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/test/test_distributed.py b/test/test_distributed.py index 6183132394e..761a7652d79 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -10,35 +10,21 @@ import abc import argparse +import importlib import os +import socket import sys import time from functools import partial import pytest -from tensordict import TensorDict -from tensordict.nn import TensorDictModuleBase -from torchrl._utils import logger as torchrl_logger -from torchrl.data import ( - LazyTensorStorage, - RandomSampler, - RayReplayBuffer, - RoundRobinWriter, - SamplerWithoutReplacement, -) - -try: - import ray - - _has_ray = True - RAY_ERR = None -except ModuleNotFoundError as err: - _has_ray = False - RAY_ERR = err import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModuleBase from torch import multiprocessing as mp, nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import ( MultiaSyncDataCollector, @@ -52,8 +38,17 @@ RPCDataCollector, ) from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG +from torchrl.data import ( + LazyTensorStorage, + RandomSampler, + RayReplayBuffer, + RoundRobinWriter, + SamplerWithoutReplacement, +) from torchrl.envs.utils import RandomPolicy +_has_ray = importlib.util.find_spec("ray") is not None + if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test.mocking_classes import ContinuousActionVecMockEnv, CountingEnv else: @@ -115,7 +110,6 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): **cls.distributed_kwargs(), ) total = 0 - torchrl_logger.info("getting data...") for data in collector: total += data.numel() assert data.numel() == frames_per_batch @@ -289,7 +283,9 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): n_collectors = 1 else: n_collectors = 2 - collector = cls.distributed_class()( + dcls = cls.distributed_class() + torchrl_logger.info(f"Using distributed collector {dcls}") + collector = dcls( [env] * n_collectors, policy, collector_class=collector_class, @@ -307,6 +303,7 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): if i == 0: first_batch = data policy.weight.data += 1 + torchrl_logger.info("TEST -- Calling update_policy_weights_()") collector.update_policy_weights_() elif total == total_frames - frames_per_batch: last_batch = data @@ -338,7 +335,8 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): proc.start() try: out = queue.get(timeout=TIMEOUT) - assert out == "passed" + if out != "passed": + raise AssertionError(out) finally: proc.join(10) if proc.is_alive(): @@ -353,7 +351,13 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + # Pick an ephemeral free TCP port on localhost for each test process to + # avoid address-in-use errors when tests are run repeatedly or in quick + # succession. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -367,7 +371,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -381,7 +388,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -459,7 +469,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv queue.close() -@pytest.mark.skipif(not _has_ray, reason=f"Ray not found (error: {RAY_ERR})") +@pytest.mark.skipif( + not _has_ray, reason="Ray not found. Ray may be badly configured or not installed." +) class TestRayCollector(DistributedCollectorBase): """A testing distributed data collector class that runs tests without using a Queue, to avoid potential deadlocks when combining Ray and multiprocessing. @@ -467,6 +479,7 @@ class TestRayCollector(DistributedCollectorBase): @pytest.fixture(autouse=True, scope="class") def start_ray(self): + import ray from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG ray.init(**DEFAULT_RAY_INIT_CONFIG) @@ -480,6 +493,8 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: + import ray + ray.shutdown() # make sure ray is not running ray_init_config = DEFAULT_RAY_INIT_CONFIG ray_init_config["runtime_env"] = { diff --git a/test/test_weightsync.py b/test/test_weightsync.py index b75186c4afe..2e0a8fc0dfc 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -638,7 +638,7 @@ def test_multiprocess_scheme_serialize_before_init(self): assert restored._sender is None assert restored._receiver is None assert not restored._initialized_on_sender - assert not restored._initialized_on_worker + assert not restored._initialized_on_receiver def test_multiprocess_scheme_serialize_after_sender_init(self): """Test that initialized sender can be pickled (excluding runtime state).""" @@ -660,7 +660,7 @@ def test_multiprocess_scheme_serialize_after_sender_init(self): assert restored._sender is None # Runtime state excluded assert restored._receiver is None assert not restored._initialized_on_sender # Reset - assert not restored._initialized_on_worker + assert not restored._initialized_on_receiver # Clean up parent_pipe.close() diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 50cfa8af7d3..2b2fe5f622e 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -52,7 +52,7 @@ def strtobool(val: Any) -> bool: LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") -logger.setLevel(getattr(logging, LOGGING_LEVEL)) +logger.setLevel(LOGGING_LEVEL) logger.propagate = False # Clear existing handlers while logger.hasHandlers(): @@ -85,7 +85,9 @@ def format(self, record): console_handler = logging.StreamHandler(stream=stream_handler) console_handler.setFormatter(_CustomFormatter()) logger.addHandler(console_handler) -console_handler.setLevel(logging.INFO) + +console_handler.setLevel(LOGGING_LEVEL) +logger.debug(f"Logging level: {logger.getEffectiveLevel()}") VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) _os_is_windows = sys.platform == "win32" diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 5e2ef63fb69..98b44cc39ec 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -5,12 +5,13 @@ from torchrl.envs.utils import RandomPolicy +from ._base import DataCollectorBase + from ._multi_async import MultiaSyncDataCollector from ._multi_sync import MultiSyncDataCollector from ._single import SyncDataCollector from ._single_async import aSyncDataCollector -from .base import DataCollectorBase from .weight_update import ( MultiProcessedWeightUpdater, RayWeightUpdater, diff --git a/torchrl/collectors/base.py b/torchrl/collectors/_base.py similarity index 80% rename from torchrl/collectors/base.py rename to torchrl/collectors/_base.py index 1ad97d4056f..d94d5ac4bca 100644 --- a/torchrl/collectors/base.py +++ b/torchrl/collectors/_base.py @@ -16,10 +16,11 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase from torch import nn as nn from torch.utils.data import IterableDataset +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.utils import _map_weight from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.weight_update import WeightReceiver, WeightSender, WeightSyncScheme +from torchrl.weight_update import WeightSyncScheme class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): @@ -35,8 +36,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): cudagraphed_policy: bool _weight_updater: WeightUpdaterBase | None = None _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None - _weight_senders: dict[str, WeightSender] | None = None - _weight_receivers: dict[str, WeightReceiver] | None = None verbose: bool = False @property @@ -320,40 +319,81 @@ def _weight_update_impl( if policy_or_weights is not None: weights_dict = {"policy": policy_or_weights} - # Priority: new weight sync schemes > old weight updater system - if self._weight_senders: - if model_id is not None: + if self._weight_sync_schemes: + if model_id is None: + model_id = "policy" + if weights_dict is None: # Compose weight_dict weights_dict = {model_id: policy_or_weights} - if weights_dict is None: - if "policy" in self._weight_senders: - weights_dict = {"policy": policy_or_weights} - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - weights_dict = {single_model_id: policy_or_weights} - else: - raise ValueError( - "Cannot determine the model to update. Please provide a weights_dict." - ) for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: + if target_model_id not in self._weight_sync_schemes: raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" + f"Model '{target_model_id}' not found in registered weight sync schemes. " + f"Available models: {list(self._weight_sync_schemes.keys())}" ) processed_weights = self._extract_weights_if_needed( weights, target_model_id ) # Use new send() API with worker_ids support - self._weight_senders[target_model_id].send( - weights=processed_weights, worker_ids=worker_ids + torchrl_logger.debug("weight update -- getting scheme") + scheme = self._weight_sync_schemes.get(target_model_id) + if not isinstance(scheme, WeightSyncScheme): + raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}") + torchrl_logger.debug( + f"calling send() on scheme {type(scheme).__name__}" ) + scheme.send(weights=processed_weights, worker_ids=worker_ids) elif self._weight_updater is not None: # unreachable raise RuntimeError else: return self.receive_weights(policy_or_weights) + def _receive_weights_scheme(self): + """Receive weights via registered receiver schemes and cascade to nested collectors. + + This method enables cascading weight updates across multiple collector layers: + - RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector + - DistributedDataCollector -> MultiSyncDataCollector -> SyncDataCollector + + Process: + 1. Receive weights for all registered receiver schemes (_receiver_schemes) + 2. If this collector has nested collectors (_weight_sync_schemes), propagate + the updates by calling update_policy_weights_() + + """ + # Receive weights for all registered schemes + updates = {} + if not hasattr(self, "_receiver_schemes"): + raise RuntimeError("No receiver schemes registered.") + + for model_id, scheme in self._receiver_schemes.items(): + # scheme.receive() pulls weights from the transport and applies them locally + # For RPC/Ray: weights are already passed as argument, receive() is a no-op + # For Distributed: receive() pulls from TCPStore + # For MultiProcess: receive() checks the pipe + received_weights = scheme.receive() + if received_weights is not None: + updates[model_id] = received_weights + + # If we have nested collectors (e.g., MultiSyncDataCollector with inner workers) + # AND we actually received updates, propagate them down via their senders + if ( + updates + and hasattr(self, "_weight_sync_schemes") + and self._weight_sync_schemes + ): + # Build weights_dict for all models that need propagation to nested collectors + weights_dict = {} + for model_id in updates: + if model_id in self._weight_sync_schemes: + # This model has a sender scheme - propagate to nested workers + weights_dict[model_id] = updates[model_id] + + if weights_dict: + # Propagate to nested collectors via their sender schemes + self.update_policy_weights_(weights_dict=weights_dict) + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): # No weight updater configured # For single-process collectors, apply weights locally if explicitly provided @@ -389,6 +429,42 @@ def receive_weights(self, policy_or_weights: TensorDictBase | None = None): strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible + def _set_scheme_receiver(self, weight_sync_schemes: dict[str, WeightSyncScheme]): + """Set up receiver schemes for this collector. + + This method initializes receiver schemes and stores them in _receiver_schemes + for later use by _receive_weights_scheme() and receive_weights(). + + Args: + weight_sync_schemes: Dictionary of {model_id: WeightSyncScheme} to set up as receivers + """ + # Initialize _receiver_schemes if not already present + if not hasattr(self, "_receiver_schemes"): + self._receiver_schemes = {} + + # Initialize each scheme on the receiver side + for model_id, scheme in weight_sync_schemes.items(): + if not scheme.initialized_on_receiver: + if scheme.initialized_on_sender: + raise RuntimeError( + "Weight sync scheme cannot be initialized on both sender and receiver." + ) + scheme.init_on_receiver( + model_id=model_id, + context=self, + worker_idx=getattr(self, "_worker_idx", None), + ) + + # Store the scheme for later use in receive_weights() + self._receiver_schemes[model_id] = scheme + + # Perform initial synchronization + for scheme in weight_sync_schemes.values(): + if not scheme.synchronized_on_receiver: + scheme.synchronize_weights( + worker_idx=getattr(self, "_worker_idx", None) + ) + def __iter__(self) -> Iterator[TensorDictBase]: try: yield from self.iterator() diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py index 6e9b3a55f7b..fb6126c6c5f 100644 --- a/torchrl/collectors/_multi_async.py +++ b/torchrl/collectors/_multi_async.py @@ -293,3 +293,6 @@ def reset(self, reset_idx: Sequence[bool] | None = None) -> None: self.pipes[idx].send((idx, "continue_random")) else: self.pipes[idx].send((idx, "continue")) + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 912ecfd3e6f..244f8b41e46 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -16,6 +16,7 @@ from torch import multiprocessing as mp, nn from torchrl import logger as torchrl_logger from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, RL_WARNINGS +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import ( _InterruptorManager, _is_osx, @@ -25,7 +26,6 @@ ) from torchrl.collectors._runner import _main_async_collector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -37,6 +37,7 @@ SharedMemWeightSyncScheme, WeightSyncScheme, ) +from torchrl.weight_update.utils import _resolve_model class _MultiDataCollector(DataCollectorBase): @@ -357,8 +358,8 @@ def __init__( self.policy = policy self.policy_factory = policy_factory - # Set up fallback policy for weight extraction - self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + # # Set up fallback policy for weight extraction + # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) # Set up total frames and other parameters self._setup_multi_total_frames( @@ -518,7 +519,7 @@ def _setup_multi_policy_and_weights( if weight_sync_policy is None: return if any(p is not None for p in policy_factory): - if not weight_sync_policy._initialized_on_sender: + if not weight_sync_policy.initialized_on_sender: raise RuntimeError( f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" ) @@ -574,6 +575,7 @@ def _setup_multi_policy_and_weights_legacy( # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default if weight_sync_schemes is None: weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + self._weight_sync_schemes = weight_sync_schemes elif weight_updater is None: warnings.warn( "weight_updater is None, but policy_factory is provided. This means that the server will " @@ -593,14 +595,12 @@ def _setup_multi_weight_sync( if weight_sync_schemes is not None: # Use weight sync schemes for weight distribution self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} # Senders will be created in _run_processes self.weight_updater = None else: # Use weight updater for weight distribution self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} def _setup_multi_policy_version_tracking( self, track_policy_version: bool | PolicyVersion @@ -621,6 +621,7 @@ def _setup_multi_policy_version_tracking( ) self.policy_version_tracker = None + # TODO: Remove this def _setup_fallback_policy( self, policy: TensorDictModule | Callable | None, @@ -837,7 +838,6 @@ def _run_processes(self) -> None: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) - self._weight_senders[model_id] = scheme.get_sender() # Create a policy on the right device policy_factory = self.policy_factory @@ -972,9 +972,15 @@ def _run_processes(self) -> None: # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: # Workers will call receiver.synchronize_weights() during init and may block waiting for data - if self._weight_senders: - for sender in self._weight_senders.values(): - sender.synchronize_weights() + if self._weight_sync_schemes: + # start with policy + policy_scheme = self._weight_sync_schemes.get("policy") + if policy_scheme is not None: + policy_scheme.synchronize_weights() + for key, scheme in self._weight_sync_schemes.items(): + if key == "policy": + continue + scheme.synchronize_weights() # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): @@ -1414,18 +1420,15 @@ def get_model(self, model_id: str): """ if model_id == "policy": # Return the fallback policy instance - if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: - return self._fallback_policy + if (fallback_policy := getattr(self, "_fallback_policy", None)) is not None: + return fallback_policy elif hasattr(self, "policy") and self.policy is not None: return self.policy else: raise ValueError(f"No policy found for model_id '{model_id}'") else: # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") + return _resolve_model(self, model_id) def get_cached_weights(self, model_id: str): """Get cached shared memory weights if available (for weight sync schemes). @@ -1445,3 +1448,6 @@ def get_cached_weights(self, model_id: str): # Return cached weights for this device return self._policy_weights_dict.get(policy_device) return None + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py index 3f475673a30..9fd5d24c1f2 100644 --- a/torchrl/collectors/_multi_sync.py +++ b/torchrl/collectors/_multi_sync.py @@ -428,3 +428,6 @@ def iterator(self) -> Iterator[TensorDictBase]: self.out_buffer = None # We shall not call shutdown just yet as user may want to retrieve state_dict # self._shutdown_main() + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 63d1d0c2cd1..eec9b6dba87 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -13,6 +13,7 @@ from torchrl import logger as torchrl_logger from torchrl._utils import VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import ( _MAX_IDLE_COUNT, _MIN_TIMEOUT, @@ -20,36 +21,19 @@ DEFAULT_EXPLORATION_TYPE, ) from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase -from torchrl.collectors.utils import _cast, _map_to_cpu_if_needed, _TrajectoryPool +from torchrl.collectors.utils import ( + _cast, + _make_policy_factory, + _map_to_cpu_if_needed, + _TrajectoryPool, +) from torchrl.data import ReplayBuffer from torchrl.envs import EnvBase, EnvCreator from torchrl.envs.utils import ExplorationType from torchrl.weight_update import WeightSyncScheme -def _make_policy_factory( - *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None -): - if policy is not None and policy_factory is not None: - raise ValueError("policy cannot be used with policy_factory") - elif policy_factory is not None: - policy = policy_factory() - - if weight_sync_scheme is not None: - # Initialize the receiver on the worker side - weight_sync_scheme.init_on_receiver( - model=policy, - model_id="policy", - worker_idx=worker_idx, - ) - # Get the receiver and synchronize initial weights - receiver = weight_sync_scheme.get_receiver() - receiver.synchronize_weights(worker_idx=worker_idx) - return policy - - def _main_async_collector( pipe_parent: connection.Connection, pipe_child: connection.Connection, @@ -130,31 +114,18 @@ def _main_async_collector( compile_policy=compile_policy, cudagraph_policy=cudagraph_policy, no_cuda_sync=no_cuda_sync, - weight_sync_schemes=weight_sync_schemes, + # We don't pass the weight sync scheme as only the sender has the weight sync scheme within. + # weight_sync_schemes=weight_sync_schemes, + worker_idx=worker_idx, ) # Set up weight receivers for worker process # Note: For the "policy" model, initialization is done in _make_policy_factory # This section only handles additional models (not "policy") if weight_sync_schemes: - inner_collector._weight_receivers = {} - inner_collector.pipe = pipe_child # Add pipe attribute for context - inner_collector.worker_idx = ( - worker_idx # Add worker index for queue-based schemes - ) - for model_id, scheme in weight_sync_schemes.items(): - if model_id == "policy": - # Policy receiver was already initialized in _make_policy_factory - receiver = scheme.get_receiver() - inner_collector._weight_receivers[model_id] = receiver - else: - # Initialize receivers for other models + if not scheme.initialized_on_receiver: scheme.init_on_receiver(model_id=model_id, context=inner_collector) - receiver = scheme.get_receiver() - receiver.synchronize_weights(worker_idx=worker_idx) - inner_collector._weight_receivers[model_id] = receiver - else: - inner_collector._weight_receivers = {} + scheme.synchronize_weights() use_buffers = inner_collector._use_buffers if verbose: @@ -256,6 +227,7 @@ def _main_async_collector( # to allow falling through from update_weights to continue if msg == "update": + # Legacy - weight updater torchrl_logger.info(f"worker {idx} updating the params...") inner_collector.update_policy_weights_(policy_weights=data_in) pipe_child.send((j, "updated")) @@ -317,7 +289,7 @@ def _main_async_collector( continue if msg == "update_weights": - # New weight update protocol for simplified weight sync system + # weight update protocol with schemes if verbose: torchrl_logger.info( f"worker {idx} received weight update via new protocol" @@ -325,15 +297,10 @@ def _main_async_collector( model_id, weights = data_in # Apply weights using the appropriate receiver for this model - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - inner_collector._weight_receivers[model_id].apply_weights(weights) - else: - torchrl_logger.warning( - f"worker {idx} received weights for unknown model '{model_id}'" - ) + scheme = inner_collector._weight_sync_schemes.get(model_id) + if scheme is None: + raise KeyError(f"Model '{model_id}' not registered") + scheme.apply_weights(weights) # After applying weights, we continue collecting immediately as if we received # a "continue" message. This ensures the worker keeps collecting data without diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 7beda2deb63..13cbd544537 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -22,12 +22,12 @@ prod, RL_WARNINGS, ) +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import ( cudagraph_mark_step_begin, DEFAULT_EXPLORATION_TYPE, ExplorationType, ) -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _TrajectoryPool, split_trajectories from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -41,6 +41,7 @@ set_exploration_type, ) from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.utils import _resolve_model @accept_remote_rref_udf_invocation @@ -311,9 +312,11 @@ def __init__( | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, + worker_idx: int | None = None, **kwargs, ): self.closed = True + self._worker_idx = worker_idx # Initialize environment env = self._init_env(create_env_fn, create_env_kwargs) @@ -791,7 +794,6 @@ def _setup_weight_sync( if weight_sync_schemes is not None: # Use new simplified weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} # For single-process collectors, we don't need senders/receivers # The policy is local and changes are immediately visible # Senders will be set up in multiprocess collectors during _run_processes @@ -813,12 +815,10 @@ def _setup_weight_sync( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} else: # No weight sync needed for single-process collectors self.weight_updater = None self._weight_sync_schemes = None - self._weight_senders = {} @property def _traj_pool(self): @@ -1545,7 +1545,7 @@ def rollout(self) -> TensorDictBase: break else: if self._use_buffers: - torchrl_logger.info("Returning final rollout within buffer.") + torchrl_logger.debug("Returning final rollout within buffer.") result = self._final_rollout try: result = torch.stack( @@ -1792,8 +1792,7 @@ def get_model(self, model_id: str): else: raise ValueError(f"No policy found for model_id '{model_id}'") else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") + return _resolve_model(self, model_id) + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d0f1c1f765a..5af173a40c4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -5,6 +5,8 @@ """Re-exports of collector classes for backward compatibility.""" from __future__ import annotations +from torchrl.collectors._base import DataCollectorBase + # Re-export constants for backward compatibility from torchrl.collectors._constants import ( _Interruptor, @@ -24,7 +26,6 @@ from torchrl.collectors._runner import _main_async_collector from torchrl.collectors._single import SyncDataCollector from torchrl.collectors._single_async import aSyncDataCollector -from torchrl.collectors.base import DataCollectorBase __all__ = [ "MultiSyncDataCollector", diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 61180a3cb21..58359b8de95 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -20,21 +20,22 @@ from tensordict.nn import TensorDictModuleBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, TCP_PORT, ) -from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.collectors.utils import _cast, _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update import DistributedWeightSyncScheme from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None @@ -52,11 +53,11 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"Rank0 IP address: '{rank0_ip}' \ttcp port: '{tcpport}', backend={backend}." ) - torchrl_logger.info( - f"node with rank {rank} with world_size {world_size} -- launching distributed" + torchrl_logger.debug( + f"RANK {rank} with world_size {world_size} -- launching distributed" ) torch.distributed.init_process_group( backend, @@ -66,7 +67,7 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"Connected!\nNode with rank {rank} -- creating store") + torchrl_logger.debug(f"Connected!\nRANK {rank} -- creating store") # The store carries instructions for the node _store = torch.distributed.TCPStore( host_name=rank0_ip, @@ -106,19 +107,20 @@ def _distributed_init_delayed( frames_per_batch = output["frames_per_batch"] collector_kwargs = output["collector_kwargs"] _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + frames_per_batch=frames_per_batch, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -132,24 +134,27 @@ def _distributed_init_collection_node( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes, verbose=True, ): _store = _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose) _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - policy_factory, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + policy_factory=policy_factory, + frames_per_batch=frames_per_batch, + weight_sync_schemes=weight_sync_schemes, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _run_collector( + *, _store, sync, collector_class, @@ -159,12 +164,13 @@ def _run_collector( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes: dict[str, DistributedWeightSyncScheme], verbose=True, ): rank = torch.distributed.get_rank() if verbose: - torchrl_logger.info( - f"node with rank {rank} -- creating collector of type {collector_class}" + torchrl_logger.debug( + f"RANK {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): env_make = [env_make] * num_workers @@ -177,7 +183,7 @@ def _run_collector( if isinstance(policy, nn.Module): policy_weights = TensorDict.from_module(policy) - policy_weights = policy_weights.data.lock_() + policy_weights = policy_weights.data.apply(_cast, policy_weights).lock_() else: if collector_kwargs.get("weight_updater") is None and ( policy_factory is None @@ -186,50 +192,113 @@ def _run_collector( warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + torchrl_logger.debug(f"RANK {rank} -- init collector") collector = collector_class( env_make, - policy, + policy=policy, policy_factory=policy_factory, frames_per_batch=frames_per_batch, total_frames=-1, split_trajs=False, **collector_kwargs, ) + + if weight_sync_schemes is not None: + for model_id, scheme in weight_sync_schemes.items(): + torchrl_logger.debug(f"RANK {rank} -- init receiver for model '{model_id}'") + # Provide both collector context and distributed store / rank so the + # scheme can wire its transport correctly. + scheme.init_on_receiver( + model_id=model_id, + context=collector, + store=_store, + rank=rank, + ) + torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") + scheme.synchronize_weights() + torchrl_logger.debug( + f"RANK {rank} -- initial weight sync for '{model_id}' completed" + ) + else: + torchrl_logger.debug( + f"RANK {rank} -- {collector_class.__name__} without weight_sync_schemes \n\n" + ) + total_frames = 0 - if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") while True: + if verbose: + torchrl_logger.debug(f"RANK {rank} -- waiting for instructions") instruction = _store.get(f"NODE_{rank}_in") if verbose: - torchrl_logger.info( - f"node with rank {rank} -- new instruction: {instruction}" - ) + torchrl_logger.debug(f"RANK {rank} -- new instruction: {instruction}") _store.delete_key(f"NODE_{rank}_in") if instruction == b"continue": _store.set(f"NODE_{rank}_status", b"busy") if verbose: - torchrl_logger.info(f"node with rank {rank} -- new data") + torchrl_logger.debug(f"RANK {rank} -- collecting new data") data = collector.next() total_frames += data.numel() if verbose: - torchrl_logger.info(f"got data, total frames = {total_frames}") - torchrl_logger.info(f"node with rank {rank} -- sending {data}") + torchrl_logger.debug( + f"RANK {rank} -- got data, total frames = {total_frames}" + ) + torchrl_logger.debug( + f"RANK {rank} -- data batch_size={data.batch_size}, " + f"keys={list(data.keys(False, True))}" + ) + torchrl_logger.debug( + f"RANK {rank} -- sending TensorDict payload to rank 0" + ) + torchrl_logger.debug(f"RANK {rank} -- {data=}") + if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) if verbose: - torchrl_logger.info(f"node with rank {rank} -- setting to 'done'") + torchrl_logger.debug(f"RANK {rank} -- setting to 'done'") if not sync: _store.set(f"NODE_{rank}_status", b"done") + if verbose: + torchrl_logger.debug(f"RANK {rank} -- set to 'done'") + elif instruction == b"shutdown": if verbose: - torchrl_logger.info(f"node with rank {rank} -- shutting down") + torchrl_logger.debug(f"RANK {rank} -- shutting down") try: collector.shutdown() except Exception: pass _store.set(f"NODE_{rank}_out", b"down") break + elif instruction == b"update_weights": + if verbose: + torchrl_logger.debug(f"RANK {rank} -- updating weights") + + if weight_sync_schemes is not None: + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- using weight sync schemes for update" + ) + # Receive fresh weights from the main process for each model + for model_id, scheme in weight_sync_schemes.items(): + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- receiving weights for model '{model_id}'" + ) + scheme.receive() + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- received weights for model '{model_id}'" + ) + + # Propagate updated weights to inner workers via the nested + # collector's own weight sync schemes. + collector.update_policy_weights_() + + # Acknowledgment is handled by the transport (send_ack in the + # WeightReceiver), so we can continue without touching the + # TCPStore here. + continue if sync: policy_weights.recv(0) else: @@ -463,6 +532,9 @@ def __init__( weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -562,11 +634,6 @@ def __init__( self.backend = backend - # os.environ['TP_SOCKET_IFNAME'] = 'lo' - - self._init_workers() - self._make_container() - # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Distributed weight sync scheme for distributed collectors @@ -577,37 +644,12 @@ def __init__( } if weight_sync_schemes is not None: + torchrl_logger.debug("RANK 0 -- Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - rank = i + 1 # Workers are 1-indexed in distributed - transport = scheme.create_transport((self._store, rank)) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: + torchrl_logger.debug("RANK 0 -- Using weight updater") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = DistributedWeightUpdater( @@ -618,7 +660,17 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init_workers() + if self._weight_sync_schemes is not None: + # Initialize schemes on the sender (main process) side now that + # worker processes and the store have been created. + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + num_workers=self.num_workers, context=self, model_id=model_id + ) + + self._make_container() @property def device(self) -> list[torch.device]: @@ -685,11 +737,10 @@ def _init_master_dist( world_size, backend, ): - if self._VERBOSE: - torchrl_logger.info( - f"launching main node with tcp port '{self.tcp_port}' and " - f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." - ) + torchrl_logger.debug( + f"RANK 0 -- launching main node with tcp port '{self.tcp_port}' and " + f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." + ) os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -701,8 +752,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - if self._VERBOSE: - torchrl_logger.info("main initiated! Launching store...") + torchrl_logger.debug("RANK 0 -- main initiated! Launching store...") self._store = torch.distributed.TCPStore( host_name=self.IPAddr, port=int(TCP_PORT) + 1, @@ -710,15 +760,20 @@ def _init_master_dist( is_master=True, timeout=timedelta(10), ) - if self._VERBOSE: - torchrl_logger.info("done. Setting status to 'alive'") + torchrl_logger.debug("RANK 0 -- done. Setting status to 'alive'") self._store.set("TRAINER_status", b"alive") def _make_container(self): - if self._VERBOSE: - torchrl_logger.info("making container") + torchrl_logger.debug("RANK 0 -- making container") env_constructor = self.env_constructors[0] - kwargs = self.collector_kwargs[0] + kwargs = self.collector_kwargs[ + 0 + ].copy() # Create a copy to avoid modifying the original + # Mirror the SyncDataCollector configuration used on the workers so + # that the dummy batch structure matches what remote ranks will send. + # _run_collector always sets return_same_td=True for SyncDataCollector, + # so we must do the same here to ensure structural consistency. + kwargs["return_same_td"] = True pseudo_collector = SyncDataCollector( env_constructor, policy=self.policy, @@ -730,12 +785,15 @@ def _make_container(self): ) for _data in pseudo_collector: break - if self._VERBOSE: - torchrl_logger.info(f"got data {_data}") - torchrl_logger.info("expanding...") - self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) - if self._VERBOSE: - torchrl_logger.info("locking") + torchrl_logger.debug(f"RANK 0 -- got dummy batch: {_data}") + torchrl_logger.debug("RANK 0 -- expanding...") + self._tensordict_out = ( + _data.expand((self.num_workers, *_data.shape)).clone().to_lazystack(0) + ) + torchrl_logger.debug( + f"RANK 0 -- expanded recv buffer spec: {self._tensordict_out}" + ) + torchrl_logger.debug("RANK 0 -- locking") if self._sync: self._tensordict_out.lock_() self._tensordict_out_unbind = self._tensordict_out.unbind(0) @@ -745,12 +803,10 @@ def _make_container(self): self._tensordict_out = self._tensordict_out.unbind(0) for td in self._tensordict_out: td.lock_() - if self._VERBOSE: - torchrl_logger.info("storage created:") - torchrl_logger.info("shutting down...") + torchrl_logger.debug("RANK 0 -- storage created:") + torchrl_logger.debug("RANK 0 -- shutting down...") pseudo_collector.shutdown() - if self._VERBOSE: - torchrl_logger.info("dummy collector shut down!") + torchrl_logger.debug("RANK 0 -- dummy collector shut down!") del pseudo_collector def _init_worker_dist_submitit(self, executor, i): @@ -760,20 +816,21 @@ def _init_worker_dist_submitit(self, executor, i): TCP_PORT = self.tcp_port job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + weight_sync_schemes=self._weight_sync_schemes, + collector_kwargs=self.collector_kwargs[i], + verbose=self._VERBOSE, ) return job @@ -808,21 +865,22 @@ def _init_worker_dist_mp(self, i): TCP_PORT = self.tcp_port job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + weight_sync_schemes=self._weight_sync_schemes, + verbose=self._VERBOSE, ), ) job.start() @@ -835,8 +893,7 @@ def _init_workers(self): IPAddr = socket.gethostbyname(hostname) else: IPAddr = "localhost" - if self._VERBOSE: - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"RANK 0 -- Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -851,21 +908,20 @@ def _init_workers(self): self._init_worker_dist_submitit_delayed() else: for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info("Submitting job") + torchrl_logger.debug("RANK 0 -- Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug( + f"RANK 0 -- job id {job.job_id}" + ) # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - if self._VERBOSE: - torchrl_logger.info("job launched") + torchrl_logger.debug("RANK 0 -- job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) @@ -873,21 +929,21 @@ def iterator(self): yield from self._iterator_dist() def _iterator_dist(self): - if self._VERBOSE: - torchrl_logger.info("iterating...") + torchrl_logger.debug("RANK 0 -- iterating...") total_frames = 0 if not self._sync: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=}") trackers.append( self._tensordict_out[i].irecv(src=rank, return_premature=True) ) + torchrl_logger.debug(f"RANK 0 -- trackers: {trackers}") while total_frames < self.total_frames: if self._sync: @@ -908,19 +964,22 @@ def _iterator_dist(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): + torchrl_logger.debug(f"RANK 0 -- updating weights for {rank=}") self.update_policy_weights_( policy_weights=None, worker_ids=rank ) for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down rank {rank}.") + torchrl_logger.debug(f"RANK 0 -- shutting down rank {rank}.") self._store.set(f"NODE_{rank}_in", b"shutdown") def _next_sync(self, total_frames): # in the 'sync' case we should update before collecting the data if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {total_frames=} in _next_sync." + ) self.update_policy_weights_() else: for j in range(self.num_workers): @@ -928,12 +987,12 @@ def _next_sync(self, total_frames): if total_frames < self.total_frames: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_sync.") trackers.append( self._tensordict_out_unbind[i].irecv(src=rank, return_premature=True) ) @@ -954,16 +1013,21 @@ def _next_async(self, total_frames, trackers): while data is None: for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- checking {rank=} in _next_async.") if self._store.get(f"NODE_{rank}_status") == b"done": + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_async.") for _tracker in trackers[i]: _tracker.wait() + torchrl_logger.debug(f"RANK 0 -- received {rank=} in _next_async.") data = self._tensordict_out[i].clone() if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {rank=} in _next_async." + ) self.update_policy_weights_(worker_ids=rank) total_frames += data.numel() if total_frames < self.total_frames: - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers[i] = self._tensordict_out[i].irecv( src=i + 1, return_premature=True @@ -973,34 +1037,6 @@ def _next_async(self, total_frames, trackers): break return data, total_frames - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For distributed collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - # For distributed collectors, we need TensorDict format for isend/irecv - from tensordict import TensorDict - - return TensorDict.from_module(sender._source_model).data.lock_() - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 @@ -1024,13 +1060,11 @@ def shutdown(self, timeout: float | None = None) -> None: self._store.set("TRAINER_status", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down node with rank={rank}") + torchrl_logger.debug(f"shutting down node with rank={rank}") self._store.set(f"NODE_{rank}_in", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"getting status of node {rank}") + torchrl_logger.debug(f"getting status of node {rank}") status = self._store.get(f"NODE_{rank}_out") if status != b"down": raise RuntimeError(f"Expected 'down' but got status {status}.") @@ -1044,13 +1078,16 @@ def shutdown(self, timeout: float | None = None) -> None: self.jobs[i].result() elif self.launcher == "submitit_delayed": pass - if self._VERBOSE: - torchrl_logger.info("collector shut down") + torchrl_logger.debug("collector shut down") class DistributedWeightUpdater(WeightUpdaterBase): """A remote weight updater for synchronizing policy weights across distributed workers. + .. warning:: + This class has been deprecated in favor of the :class:`~torchrl.weight_update.DistributedWeightSyncScheme` + API. + The `DistributedWeightUpdater` class provides a mechanism for updating the weights of a policy across distributed inference workers. It is designed to work with the :class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights. @@ -1086,7 +1123,7 @@ class DistributedWeightUpdater(WeightUpdaterBase): """ - _VERBOSE = True + _VERBOSE = False def __init__( self, @@ -1131,8 +1168,7 @@ def _push_weights( ) for i in workers: rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"updating weights of {rank}") + torchrl_logger.debug(f"updating weights of {rank}") self._store.set(f"NODE_{rank}_in", b"update_weights") if self._sync: weights.send(rank) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 7547985e1ac..1cdaca40072 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -16,11 +16,11 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import as_remote, logger as torchrl_logger +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -72,7 +72,7 @@ def print_remote_collector_info(self): f"{get_node_ip_address()} using gpus {ray.get_gpu_ids()}" ) # torchrl_logger.warning(s) - torchrl_logger.info(s) + torchrl_logger.debug(s) class RayCollector(DataCollectorBase): @@ -755,7 +755,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: self.collected_frames < self.total_frames and not self._stop_event.is_set() ): if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info("Updating weights on all workers") + torchrl_logger.debug("Updating weights on all workers") self.update_policy_weights_() # Ask for batches to all remote workers. @@ -872,7 +872,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: yield out_td if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info(f"Updating weights on worker {collector_index}") + torchrl_logger.debug(f"Updating weights on worker {collector_index}") self.update_policy_weights_(worker_ids=collector_index + 1) # Schedule a new collection task diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index dfbd8a7c5a2..b7705dae72d 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -23,12 +23,12 @@ from torch.distributed import rpc from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -59,11 +59,23 @@ def _rpc_init_collection_node( world_size, visible_device, tensorpipe_options, + backend="gloo", verbose=VERBOSE, ): os.environ["MASTER_ADDR"] = str(rank0_ip) os.environ["MASTER_PORT"] = str(tcp_port) + # Initialize torch.distributed process group for efficient weight transfer + if verbose: + torchrl_logger.debug( + f"init distributed with rank={rank}, world_size={world_size}, backend={backend}" + ) + torch.distributed.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + if isinstance(visible_device, list): pass elif isinstance(visible_device, (str, int, torch.device)): @@ -78,7 +90,7 @@ def _rpc_init_collection_node( **tensorpipe_options, ) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" ) rpc.init_rpc( @@ -89,6 +101,7 @@ def _rpc_init_collection_node( world_size=world_size, ) rpc.shutdown() + torch.distributed.destroy_process_group() class RPCDataCollector(DataCollectorBase): @@ -258,6 +271,9 @@ class RPCDataCollector(DataCollectorBase): https://github.com/facebookincubator/submitit Defaults to "submitit". tcp_port (int, optional): the TCP port to be used. Defaults to 10003. + backend (str, optional): the torch.distributed backend to use for weight synchronization. + Must be one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed + documentation for more information. Defaults to ``"gloo"``. visible_devices (list of Union[int, torch.device, str], optional): a list of the same length as the number of nodes containing the device used to pass data to main. @@ -302,6 +318,7 @@ def __init__( max_weight_update_interval: int = -1, launcher: str = "submitit", tcp_port: str | None = None, + backend: str = "gloo", visible_devices: list[torch.device] | None = None, tensorpipe_options: dict[str, Any] | None = None, weight_updater: WeightUpdaterBase @@ -309,6 +326,10 @@ def __init__( | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): + + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -405,6 +426,7 @@ def __init__( self.postproc = postproc self.split_trajs = split_trajs + self.backend = backend if tensorpipe_options is None: self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS) @@ -412,7 +434,6 @@ def __init__( self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS).update( tensorpipe_options ) - self._init() # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: @@ -424,38 +445,6 @@ def __init__( if weight_sync_schemes is not None: # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - transport = scheme.create_transport( - ( - self.collector_infos[i], - self.collector_rrefs[i], - self.collector_class, - ) - ) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: # Fall back to legacy weight updater system @@ -469,7 +458,20 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init() + + if weight_sync_schemes is not None: + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + model_id=model_id, + num_workers=self.num_workers, + collector_infos=self.collector_infos, + collector_class=self.collector_class, + collector_rrefs=self.collector_rrefs, + context=self, + ) @property def device(self) -> list[torch.device]: @@ -535,7 +537,18 @@ def _init_master_rpc( self, world_size, ): - """Init RPC on main node.""" + """Init torch.distributed and RPC on main node.""" + # Initialize torch.distributed process group for efficient weight transfer + torchrl_logger.debug( + f"init distributed with rank=0, world_size={world_size}, backend={self.backend}" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=0, + world_size=world_size, + ) + + # Initialize RPC for control/signaling options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options) if torch.cuda.is_available(): if self.visible_devices: @@ -544,8 +557,7 @@ def _init_master_rpc( options.set_device_map( f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]} ) - if self._VERBOSE: - torchrl_logger.info("init rpc") + torchrl_logger.debug("init rpc") rpc.init_rpc( "TRAINER_NODE", rank=0, @@ -576,10 +588,7 @@ def _start_workers( counter += 1 time.sleep(time_interval) try: - if self._VERBOSE: - torchrl_logger.info( - f"trying to connect to collector node {i + 1}" - ) + torchrl_logger.debug(f"trying to connect to collector node {i + 1}") collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}") break except RuntimeError as err: @@ -593,8 +602,7 @@ def _start_workers( env_make = env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - if self._VERBOSE: - torchrl_logger.info("Making collector in remote node") + torchrl_logger.debug("Making collector in remote node") collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -614,17 +622,26 @@ def _start_workers( ) collector_rrefs.append(collector_rref) + # Set up receiver schemes on remote collectors (if using new weight sync system) + # This enables cascading: RPC -> MultiSync -> Sync + if getattr(self, "_weight_sync_schemes", None) is not None: + for i in range(num_workers): + torchrl_logger.debug( + f"Setting up receiver schemes on remote collector {i}" + ) + # Call _set_scheme_receiver on the remote collector using rref.rpc_sync() + # This properly dereferences the rref and calls the instance method + collector_rrefs[i].rpc_sync()._set_scheme_receiver( + self._weight_sync_schemes + ) + futures = collections.deque(maxlen=self.num_workers) if not self._sync: for i in range(num_workers): - if self._VERBOSE: - torchrl_logger.info("Asking for the first batch") - future = rpc.rpc_async( - collector_infos[i], - collector_class.next, - args=(collector_rrefs[i],), - ) + torchrl_logger.debug("Asking for the first batch") + # Use rref.rpc_async() to properly call instance method + future = collector_rrefs[i].rpc_async().next() futures.append((future, i)) self.futures = futures self.collector_rrefs = collector_rrefs @@ -646,10 +663,10 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job return job elif self.launcher == "mp": job = _ProcessNoWarn( @@ -661,6 +678,7 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ), ) @@ -692,8 +710,7 @@ def _init(self): self.jobs = [] for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"Submitting job {i}") + torchrl_logger.debug(f"Submitting job {i}") job = self._init_worker_rpc( executor, i, @@ -735,10 +752,9 @@ def iterator(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of worker {j} with wait=False" - ) + torchrl_logger.debug( + f"Updating policy of worker {j} with wait=False" + ) self.update_policy_weights_(worker_ids=[j], wait=False) elif self.max_weight_update_interval > -1: ranks = [ @@ -747,15 +763,13 @@ def iterator(self): if self._batches_since_weight_update[j] > self.max_weight_update_interval ] - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of workers {ranks} with wait=True" - ) + torchrl_logger.debug( + f"Updating policy of workers {ranks} with wait=True" + ) self.update_policy_weights_(worker_ids=ranks, wait=True) def _next_async_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next async") + torchrl_logger.debug("next async") if not len(self.futures): raise StopIteration( f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames." @@ -765,31 +779,23 @@ def _next_async_rpc(self): if future.done(): if self.update_after_each_batch: self.update_policy_weights_(worker_ids=(i,), wait=False) - if self._VERBOSE: - torchrl_logger.info(f"future {i} is done") + torchrl_logger.debug(f"future {i} is done") data = future.value() self._collected_frames += data.numel() if self._collected_frames < self.total_frames: - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) return data self.futures.append((future, i)) def _next_sync_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next sync: futures") + torchrl_logger.debug("next sync: futures") if self.update_after_each_batch: self.update_policy_weights_() for i in range(self.num_workers): - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) data = [] while len(self.futures): @@ -797,10 +803,9 @@ def _next_sync_rpc(self): # the order is NOT guaranteed: should we change that? if future.done(): data += [future.value()] - if self._VERBOSE: - torchrl_logger.info( - f"got data from {i} // data has len {len(data)} / {self.num_workers}" - ) + torchrl_logger.debug( + f"got data from {i} // data has len {len(data)} / {self.num_workers}" + ) else: self.futures.append((future, i)) data = torch.cat(data) @@ -812,34 +817,6 @@ def _next_sync_rpc(self): self._collected_frames += data.numel() return data - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For RPC collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for worker in self.collector_infos: seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,)) @@ -856,25 +833,23 @@ def shutdown(self, timeout: float | None = None) -> None: return if self._shutdown: return - if self._VERBOSE: - torchrl_logger.info("shutting down") + torchrl_logger.debug("shutting down") for future, i in self.futures: # clear the futures while future is not None and not future.done(): - torchrl_logger.info(f"waiting for proc {i} to clear") + torchrl_logger.debug(f"waiting for proc {i} to clear") future.wait() for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"shutting down {i}") - rpc.rpc_sync( - self.collector_infos[i], - self.collector_class.shutdown, - args=(self.collector_rrefs[i],), - timeout=int(IDLE_TIMEOUT), - ) - if self._VERBOSE: - torchrl_logger.info("rpc shutdown") + torchrl_logger.debug(f"shutting down {i}") + # Use rref.rpc_sync() to properly call instance method + self.collector_rrefs[i].rpc_sync(timeout=int(IDLE_TIMEOUT)).shutdown() + torchrl_logger.debug("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) + + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + if self.launcher == "mp": for job in self.jobs: job.join(int(IDLE_TIMEOUT)) @@ -969,19 +944,13 @@ def push_weights( futures = [] weights = self.policy_weights if weights is None else weights for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"calling update on worker {i}") + torchrl_logger.debug(f"calling update on worker {i}") + # Use rref.rpc_async() to properly call instance method futures.append( - rpc.rpc_async( - self.collector_infos[i], - self.collector_class.update_policy_weights_, - args=(self.collector_rrefs[i], weights), - ) + self.collector_rrefs[i].rpc_async().update_policy_weights_(weights) ) if kwargs.get("wait", True): for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"waiting for worker {i}") + torchrl_logger.debug(f"waiting for worker {i}") futures[i].wait() - if self._VERBOSE: - torchrl_logger.info("got it!") + torchrl_logger.debug("got it!") diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index f81a5efce0a..fd36e47cd7b 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -19,11 +19,11 @@ from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, @@ -44,6 +44,7 @@ def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -64,7 +65,7 @@ def _distributed_init_collection_node( os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"node with rank {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): @@ -97,9 +98,9 @@ def _distributed_init_collection_node( **collector_kwargs, ) - torchrl_logger.info(f"IP address: {rank0_ip} \ttcp port: {tcpport}") + torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}") if verbose: - torchrl_logger.info(f"node with rank {rank} -- launching distributed") + torchrl_logger.debug(f"node with rank {rank} -- launching distributed") torch.distributed.init_process_group( backend, rank=rank, @@ -108,9 +109,9 @@ def _distributed_init_collection_node( # init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"node with rank {rank} -- creating store") + torchrl_logger.debug(f"node with rank {rank} -- creating store") if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") + torchrl_logger.debug(f"node with rank {rank} -- loop") policy_weights.irecv(0) frames = 0 for i, data in enumerate(collector): @@ -471,7 +472,7 @@ def _init_master_dist( backend, ): TCP_PORT = self.tcp_port - torchrl_logger.info("init master...") + torchrl_logger.debug("init master...") torch.distributed.init_process_group( backend, rank=0, @@ -479,7 +480,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - torchrl_logger.info("done") + torchrl_logger.debug("done") def _make_container(self): env_constructor = self.env_constructors[0] @@ -505,20 +506,21 @@ def _init_worker_dist_submitit(self, executor, i): env_make = CloudpickleWrapper(env_make) job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + verbose=VERBOSE, ) return job @@ -529,21 +531,22 @@ def _init_worker_dist_mp(self, i): env_make = CloudpickleWrapper(env_make) job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + verbose=VERBOSE, ), ) job.start() @@ -553,7 +556,7 @@ def _init_workers(self): hostname = socket.gethostname() IPAddr = socket.gethostbyname(hostname) - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -565,18 +568,18 @@ def _init_workers(self): executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.slurm_kwargs) for i in range(self.num_workers): - torchrl_logger.info("Submitting job") + torchrl_logger.debug("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - torchrl_logger.info("job launched") + torchrl_logger.debug("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index bc72bda6a4a..3a7258c367a 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -103,7 +103,7 @@ def exec_fun(): executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address - torchrl_logger.info(f"job id: {main_job.job_id}") + torchrl_logger.debug(f"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: @@ -114,11 +114,11 @@ def exec_fun(): except ValueError: time.sleep(0.5) continue - torchrl_logger.info(f"node: {node}") + torchrl_logger.debug(f"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip() - torchrl_logger.info(f"IP: {rank0_ip}") + torchrl_logger.debug(f"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 799c0a5e692..9c5b9c06117 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -5,7 +5,7 @@ from __future__ import annotations import contextlib -from collections.abc import Callable +from collections.abc import Callable, Sequence import torch from pyvers import implement_for @@ -264,7 +264,13 @@ def nest(*x): @implement_for("torch", "2.5.0") -def _cast(p, param_maybe_buffer): +def _cast( + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data if isinstance(param_maybe_buffer, Parameter): # Create parameter without gradients to avoid serialization issues return Parameter(p, requires_grad=False) @@ -291,7 +297,13 @@ def _make_meta_policy(policy: nn.Module): @implement_for("torch", None, "2.5.0") -def _cast(p, param_maybe_buffer): # noqa +def _cast( # noqa + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data if isinstance(param_maybe_buffer, Parameter): # Create parameter without gradients to avoid serialization issues return Parameter(p, requires_grad=False) @@ -357,3 +369,30 @@ def _map_weight( elif is_buffer: weight = Buffer(weight) return weight + + +def _make_policy_factory( + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None +): + has_policy_factory = policy_factory is not None and ( + (isinstance(policy_factory, Sequence) and any(policy_factory)) + or not isinstance(policy_factory, Sequence) + ) + if policy is not None and has_policy_factory: + raise ValueError("policy cannot be used with policy_factory") + elif has_policy_factory: + if isinstance(policy_factory, Sequence): + return policy_factory + else: + policy = policy_factory() + + if weight_sync_scheme is not None: + # Initialize the receiver on the worker side + weight_sync_scheme.init_on_receiver( + model=policy, + model_id="policy", + worker_idx=worker_idx, + ) + # Synchronize initial weights + weight_sync_scheme.synchronize_weights(worker_idx=worker_idx) + return policy diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index a742d922a12..1228692b552 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -1,10 +1,15 @@ from __future__ import annotations +import weakref + from typing import Any import torch -from tensordict import TensorDict +from tensordict import TensorDictBase + +from torchrl._utils import logger as torchrl_logger +from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, @@ -13,6 +18,28 @@ ) +class DistributedWeightReceiver(WeightReceiver): + """Weight receiver for torch.distributed systems. + + Receives weight updates from the main process via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None + + +class DistributedWeightSender(WeightSender): + """Weight sender for torch.distributed systems. + + Sends weight updates to distributed workers via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None + + class DistributedWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed. @@ -25,25 +52,108 @@ class DistributedWeightSyncScheme(WeightSyncScheme): sync (bool): Whether to use synchronous weight updates """ + _receiver_cls = DistributedWeightReceiver + _sender_cls = DistributedWeightSender + def __init__(self, backend: str = "gloo", sync: bool = True): super().__init__() self.backend = backend self.sync = sync - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create distributed transport for a specific worker. - - Args: - pipe_or_context: A tuple of (store, rank) for the worker. - - Returns: - DistributedTransport configured for this specific worker. + def _init_on_sender_impl( + self, + *args, + **kwargs, + ) -> None: + num_workers = kwargs.pop("num_workers") + context = kwargs.pop("context") + model_id = kwargs.pop("model_id") + + # Create and configure sender for this model + sender = self.create_sender() + sender._model_id = model_id + + # Attach context so the sender can resolve the model and prepare + # weights on demand via scheme.prepare_weights(). + if context is not None: + sender._set_context(context, model_id) + + # Store reference to source model for automatic extraction + try: + sender._source_model = _resolve_model(context, model_id) + except (AttributeError, IndexError): + pass + + # Create transports for each remote collector + weights_buffer = self._get_weights_buffer_from_model(sender._source_model) + for i in range(num_workers): + rank = i + 1 # Workers are 1-indexed in distributed + transport = self.create_transport( + store=context._store, rank=rank, weights_buffer=weights_buffer + ) + sender._transports[i] = transport + + # Expose sender through the base API + self._sender = sender + + def _init_on_receiver_impl(self, *args, **kwargs) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - store: TCPStore | None # distributed TCP store + - rank: int | None # worker rank (1-indexed) """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: - store, rank = pipe_or_context - return DistributedTransport(store=store, rank=rank, sync=self.sync) - # Fallback - shouldn't normally happen - return DistributedTransport() + context = kwargs.pop("context", None) + model_id = kwargs.pop("model_id") + store = kwargs.pop("store", None) + rank = kwargs.pop("rank", None) + + if context is None: + raise ValueError( + "DistributedWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Create receiver instance + receiver = self._receiver_cls(self) + receiver._model_id = model_id + + # Attach context so we can resolve string model refs like "policy" + receiver._context_ref = weakref.ref(context) + + # Resolve the target model on this worker + model = None + # Prefer a collector-specific get_model if available, but fall back + # gracefully to attribute resolution when no mapping exists. + if hasattr(context, "get_model"): + try: + model = context.get_model(model_id) + except (ValueError, AttributeError): + model = None + if model is None: + model = _resolve_model(context, model_id) + receiver._register_model(model) + + weights_buffer = self._get_weights_buffer_from_model(model) + receiver._transport = self.create_transport( + store=store, rank=rank, weights_buffer=weights_buffer + ) + + # Store receiver on scheme so get_receiver() works as expected + self._receiver = receiver + + def create_transport(self, **kwargs) -> TransportBackend: + """Create distributed transport for a specific worker.""" + if self._initialized_on_receiver: + return DistributedTransport(**kwargs) + elif self._initialized_on_sender: + return DistributedTransport(**kwargs) + else: + raise RuntimeError( + "DistributedWeightSyncScheme.create_transport must be called after initialization has been marked." + ) class DistributedTransport: @@ -54,18 +164,26 @@ class DistributedTransport: following the same pattern as multiprocess collectors. """ - def __init__(self, store=None, rank=None, sync=True): + def __init__( + self, + *, + weights_buffer: TensorDictBase, + store: torch.distributed.Store = None, + rank: int = None, + sync: bool = True, + ): """Initialize the DistributedTransport. Args: - store: TCPStore for communication. - rank: Worker rank (1-indexed). - sync: Whether to use synchronous weight updates. + weights_buffer (TensorDictBase): a tensor buffer of weights. + store (torch.distributed.Store): A (TCP)Store for communication. + rank (int): Worker rank (1-indexed). + sync (bool): Whether to use synchronous weight updates. """ self._store = store self._rank = rank self._sync = sync - self._weights_buffer = None # TensorDict buffer for receiving weights + self._weights_buffer = weights_buffer def send_weights(self, weights: Any) -> None: """Send weights to the distributed worker.""" @@ -73,15 +191,18 @@ def send_weights(self, weights: Any) -> None: return # Instruct worker to expect weight update + torchrl_logger.debug("RANK 0 -- Setting weight sync instructions to store") self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed + torchrl_logger.debug(f"RANK 0 -- Send {weights=} to rank {self._rank}") if self._sync: weights.send(self._rank) else: weights.isend(self._rank) # Wait for acknowledgment + torchrl_logger.debug("RANK 0 -- Receiving acknowledgement from store") status = self._store.get(f"NODE_{self._rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'updated' but got status {status}.") @@ -96,13 +217,20 @@ def send_weights_async(self, weights: Any) -> None: return # Instruct worker to expect weight update + torchrl_logger.info( + f"RANK 0 -- Setting weight sync instructions to store for rank {self._rank}" + ) self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed + torchrl_logger.info( + f"RANK 0 -- Send {weights=} to rank {self._rank} with sync={self._sync}" + ) if self._sync: weights.send(self._rank) else: weights.isend(self._rank) + torchrl_logger.debug(f"RANK 0 -- Weights successfully sent to {self._rank}") def wait_ack(self) -> None: """Wait for acknowledgment from distributed worker.""" @@ -115,55 +243,31 @@ def wait_ack(self) -> None: self._store.delete_key(f"NODE_{self._rank}_out") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights via torch.distributed, using TCPStore for signaling. + r"""Receive weights via torch.distributed. - This implements the RPC-like pattern: - 1. Check TCPStore for signal (non-blocking) - 2. If signal present, receive weights via torch.distributed - 3. Clean up signal and send acknowledgment + The surrounding collector loop is responsible for checking the TCPStore + for the \"update_weights\" instruction. When this method is called we + assume that a weight update has been requested and the sender has + already performed the corresponding ``send()``. Args: - timeout: Timeout for receiving (currently not used for TCPStore check) + timeout: Unused for now (kept for TransportBackend compatibility). Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. + Tuple of (model_id, weights) where model_id is currently always + \"policy\". """ if self._store is None or self._rank is None: return None - try: - # Non-blocking check of TCPStore "mailbox" for signal - msg = self._store.get(f"NODE_{self._rank}_in") - - if msg == b"update_weights": - # Initialize weights buffer on first use - if self._weights_buffer is None: - self._weights_buffer = TensorDict() - - # Receive weights via torch.distributed - # recv() and irecv() update the TensorDict in place - if self._sync: - self._weights_buffer.recv(src=0) - else: - # irecv() blocks until weights are received - self._weights_buffer.irecv(src=0) - - # Clean up the signal - self._store.delete_key(f"NODE_{self._rank}_in") - - # Note: Acknowledgment is sent separately via send_ack() if transport supports it - # This matches the pattern in WeightReceiver.receive() - - # Return model_id and received weights - # For distributed transport, we use "policy" as default model_id - return ("policy", self._weights_buffer) - else: - raise ValueError(f"Expected 'update_weights' but got {msg}") - except KeyError: - # No message in store - no weights available - return None + # Receive weights via torch.distributed into the buffer + if self._sync: + self._weights_buffer.recv(src=0) + else: + # irecv() blocks until weights have been received + self._weights_buffer.irecv(src=0) - return None + return ("policy", self._weights_buffer) def send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender via TCPStore. @@ -183,28 +287,6 @@ def check_connection(self) -> bool: def synchronize_weights_on_sender(self) -> None: """No-op for DistributedTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for DistributedTransport - weights are received via receive_weights().""" return None - - -class DistributedWeightReceiver(WeightReceiver): - """Weight receiver for torch.distributed systems. - - Receives weight updates from the main process via torch.distributed send/recv - primitives and TCPStore signaling. This is typically instantiated and managed - by :class:`DistributedWeightSyncScheme`. - """ - - _transport: DistributedTransport | None - - -class DistributedWeightSender(WeightSender): - """Weight sender for torch.distributed systems. - - Sends weight updates to distributed workers via torch.distributed send/recv - primitives and TCPStore signaling. This is typically instantiated and managed - by :class:`DistributedWeightSyncScheme`. - """ - - _transport: DistributedTransport | None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index fc845fcdf64..4e7bf760845 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -2,7 +2,7 @@ import weakref from collections.abc import Callable -from typing import Any, overload +from typing import Any import torch from tensordict import TensorDictBase @@ -16,6 +16,197 @@ ) +class MPWeightReceiver(WeightReceiver): + """Weight receiver for multiprocess systems using queues. + + Receives weight updates from the main process via multiprocessing queues. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + + +class MPWeightSender(WeightSender): + """Weight sender for multiprocess systems using queues. + + Sends weight updates to worker processes via multiprocessing queues. + Supports both synchronous and asynchronous sending patterns. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + _model_id: str + _scheme: MultiProcessWeightSyncScheme + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) + + model_id = self._model_id + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def synchronize_weights(self) -> None: + """Synchronize weights with workers before collection starts. + + Computes device-specific weight copies on-demand and sends them to workers + sequentially via queues. This is called once after workers are initialized + but before they start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + + Raises: + RuntimeError: If init_on_sender() was not called first. + """ + # Get the device mapping info stored during init_on_sender + if not hasattr(self._scheme, "_device_mapping_info"): + raise RuntimeError( + "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" + ) + + mapping_info = self._scheme._device_mapping_info + + # Get context from sender's weakref + context = self._context_ref() if self._context_ref is not None else None + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._scheme._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Send to workers sequentially via queues (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + worker_weights = params_map[i] + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) + + # Clean up the mapping info after synchronization + delattr(self._scheme, "_device_mapping_info") + + class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): """Weight synchronization for multiprocess operations using queues. @@ -64,6 +255,9 @@ class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): is large. """ + _sender_cls = MPWeightSender + _receiver_cls = MPWeightReceiver + def __init__(self, strategy: str = "tensordict"): """Initialize the MultiProcessWeightSyncScheme. @@ -203,29 +397,9 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( + def _init_on_receiver_impl( self, - model_id: str, - context: None = None, *, - worker_idx: int = ..., - model: Any | None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( - self, model_id: str, context: Any = None, **kwargs, @@ -278,7 +452,7 @@ def init_on_receiver( receiver._worker_idx = worker_idx self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True def create_transport(self, queue: Any) -> TransportBackend: """Create an MPTransport using the provided queue. @@ -298,7 +472,7 @@ class MPTransport: Initialization flow: - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via queues - - Workers receive the initial weights via synchronize_weights_on_worker() + - Workers receive the initial weights via synchronize_weights_on_receiver() - Subsequent updates use send_weights_async() followed by acknowledgments Args: @@ -383,11 +557,11 @@ def synchronize_weights_on_sender(self) -> None: sends shared memory buffer references via queues. """ - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """Receive initial weights from sender during worker initialization. This method blocks waiting for the initial weights to be sent from the main process - via queue. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives + via queue. Similar to SharedMemTransport.synchronize_weights_on_receiver() which receives shared memory buffer references via queues, this receives the actual weights via queues. The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). @@ -406,194 +580,3 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: return weights else: raise ValueError(f"Expected 'update_weights' but got {msg}") - - -class MPWeightReceiver(WeightReceiver): - """Weight receiver for multiprocess systems using queues. - - Receives weight updates from the main process via multiprocessing queues. - This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. - """ - - _transport: MPTransport | None - - -class MPWeightSender(WeightSender): - """Weight sender for multiprocess systems using queues. - - Sends weight updates to worker processes via multiprocessing queues. - Supports both synchronous and asynchronous sending patterns. - This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. - """ - - _transport: MPTransport | None - _model_id: str - _scheme: MultiProcessWeightSyncScheme - - def send( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights synchronously to workers. - - This method: - 1. Prepares weights (extracts from model if weights=None) - 2. Sends to specified workers (or all if worker_ids=None) - 3. Waits for acknowledgments from those workers - 4. Returns when workers have applied the weights - - Args: - weights: Weights to send. Can be: - - None: Extract from model via context.get_model(model_id) - - nn.Module: Extract weights from module - - TensorDict: Use directly - - dict: Convert to TensorDict - worker_ids: Which workers to send to: - - None: Send to all workers (default) - - int: Send to single worker - - list[int]: Send to specific workers - - Note: This is a blocking call that ensures specified workers are updated - before returning. - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send() while an async send is pending. Call wait_async() first." - ) - - model_id = self._model_id - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=model_id, - strategy=self._strategy, - context=context, - ) - - transports = list(self._iterate_transports(worker_ids)) - - # Send to all workers first (non-blocking if transport supports it) - for transport in transports: - if hasattr(transport, "send_weights_async"): - # For MPTransport, pass model_id; other transports don't need it - transport.send_weights_async(prepared_weights, model_id=model_id) - else: - # Fallback for transports that don't support async send - transport.send_weights(prepared_weights) - - # Wait for all acknowledgments - for transport in transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() - - def send_async( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights asynchronously to workers (non-blocking). - - This initiates the send but returns immediately without waiting - for workers to acknowledge. You must call wait_async() before - the next send_async() or send() call. - - Args: - weights: Same as send() - worker_ids: Same as send() - - Raises: - RuntimeError: If a previous send_async() is still pending - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send_async() again while a previous send is pending. Call wait_async() first." - ) - - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=self._model_id, - strategy=self._strategy, - context=context, - ) - - # Store transports for wait_async - self._pending_transports = list(self._iterate_transports(worker_ids)) - - # Send to all workers (non-blocking) - for transport in self._pending_transports: - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(prepared_weights, model_id=self._model_id) - else: - raise RuntimeError( - f"transport of type {type(transport)} does not support async send." - ) - - self._pending_async = True - - def synchronize_weights(self) -> None: - """Synchronize weights with workers before collection starts. - - Computes device-specific weight copies on-demand and sends them to workers - sequentially via queues. This is called once after workers are initialized - but before they start collecting data. - - Unlike send(), this does not wait for acknowledgments since workers are still - in their initialization phase. - - This approach creates weight copies on-demand and sends them sequentially, - allowing garbage collection between workers to reduce memory usage. - - Raises: - RuntimeError: If init_on_sender() was not called first. - """ - # Get the device mapping info stored during init_on_sender - if not hasattr(self._scheme, "_device_mapping_info"): - raise RuntimeError( - "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" - ) - - mapping_info = self._scheme._device_mapping_info - - # Get context from sender's weakref - context = self._context_ref() if self._context_ref is not None else None - - # Compute params_map on-demand - # Extract with explicit type casting for type checker - model_id = mapping_info["model_id"] - weights = mapping_info["weights"] - model = mapping_info["model"] - params_map_arg = mapping_info["params_map"] - devices = mapping_info["devices"] - device_map_fn = mapping_info["device_map_fn"] - num_workers = mapping_info["num_workers"] - - params_map = self._scheme._get_params_map( - context=context, - model_id=model_id, - weights=weights, - model=model, - params_map=params_map_arg, - devices=devices, - device_map_fn=device_map_fn, - num_workers=num_workers, - ) - - # Send to workers sequentially via queues (no ACK - workers are still initializing) - # This allows GC to clean up each worker's weights before creating the next - for i, transport in enumerate(self._iterate_transports()): - worker_weights = params_map[i] - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] - else: - raise RuntimeError( - f"Transport {type(transport)} does not support async send for synchronization" - ) - - # Clean up the mapping info after synchronization - delattr(self._scheme, "_device_mapping_info") diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index fbb90f8ff34..0751261a4ce 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, overload +from typing import Any from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, @@ -16,24 +16,6 @@ class NoWeightSyncScheme(WeightSyncScheme): This scheme disables weight synchronization entirely. """ - @overload - def init_on_sender( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_sender( - self, - model_id: str, - context: None = None, - **kwargs, - ) -> None: - ... - def _init_on_sender_impl( self, model_id: str, @@ -54,26 +36,9 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( - self, - model_id: str, - context: None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( + def _init_on_receiver_impl( self, + *, model_id: str, context: Any = None, **kwargs, @@ -90,7 +55,7 @@ def init_on_receiver( receiver._model_ref = model_id self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create a no-op transport. diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index 0dff3db7417..a7a16999574 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import Any, Literal, overload +from typing import Any, Literal from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( @@ -22,40 +22,29 @@ class RayWeightSyncScheme(WeightSyncScheme): as multiprocess collectors. """ - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport( + self, + *, + remote_collector=None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + **kwargs, + ) -> TransportBackend: """Create Ray-based transport for a specific remote collector. Args: - pipe_or_context: The Ray actor handle for the remote collector. + remote_collector: The Ray actor handle for the remote collector. + tensor_transport: Transport mechanism for tensors ("object_store" or "nixl"). + **kwargs: Additional transport configuration. Returns: RayTransport configured for this specific remote collector. """ - return RayTransport(remote_collector=pipe_or_context) - - @overload - def init_on_sender( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_sender( - self, - model_id: str, - context: None = None, - *, - remote_collectors: list = ..., - num_workers: int | None = None, - source_model: Any | None = None, - **kwargs, - ) -> None: - ... + return RayTransport( + remote_collector=remote_collector, + tensor_transport=tensor_transport, + ) - def init_on_sender( + def _init_on_sender_impl( self, model_id: str, context: Any = None, @@ -87,9 +76,12 @@ def init_on_sender( sender = WeightSender(self) sender._model_id = model_id - # Register each Ray actor - _register_worker will create the transport + # Register each Ray actor with explicit transport kwargs for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker(worker_idx, remote_collector) + sender._register_worker( + worker_idx, + remote_collector=remote_collector, + ) # Set context with weak reference to avoid circular refs if context is not None: @@ -103,27 +95,7 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( - self, - model_id: str, - context: None = None, - *, - model: Any | None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( + def _init_on_receiver_impl( self, model_id: str, context: Any = None, @@ -155,7 +127,118 @@ def init_on_receiver( receiver._set_context(weakref.ref(context)) self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True + + +class RayModuleTransformReceiver(WeightReceiver): + """Specialized receiver for RayModuleTransform actors. + + This receiver handles weight updates within Ray actors. + Since Ray actors receive weights through direct method calls, + this receiver primarily validates and applies weights locally. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + + def _register_worker_transport( + self, actor_or_context: Any = None, **transport_kwargs + ) -> None: + """Register the Ray actor's transport (internal). + + This is now handled by init_on_receiver(). Only kept for internal use. + + Args: + actor_or_context: Legacy parameter (deprecated, use transport_kwargs). + **transport_kwargs: Transport-specific configuration (e.g., actor_ref=...). + """ + # Support legacy actor_or_context for backward compatibility + if actor_or_context is not None and not transport_kwargs: + transport_kwargs = {"actor_ref": actor_or_context} + self._transport = self._scheme.create_transport(**transport_kwargs) + + def apply_weights(self, weights: Any, inplace: bool = True) -> None: + """Apply received weights to registered model. + + For Ray actors, weights are applied directly to the module + within the actor's process space. + + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + """ + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights, inplace=inplace) + + +class RayModuleTransformSender(WeightSender): + """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + + This sender handles weight updates for models hosted within Ray actors. + Unlike the base WeightSender which uses pipes for multiprocessing, + this sender directly communicates with Ray actors via their remote methods. + + For Ray actors, there is typically only one shared actor instance, so we + store a single transport rather than per-worker transports. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + self._actor_ref = None + self._single_transport = None + self._context_ref = None + self._model_id_str = None + + def _set_context(self, context: Any, model_id: str) -> None: + """Set context for lazy actor resolution (internal). + + This is now handled by init_on_sender(). Only kept for internal use. + + Args: + context: The collector instance. + model_id: String path to the Ray actor (e.g., "env.transform[0]"). + """ + self._context_ref = weakref.ref(context) + self._model_id_str = model_id + + def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """For Ray actors, worker registration is a no-op (internal). + + Ray actors are shared across all workers, so we don't need per-worker + transports. The actor reference is resolved lazily on first use. + """ + + def update_weights(self, weights: Any) -> None: + """Send weights to the Ray actor. + + Args: + weights: Weights to send. + """ + if self._single_transport is None: + self._initialize_transport() + + if self._single_transport is not None: + self._single_transport.send_weights(weights) + + def _initialize_transport(self) -> None: + """Lazily initialize the transport by resolving the actor reference.""" + if self._context_ref is None or self._model_id_str is None: + return + + context = self._context_ref() + if context is None: + return + + model = _resolve_model(context, self._model_id_str) + if hasattr(model, "_actor"): + self._actor_ref = model._actor + self._single_transport = self._scheme.create_transport(actor_ref=model) + elif type(model).__name__ == "ActorHandle": + self._actor_ref = model + self._single_transport = self._scheme.create_transport(actor_ref=model) class RayModuleTransformScheme(WeightSyncScheme): @@ -170,21 +253,44 @@ class RayModuleTransformScheme(WeightSyncScheme): Default is "tensordict". """ + _sender_cls = RayModuleTransformSender + _receiver_cls = RayModuleTransformReceiver + def __init__(self, strategy: str = "tensordict"): super().__init__(strategy) - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport( + self, + *, + actor_ref=None, + update_method: str | None = None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + **kwargs, + ) -> TransportBackend: """Create RayActorTransport for the given actor. Args: - pipe_or_context: Either a Ray actor reference or a context object - from which to extract the actor reference. + actor_ref: Ray actor reference or context object with _actor attribute. + update_method: Weight update method ("tensordict" or "state_dict"). + If None, uses self.strategy. + tensor_transport: Transport mechanism for tensors ("object_store" or "nixl"). + **kwargs: Additional transport configuration. Returns: RayActorTransport configured with the actor reference. """ - actor_ref = self._extract_actor_ref(pipe_or_context) - return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) + # Extract actor reference if needed + if actor_ref is not None and hasattr(actor_ref, "_actor"): + actor_ref = actor_ref._actor + + if update_method is None: + update_method = self.strategy + + return RayActorTransport( + actor_ref=actor_ref, + update_method=update_method, + tensor_transport=tensor_transport, + ) def _extract_actor_ref(self, pipe_or_context: Any) -> Any: """Extract the Ray actor reference from the context. @@ -208,29 +314,6 @@ def create_receiver(self) -> RayModuleTransformReceiver: """Create a specialized receiver for Ray actor communication.""" return RayModuleTransformReceiver(self) - @overload - def init_on_sender( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_sender( - self, - model_id: str, - context: None = None, - *, - actor_refs: list | None = None, - actors: list | None = None, - remote_collectors: list | None = None, - source_model: Any | None = None, - **kwargs, - ) -> None: - ... - def _init_on_sender_impl( self, model_id: str, @@ -268,9 +351,12 @@ def _init_on_sender_impl( sender = self.create_sender() sender._model_id = model_id - # Register all actors - _register_worker will create the transport + # Register all actors with explicit transport kwargs for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker(worker_idx, actor_ref) + sender._register_worker( + worker_idx, + actor_ref=actor_ref, + ) # Set context with weak reference if context is not None: @@ -284,29 +370,9 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( + def _init_on_receiver_impl( self, - model_id: str, - context: None = None, *, - actor_ref: Any | None = None, - model: Any | None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( - self, model_id: str, context: Any = None, **kwargs, @@ -322,11 +388,10 @@ def init_on_receiver( receiver = self.create_receiver() # Extract actor reference if needed - actor_ref = kwargs.get("actor_ref") or context - if actor_ref is not None: + actor_ref_arg = kwargs.get("actor_ref") or context + if actor_ref_arg is not None: # Register the transport for this actor - transport = self.create_transport(actor_ref) - receiver._register_worker_transport(transport) + receiver._register_worker_transport(actor_ref=actor_ref_arg) # Register model if provided model = kwargs.get("model") or ( @@ -342,7 +407,7 @@ def init_on_receiver( receiver._set_context(weakref.ref(context)) self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True class RayTransport: @@ -415,7 +480,7 @@ def check_connection(self) -> bool: def synchronize_weights_on_sender(self) -> None: """No-op for RayTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for RayTransport - weights are received via remote method calls.""" return None @@ -519,111 +584,6 @@ def check_connection(self) -> bool: def synchronize_weights_on_sender(self) -> None: """No-op for RayActorTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for RayActorTransport - weights are received via remote method calls.""" return None - - -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. - - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - - def _register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport (internal). - - This is now handled by init_on_receiver(). Only kept for internal use. - - Args: - actor_or_context: Either a Ray actor reference or a context object. - """ - self._transport = self._scheme.create_transport(actor_or_context) - - def apply_weights(self, weights: Any, inplace: bool = True) -> None: - """Apply received weights to registered model. - - For Ray actors, weights are applied directly to the module - within the actor's process space. - - Args: - weights: The weights to apply. - inplace: Whether to apply weights in place. Default is `True`. - """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights, inplace=inplace) - - -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. - - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. - - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None - - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). - - This is now handled by init_on_sender(). Only kept for internal use. - - Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). - """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id - - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). - - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. - """ - - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. - - Args: - weights: Weights to send. - """ - if self._single_transport is None: - self._initialize_transport() - - if self._single_transport is not None: - self._single_transport.send_weights(weights) - - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return - - context = self._context_ref() - if context is None: - return - - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(model) diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 9290b23aa05..cf5797048c2 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -2,6 +2,7 @@ from typing import Any +from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, @@ -10,6 +11,52 @@ ) +class RPCWeightReceiver(WeightReceiver): + """Weight receiver for RPC-based distributed systems. + + Receives weight updates from the main process via torch.distributed primitives. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ + + def receive(self, timeout: float = 0.001) -> Any: + """Receive weights from the main process using torch.distributed.recv(). + + Args: + timeout: Not used for RPC receivers (included for interface compatibility). + + Returns: + The received weights as a TensorDict. + """ + from tensordict import TensorDict + + # Dereference the weakref to get the actual context + context = self._context_ref() if hasattr(self, "_context_ref") else None + if context is None: + return None + + # Get the policy to determine the structure of weights to receive + if hasattr(context, "policy") and context.policy is not None: + policy = context.policy + # Create an empty TensorDict with the same structure as the policy weights + weights = TensorDict.from_module(policy) + # Receive weights from rank 0 (the main/trainer process) + weights.recv(0) + + # Apply the received weights to the policy + self._strategy.apply_weights(policy, weights) + return weights + + return None + + +class RPCWeightSender(WeightSender): + """Weight sender for RPC-based distributed systems. + + Sends weight updates to remote collectors via torch.distributed.rpc calls. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ + + class RPCWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed.rpc. @@ -18,106 +65,218 @@ class RPCWeightSyncScheme(WeightSyncScheme): same pattern as multiprocess collectors. """ - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + _sender_cls = RPCWeightSender + _receiver_cls = RPCWeightReceiver + + def _init_on_receiver_impl(self, *args, **kwargs) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - worker_idx: int | None # worker index (optional) + """ + import weakref + + context = kwargs.pop("context", None) + model_id = kwargs.pop("model_id") + worker_idx = kwargs.pop("worker_idx", None) + + if context is None: + raise ValueError( + "RPCWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Create receiver instance + receiver = self._receiver_cls(self) + receiver._model_id = model_id + receiver._worker_idx = worker_idx + + # Attach context so we can resolve string model refs like "policy" + receiver._context_ref = weakref.ref(context) + + # Resolve the target model on this worker + from torchrl.weight_update.utils import _resolve_model + + model = _resolve_model(context, model_id) + receiver._register_model(model) + + # Note: For RPC, we don't create a transport on the receiver side + # The receiver just needs to call recv() when signaled + receiver._transport = None + + # Store receiver on scheme so get_receiver() works as expected + self._receiver = receiver + + def create_transport( + self, + *, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + **kwargs, + ) -> TransportBackend: """Create RPC-based transport for a specific remote collector. Args: - pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) - for the remote collector. + collector_info: RPC worker info for the remote collector. + collector_rref: RPC remote reference to the collector. + collector_class: Class of the remote collector. + worker_rank: The torch.distributed rank of the remote worker. + **kwargs: Additional transport configuration. Returns: RPCTransport configured for this specific remote collector. """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: - collector_info, collector_rref, collector_class = pipe_or_context - return RPCTransport( - collector_info=collector_info, - collector_rref=collector_rref, + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + worker_rank=worker_rank, + ) + + def _init_on_sender_impl(self, *args, **kwargs): + model_id = kwargs["model_id"] + num_workers = kwargs["num_workers"] + collector_infos = kwargs["collector_infos"] + collector_rrefs = kwargs["collector_rrefs"] + collector_class = kwargs["collector_class"] + context = kwargs["context"] + + sender = self.create_sender() + sender._model_id = model_id + + # Create transports for each remote collector + # worker_rank is i+1 because rank 0 is the main/trainer process + for i in range(num_workers): + worker_rank = i + 1 + transport = self.create_transport( + collector_info=collector_infos[i], + collector_rref=collector_rrefs[i], collector_class=collector_class, + worker_rank=worker_rank, ) - # If just passed the info directly - return RPCTransport(collector_info=pipe_or_context) + sender._transports[i] = transport + + # Set context and register model + if hasattr(sender, "_set_context"): + sender._set_context(context, model_id) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(context, "policy") + and context.policy is not None + ): + sender._source_model = context.policy + else: + sender._source_model = _resolve_model(context, model_id) class RPCTransport: """RPC transport for communicating with a single RPC remote collector. This transport handles weight updates for ONE specific remote collector via - torch.distributed.rpc. Multiple transports are created for multiple collectors, - following the same pattern as multiprocess collectors. + torch.distributed primitives (send/recv) with RPC used for signaling. + Multiple transports are created for multiple collectors, following the same + pattern as the DistributedDataCollector. """ - def __init__(self, collector_info=None, collector_rref=None, collector_class=None): + def __init__( + self, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + ): self._collector_info = collector_info self._collector_rref = collector_rref self._collector_class = collector_class + self._worker_rank = worker_rank # The torch.distributed rank of this worker + self._pending_future = None + self._pending_send = None def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via RPC.""" + """Send weights to the remote collector using torch.distributed. + + Uses torch.distributed.send() for the actual weight transfer and RPC + for signaling the remote collector to receive. + + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed (blocking until recv completes) + """ if self._collector_info is None or self._collector_rref is None: return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") - from torch.distributed import rpc + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + future = self._collector_rref.rpc_async()._receive_weights_scheme() - # Send weights to the remote collector and wait for completion - rpc.rpc_sync( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) + # Step 2: Send weights via torch.distributed (blocks until receiver calls recv()) + weights.send(self._worker_rank) + + # Step 3: Wait for RPC to complete (receiver has applied weights) + future.wait() def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. + """Send weights to remote collector asynchronously. + + Uses torch.distributed.isend() for the actual weight transfer and RPC + for signaling. Use wait_ack() to wait for completion. - Use wait_ack() to wait for completion after sending to all workers. + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed.isend() (non-blocking) + 3. wait_ack() waits for both to complete """ if self._collector_info is None or self._collector_rref is None: return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") - from torch.distributed import rpc - - # Send weights asynchronously - self._pending_future = rpc.rpc_async( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + self._pending_future = ( + self._collector_rref.rpc_async()._receive_weights_scheme() ) + # Step 2: Send weights asynchronously via torch.distributed + # Store the Work handle for wait_ack() + weights.isend(self._worker_rank) + def wait_ack(self) -> None: - """Wait for the RPC call to complete.""" - if hasattr(self, "_pending_future"): + """Wait for both the RPC call and the distributed send to complete.""" + # Wait for the RPC call to complete + if hasattr(self, "_pending_future") and self._pending_future is not None: self._pending_future.wait() del self._pending_future def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """RPC workers typically don't receive weights through this transport.""" + """Receive weights from sender using torch.distributed.recv().""" + # In RPC, we don't typically call this directly - instead, the receiver + # scheme's receive() method should handle the recv() call. + # This is here for completeness but may not be used in the RPC pattern. return None def check_connection(self) -> bool: - """Check if RPC is initialized.""" + """Check if both RPC and torch.distributed are initialized.""" + import torch.distributed from torch.distributed import rpc - return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + rpc_initialized = ( + rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + ) + dist_initialized = torch.distributed.is_initialized() + return rpc_initialized and dist_initialized def synchronize_weights_on_sender(self) -> None: """No-op for RPCTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RPCTransport - weights are received via RPC calls.""" + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via receive().""" return None - - -class RPCWeightReceiver(WeightReceiver): - """Weight receiver for RPC-based distributed systems. - - Receives weight updates from the main process via torch.distributed.rpc. - This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. - """ - - -class RPCWeightSender(WeightSender): - """Weight sender for RPC-based distributed systems. - - Sends weight updates to remote collectors via torch.distributed.rpc calls. - This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. - """ diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index d12292c95ba..790182e80dc 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -2,7 +2,7 @@ import weakref from collections.abc import Callable -from typing import Any, overload +from typing import Any import torch import torch.distributed @@ -71,7 +71,7 @@ def synchronize_weights_on_sender(self) -> None: weights = self._params_map[worker_idx] queue.put(weights) - def synchronize_weights_on_worker( + def synchronize_weights_on_receiver( self, worker_idx: int, timeout: float = 10.0 ) -> TensorDictBase: """Receive shared memory buffer reference from sender via their per-worker queues. @@ -137,6 +137,30 @@ def check_connection(self) -> bool: return True +class SharedMemWeightReceiver(WeightReceiver): + """Weight receiver for shared memory systems. + + Receives weight updates via shared memory buffers. Workers automatically + see weight updates without explicit message passing, providing zero-copy + weight synchronization. This is typically instantiated and managed by + :class:`SharedMemWeightSyncScheme`. + """ + + _transport: SharedMemTransport | None + + +class SharedMemWeightSender(WeightSender): + """Weight sender for shared memory systems. + + Sends weight updates by writing directly to shared memory buffers. + All workers automatically see updates without explicit communication, + providing zero-copy weight synchronization. This is typically instantiated + and managed by :class:`SharedMemWeightSyncScheme`. + """ + + _transport: SharedMemTransport | None + + class SharedMemWeightSyncScheme(WeightSyncScheme): """Weight synchronization using shared memory. @@ -152,6 +176,9 @@ class SharedMemWeightSyncScheme(WeightSyncScheme): >>> # Weights are initialized via init_on_sender() """ + _sender_cls = SharedMemWeightSender + _receiver_cls = SharedMemWeightReceiver + def __init__( self, strategy: str = "tensordict", @@ -283,19 +310,6 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - def synchronize_weights(self): - """Method to be called once the workers have started. - - Triggers a rendez-vous for the workers to receive their copy of the weights. - - This is a convenience method that delegates to the sender's synchronize_weights(). - """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" - ) - self._sender.synchronize_weights() - def _get_params_map( self, context: Any = None, @@ -403,25 +417,7 @@ def _get_params_map( "Either params_map, model_id + context or model/weights + devices must be provided." ) - @overload - def init_on_receiver( - self, - *, - model_id: str, - context: Any, - ) -> None: - ... - - @overload - def init_on_receiver( - self, - *, - model: Any, - worker_idx: int, - ) -> None: - ... - - def init_on_receiver( + def _init_on_receiver_impl( self, *, model_id: str | None = None, @@ -466,7 +462,7 @@ def init_on_receiver( receiver._worker_idx = worker_idx self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True def get_weight_queues(self): """Get the per-worker weight initialization queues. @@ -535,27 +531,3 @@ def prepare_weights( # Fall back to default behavior return super().prepare_weights(weights, model_id, strategy, context) - - -class SharedMemWeightReceiver(WeightReceiver): - """Weight receiver for shared memory systems. - - Receives weight updates via shared memory buffers. Workers automatically - see weight updates without explicit message passing, providing zero-copy - weight synchronization. This is typically instantiated and managed by - :class:`SharedMemWeightSyncScheme`. - """ - - _transport: SharedMemTransport | None - - -class SharedMemWeightSender(WeightSender): - """Weight sender for shared memory systems. - - Sends weight updates by writing directly to shared memory buffers. - All workers automatically see updates without explicit communication, - providing zero-copy weight synchronization. This is typically instantiated - and managed by :class:`SharedMemWeightSyncScheme`. - """ - - _transport: SharedMemTransport | None diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 735c9e59804..4842aca7f79 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -187,13 +187,11 @@ def __init__( self.num_threads = num_threads self.strategy_name = strategy - def create_transport( - self, pipe_or_context: Any = None - ) -> VLLMDoubleBufferTransport: + def create_transport(self, **kwargs) -> VLLMDoubleBufferTransport: """Create transport for double-buffered storage. Args: - pipe_or_context: Not used for file-based transport (kept for API compatibility). + **kwargs: Not used for file-based transport (kept for API compatibility). Returns: A VLLMDoubleBufferTransport instance. diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index f57883e5cd8..ed5e969f4b4 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -441,7 +441,7 @@ def __init__( s.bind(("", 0)) self.master_port = s.getsockname()[1] - def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: + def create_transport(self, **kwargs) -> VLLMCollectiveTransport: """Create transport for collective communication. For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). @@ -449,7 +449,7 @@ def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: is more complex and typically handled by sender/receiver initialization. Args: - pipe_or_context: Not used for vLLM (kept for API compatibility). + **kwargs: Not used for vLLM (kept for API compatibility). Returns: A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called). diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 13a11b7b24b..22a0b6dbf6c 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -14,6 +14,7 @@ from tensordict import TensorDict, TensorDictBase from torch import nn +from torchrl._utils import logger as torchrl_logger __all__ = [ "TransportBackend", @@ -23,6 +24,7 @@ "WeightSyncScheme", ] +from torchrl.collectors.utils import _cast from torchrl.weight_update.utils import _resolve_model @@ -55,7 +57,7 @@ def synchronize_weights_on_sender(self) -> None: """ ... - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """Synchronize weights on worker side before collection starts. This is called once in each worker after initialization to receive @@ -257,18 +259,25 @@ def _set_context(self, context: Any, model_id: str | None = None) -> None: if model_id is not None: self._model_id = model_id - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + def _register_worker( + self, worker_idx: int, pipe_or_context: Any = None, **transport_kwargs + ) -> None: """Register a worker's communication pipe (internal). This is now handled by init_on_sender(). Only kept for internal use. Args: worker_idx: The worker index. - pipe_or_context: The pipe connection for this worker. + pipe_or_context: Legacy parameter (deprecated, use transport_kwargs). + **transport_kwargs: Transport-specific configuration. """ if worker_idx not in self._transports: + # Support legacy pipe_or_context for backward compatibility + if pipe_or_context is not None and not transport_kwargs: + # Legacy mode: try to infer kwargs from pipe_or_context + transport_kwargs = {"pipe": pipe_or_context} self._transports[worker_idx] = self._scheme.create_transport( - pipe_or_context + **transport_kwargs ) def _iterate_transports( @@ -328,6 +337,7 @@ def send( context = self._context_ref() if self._context_ref is not None else None # Let the scheme prepare the weights + torchrl_logger.debug("Preparing weights") prepared_weights = self._scheme.prepare_weights( weights=weights, model_id=self._model_id, @@ -337,15 +347,22 @@ def send( transports = list(self._iterate_transports(worker_ids)) + if not transports: + raise RuntimeError("No transports available.") + # Send to all workers first (non-blocking if transport supports it) + torchrl_logger.debug(f"Sending over transports {transports}") for transport in transports: if hasattr(transport, "send_weights_async"): + torchrl_logger.debug(f"Sending through {transport} asynchronously.") transport.send_weights_async(prepared_weights) else: # Fallback for transports that don't support async send + torchrl_logger.debug(f"Sending through {transport} synchronously.") transport.send_weights(prepared_weights) # Wait for all acknowledgments + torchrl_logger.debug("Waiting for acknowledgement") for transport in transports: if hasattr(transport, "wait_ack"): transport.wait_ack() @@ -417,7 +434,7 @@ def wait_async(self) -> None: self._pending_async = False self._pending_transports = None - def synchronize_weights(self) -> None: + def synchronize_weights(self, worker_idx: int | None = None) -> None: """Synchronize weights with workers before collection starts. This method is called once after workers are initialized to send @@ -429,7 +446,9 @@ def synchronize_weights(self) -> None: update weights. """ # For other schemes (SharedMemWeightSyncScheme, etc.), use transport's method - for transport in self._iterate_transports(): + for idx, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and idx != worker_idx: + continue transport.synchronize_weights_on_sender() def update_weights(self, weights: Any) -> None: @@ -495,15 +514,19 @@ def _register_model(self, model_ref: Any) -> None: """ self._model_ref = model_ref - def _register_worker_transport(self, pipe: Any) -> None: + def _register_worker_transport(self, pipe: Any = None, **transport_kwargs) -> None: """Register this worker's communication pipe (internal). This is now handled by init_on_receiver(). Only kept for internal use. Args: - pipe: The pipe connection for this worker. + pipe: Legacy parameter (deprecated, use transport_kwargs). + **transport_kwargs: Transport-specific configuration. """ - self._transport = self._scheme.create_transport(pipe) + # Support legacy pipe parameter for backward compatibility + if pipe is not None and not transport_kwargs: + transport_kwargs = {"pipe": pipe} + self._transport = self._scheme.create_transport(**transport_kwargs) def receive(self, timeout: float = 0.001) -> bool: """Check for and apply new weights (non-blocking). @@ -527,6 +550,7 @@ def receive(self, timeout: float = 0.001) -> bool: return False # Try to receive weights + torchrl_logger.debug(f"Calling receive_weights on transport {self._transport}") result = self._transport.receive_weights(timeout=timeout) if result is None: return False @@ -538,10 +562,12 @@ def receive(self, timeout: float = 0.001) -> bool: raise ValueError("No model registered") model = self._resolve_model_ref() + torchrl_logger.debug(f"Applying {weights=} on {model=}") self._strategy.apply_weights(model, weights) # Send acknowledgment if transport supports it if hasattr(self._transport, "send_ack"): + torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") self._transport.send_ack("updated") return True @@ -569,7 +595,7 @@ def synchronize_weights(self, worker_idx: int | None = None) -> None: worker_idx = getattr(self, "_worker_idx", None) # Call transport's synchronize method if available - weights = self._transport.synchronize_weights_on_worker(worker_idx) + weights = self._transport.synchronize_weights_on_receiver(worker_idx) # Apply weights to model if received (SharedMemTransport case) # For other transports (MPTransport, etc.), weights is None and synchronization @@ -635,12 +661,15 @@ class WeightSyncScheme(metaclass=abc.ABCMeta): The collector maintains a dict of {model_id: scheme} pairs. """ + _receiver_cls = WeightReceiver + _sender_cls = WeightSender + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): self.strategy = strategy self._sender = None self._receiver = None self._initialized_on_sender = False - self._initialized_on_worker = False + self._initialized_on_receiver = False @overload def init_on_sender( @@ -737,8 +766,8 @@ def init_on_sender( This method is called once in the collector's _run_processes() method, after workers have been started and are ready to receive messages. """ - result = self._init_on_sender_impl(*args, **kwargs) self._initialized_on_sender = True + result = self._init_on_sender_impl(*args, **kwargs) return result def _init_on_sender_impl(self, *args, **kwargs): @@ -748,9 +777,45 @@ def _init_on_sender_impl(self, *args, **kwargs): def initialized_on_sender(self): return getattr(self, "_initialized_on_sender", False) + @property + def initialized_on_receiver(self): + return getattr(self, "_initialized_on_receiver", False) + + def apply_weights(self, weights: TensorDictBase) -> None: + """Apply weights to the model.""" + if not self.initialized_on_receiver: + if self.initialized_on_sender: + raise RuntimeError("apply_weights() called on a sender side.") + raise RuntimeError( + "apply_weights() called before init_on_receiver has been called." + ) + return self._receiver.apply_weights(weights) + + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload def init_on_receiver( self, model_id: str, + context: None = None, + *, + worker_idx: int = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( + self, + *, + model_id: str, context: Any = None, **kwargs, ) -> None: @@ -765,8 +830,70 @@ def init_on_receiver( - .get_model(model_id: str) -> nn.Module **kwargs: Alternative to context (pipe, model, etc.) """ + self._initialized_on_receiver = True + result = self._init_on_receiver_impl( + model_id=model_id, context=context, **kwargs + ) + return result + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: raise NotImplementedError + def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: + if isinstance(model, torch.nn.Module): + td = TensorDict.from_module(model) + td = td.data.apply(_cast, td) + return td + # Return an empty TD + return TensorDict() + + def synchronize_weights(self, worker_idx: int | None = None) -> None: + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's or receiver synchronize_weights(). + """ + if self._initialized_on_sender: + self.synchronized_on_sender = True + if self._sender is None: + raise RuntimeError( + "self._sender is None. Check that init_on_sender() has been called." + ) + self._sender.synchronize_weights(worker_idx=worker_idx) + elif self._initialized_on_receiver: + self.synchronized_on_receiver = True + if self._receiver is None: + raise RuntimeError( + "self._receiver is None. Check that init_on_receiver() has been called." + ) + self._receiver.synchronize_weights(worker_idx=worker_idx) + else: + raise RuntimeError( + "Neither init_on_sender nor init_on_receiver have abeen called." + ) + + @property + def synchronized_on_sender(self): + return getattr(self, "_synchronized_on_sender", False) + + @synchronized_on_sender.setter + def synchronized_on_sender(self, value: bool): + self._synchronized_on_sender = value + + @property + def synchronized_on_receiver(self): + return getattr(self, "_synchronized_on_receiver", False) + + @synchronized_on_receiver.setter + def synchronized_on_receiver(self, value: bool): + self._synchronized_on_receiver = value + def get_sender(self) -> WeightSender: """Get the sender instance. @@ -791,7 +918,7 @@ def get_receiver(self) -> WeightReceiver: Raises: RuntimeError: If init_on_receiver() hasn't been called yet """ - if not self._initialized_on_worker or self._receiver is None: + if not self._initialized_on_receiver or self._receiver is None: raise RuntimeError( f"Must call init_on_receiver() before get_receiver() on {type(self).__name__}" ) @@ -809,7 +936,7 @@ def __getstate__(self): state["_sender"] = None state["_receiver"] = None state["_initialized_on_sender"] = False - state["_initialized_on_worker"] = False + state["_initialized_on_receiver"] = False return state def __setstate__(self, state): @@ -817,11 +944,11 @@ def __setstate__(self, state): self.__dict__.update(state) @abc.abstractmethod - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport(self, **kwargs) -> TransportBackend: """Create transport for communication. Args: - pipe_or_context: Either a pipe connection or context object to extract pipe from. + **kwargs: Transport-specific configuration parameters. Returns: A transport backend instance. @@ -840,7 +967,8 @@ def create_sender(self) -> WeightSender: Note: Typically you should use init_on_sender() followed by get_sender() instead. """ - return WeightSender(self) + self._sender = self._sender_cls(self) + return self._sender def create_receiver(self) -> WeightReceiver: """Create a receiver for this scheme. @@ -851,7 +979,8 @@ def create_receiver(self) -> WeightReceiver: Note: Typically you should use init_on_receiver() followed by get_receiver() instead. """ - return WeightReceiver(self) + self._receiver = self._receiver_cls(self) + return self._receiver def prepare_weights( self, @@ -901,3 +1030,24 @@ def prepare_weights( else: # Already extracted weights (TensorDict, dict, etc.) return weights + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> Any: + """Send the given weights to specified workers. + + Args: + weights: Weights to send (None to extract from source model) + worker_ids: Worker IDs to send to (None for all workers) + """ + if not self.initialized_on_sender: + raise RuntimeError("Sender must be initialized before sending weights") + self._sender.send(weights=weights, worker_ids=worker_ids) + + def receive(self) -> Any: + """Send the given weights.""" + if not self.initialized_on_receiver: + raise RuntimeError("Sender must be initialized before receiving weights") + self._receiver.receive()