1313import subprocess
1414import sys
1515import time
16+ from contextlib import nullcontext
1617from unittest .mock import patch
1718
1819import numpy as np
1920import pytest
2021import torch
22+
23+ import torchrl .collectors ._runner
2124from packaging import version
2225from tensordict import (
2326 assert_allclose_td ,
3336 TensorDictSequential ,
3437)
3538from torch import nn
36-
3739from torchrl ._utils import (
3840 _make_ordinal_device ,
3941 _replace_last ,
4850 SyncDataCollector ,
4951 WeightUpdaterBase ,
5052)
51- from torchrl .collectors .collectors import _Interruptor
53+ from torchrl .collectors ._constants import _Interruptor
5254
5355from torchrl .collectors .utils import split_trajectories
5456from torchrl .data import (
@@ -1487,12 +1489,14 @@ def env_fn(seed):
14871489 assert_allclose_td (data10 , data20 )
14881490
14891491 @pytest .mark .parametrize ("use_async" , [False , True ])
1490- @pytest .mark .parametrize ("cudagraph" , [False , True ])
1492+ @pytest .mark .parametrize (
1493+ "cudagraph" , [False , True ] if torch .cuda .is_available () else [False ]
1494+ )
14911495 @pytest .mark .parametrize (
14921496 "weight_sync_scheme" ,
14931497 [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
14941498 )
1495- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1499+ # @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available() , reason="no cuda/mps device found")
14961500 def test_update_weights (self , use_async , cudagraph , weight_sync_scheme ):
14971501 def create_env ():
14981502 return ContinuousActionVecMockEnv ()
@@ -1509,11 +1513,12 @@ def create_env():
15091513 kwargs = {}
15101514 if weight_sync_scheme is not None :
15111515 kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
1516+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
15121517 collector = collector_class (
15131518 [create_env ] * 3 ,
15141519 policy = policy ,
1515- device = [torch .device ("cuda:0" )] * 3 ,
1516- storing_device = [torch .device ("cuda:0" )] * 3 ,
1520+ device = [torch .device (device )] * 3 ,
1521+ storing_device = [torch .device (device )] * 3 ,
15171522 frames_per_batch = 20 ,
15181523 cat_results = "stack" ,
15191524 cudagraph_policy = cudagraph ,
@@ -1544,7 +1549,9 @@ def create_env():
15441549 # check they don't match
15451550 for worker in range (3 ):
15461551 for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1547- with pytest .raises (AssertionError ):
1552+ with pytest .raises (
1553+ AssertionError
1554+ ) if torch .cuda .is_available () else nullcontext ():
15481555 torch .testing .assert_close (
15491556 state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
15501557 policy_state_dict [k ].cpu (),
@@ -2401,7 +2408,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs):
24012408 policy = UnwrappablePolicy (out_features = env_maker ().action_spec .shape [- 1 ])
24022409 with pytest .raises (
24032410 TypeError ,
2404- match = ("Arguments to policy.forward are incompatible with entries in" ),
2411+ match = (
2412+ "Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
2413+ ),
24052414 ):
24062415 collector_class (
24072416 ** self ._create_collector_kwargs (
@@ -2980,6 +2989,94 @@ def test_param_sync_mixed_device(
29802989 col .shutdown ()
29812990 del col
29822991
2992+ @pytest .mark .skipif (
2993+ not torch .cuda .is_available () or torch .cuda .device_count () < 3 ,
2994+ reason = "requires at least 3 CUDA devices" ,
2995+ )
2996+ @pytest .mark .parametrize (
2997+ "weight_sync_scheme" ,
2998+ [SharedMemWeightSyncScheme , MultiProcessWeightSyncScheme ],
2999+ )
3000+ def test_shared_device_weight_update (self , weight_sync_scheme ):
3001+ """Test that weight updates work correctly when multiple workers share the same device.
3002+
3003+ This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme.
3004+ When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through
3005+ dedicated queues, preventing race conditions that could occur with a single shared queue.
3006+ """
3007+ # Create policy on cuda:0
3008+ policy = TensorDictModule (
3009+ nn .Linear (7 , 7 , device = "cuda:0" ),
3010+ in_keys = ["observation" ],
3011+ out_keys = ["action" ],
3012+ )
3013+
3014+ def make_env ():
3015+ return ContinuousActionVecMockEnv ()
3016+
3017+ # Create collector with workers on cuda:2, cuda:1, cuda:2
3018+ # Workers 0 and 2 share cuda:2 - this is the key test case
3019+ collector = MultiaSyncDataCollector (
3020+ [make_env , make_env , make_env ],
3021+ policy = policy ,
3022+ frames_per_batch = 30 ,
3023+ total_frames = 300 ,
3024+ device = ["cuda:2" , "cuda:1" , "cuda:2" ],
3025+ storing_device = ["cuda:2" , "cuda:1" , "cuda:2" ],
3026+ weight_sync_schemes = {"policy" : weight_sync_scheme ()},
3027+ )
3028+
3029+ try :
3030+ # Collect first batch to initialize workers
3031+ for _ in collector :
3032+ break
3033+
3034+ # Get initial weights
3035+ old_weight = policy .module .weight .data .clone ()
3036+
3037+ # Modify policy weights on cuda:0
3038+ for p in policy .parameters ():
3039+ p .data += torch .randn_like (p )
3040+
3041+ new_weight = policy .module .weight .data .clone ()
3042+ assert not torch .allclose (
3043+ old_weight , new_weight
3044+ ), "Weights should have changed"
3045+
3046+ # Update weights - this should propagate to all workers via their dedicated queues
3047+ collector .update_policy_weights_ ()
3048+
3049+ # Collect more batches to ensure weights are propagated
3050+ for i , _ in enumerate (collector ):
3051+ if i >= 2 :
3052+ break
3053+
3054+ # Get state dict from all workers
3055+ state_dict = collector .state_dict ()
3056+
3057+ # Verify all workers have the new weights, including both workers on cuda:2
3058+ for worker_idx in range (3 ):
3059+ worker_key = f"worker{ worker_idx } "
3060+ assert (
3061+ "policy_state_dict" in state_dict [worker_key ]
3062+ ), f"Worker { worker_idx } should have policy_state_dict"
3063+ worker_weight = state_dict [worker_key ]["policy_state_dict" ][
3064+ "module.weight"
3065+ ]
3066+ torch .testing .assert_close (
3067+ worker_weight .cpu (),
3068+ new_weight .cpu (),
3069+ msg = (
3070+ f"Worker { worker_idx } weights don't match expected weights. "
3071+ f"Workers 0 and 2 share device cuda:2, worker 1 is on cuda:1. "
3072+ f"This test validates that the per-worker queue system correctly "
3073+ f"distributes weights even when multiple workers share a device."
3074+ ),
3075+ )
3076+ finally :
3077+ collector .shutdown ()
3078+ del collector
3079+
29833080
29843081class TestAggregateReset :
29853082 def test_aggregate_reset_to_root (self ):
@@ -3176,11 +3273,11 @@ class TestLibThreading:
31763273 reason = "setting different threads across workers can randomly fail on OSX." ,
31773274 )
31783275 def test_num_threads (self ):
3179- from torchrl . collectors import collectors
3276+ pass
31803277
3181- _main_async_collector_saved = collectors ._main_async_collector
3182- collectors ._main_async_collector = decorate_thread_sub_func (
3183- collectors ._main_async_collector , num_threads = 3
3278+ _main_async_collector_saved = torchrl . collectors . _runner ._main_async_collector
3279+ torchrl . collectors . _runner ._main_async_collector = decorate_thread_sub_func (
3280+ torchrl . collectors . _runner ._main_async_collector , num_threads = 3
31843281 )
31853282 num_threads = torch .get_num_threads ()
31863283 try :
@@ -3204,7 +3301,9 @@ def test_num_threads(self):
32043301 except Exception :
32053302 torchrl_logger .info ("Failed to shut down collector" )
32063303 # reset vals
3207- collectors ._main_async_collector = _main_async_collector_saved
3304+ torchrl .collectors ._runner ._main_async_collector = (
3305+ _main_async_collector_saved
3306+ )
32083307 torch .set_num_threads (num_threads )
32093308
32103309 @pytest .mark .skipif (
0 commit comments