File tree Expand file tree Collapse file tree 5 files changed +15
-26
lines changed Expand file tree Collapse file tree 5 files changed +15
-26
lines changed Original file line number Diff line number Diff line change @@ -2361,7 +2361,7 @@ def make_env():
23612361class TestLibThreading :
23622362 @pytest .mark .skipif (
23632363 IS_OSX ,
2364- reason = "setting different threads across workeres can randomly fail on OSX." ,
2364+ reason = "setting different threads across workers can randomly fail on OSX." ,
23652365 )
23662366 def test_num_threads (self ):
23672367 from torchrl .collectors import collectors
@@ -2396,7 +2396,7 @@ def test_num_threads(self):
23962396
23972397 @pytest .mark .skipif (
23982398 IS_OSX ,
2399- reason = "setting different threads across workeres can randomly fail on OSX." ,
2399+ reason = "setting different threads across workers can randomly fail on OSX." ,
24002400 )
24012401 def test_auto_num_threads (self ):
24022402 init_threads = torch .get_num_threads ()
Original file line number Diff line number Diff line change @@ -2337,7 +2337,7 @@ def test_terminated_or_truncated_spec(self):
23372337class TestLibThreading :
23382338 @pytest .mark .skipif (
23392339 IS_OSX ,
2340- reason = "setting different threads across workeres can randomly fail on OSX." ,
2340+ reason = "setting different threads across workers can randomly fail on OSX." ,
23412341 )
23422342 def test_num_threads (self ):
23432343 from torchrl .envs import batched_envs
@@ -2363,18 +2363,18 @@ def test_num_threads(self):
23632363
23642364 @pytest .mark .skipif (
23652365 IS_OSX ,
2366- reason = "setting different threads across workeres can randomly fail on OSX." ,
2366+ reason = "setting different threads across workers can randomly fail on OSX." ,
23672367 )
23682368 def test_auto_num_threads (self ):
23692369 init_threads = torch .get_num_threads ()
23702370
23712371 try :
2372- env3 = ParallelEnv (3 , lambda : GymEnv ( "Pendulum-v1" ) )
2372+ env3 = ParallelEnv (3 , ContinuousActionVecMockEnv )
23732373 env3 .rollout (2 )
23742374
23752375 assert torch .get_num_threads () == max (1 , init_threads - 3 )
23762376
2377- env2 = ParallelEnv (2 , lambda : GymEnv ( "Pendulum-v1" ) )
2377+ env2 = ParallelEnv (2 , ContinuousActionVecMockEnv )
23782378 env2 .rollout (2 )
23792379
23802380 assert torch .get_num_threads () == max (1 , init_threads - 5 )
Original file line number Diff line number Diff line change 5151filter_warnings_subprocess = True
5252
5353_THREAD_POOL_INIT = torch .get_num_threads ()
54- _THREAD_POOL = torch .get_num_threads ()
Original file line number Diff line number Diff line change @@ -1607,18 +1607,12 @@ def _queue_len(self) -> int:
16071607
16081608 def _run_processes (self ) -> None :
16091609 if self .num_threads is None :
1610- import torchrl
1611-
16121610 total_workers = self ._total_workers_from_env (self .create_env_fn )
16131611 self .num_threads = max (
1614- 1 , torchrl . _THREAD_POOL - total_workers
1612+ 1 , torch . get_num_threads () - total_workers
16151613 ) # 1 more thread for this proc
16161614
16171615 torch .set_num_threads (self .num_threads )
1618- assert torch .get_num_threads () == self .num_threads
1619- import torchrl
1620-
1621- torchrl ._THREAD_POOL = self .num_threads
16221616 queue_out = mp .Queue (self ._queue_len ) # sends data from proc to main
16231617 self .procs = []
16241618 self .pipes = []
@@ -1727,11 +1721,12 @@ def _shutdown_main(self) -> None:
17271721 finally :
17281722 import torchrl
17291723
1730- torchrl . _THREAD_POOL = min (
1724+ num_threads = min (
17311725 torchrl ._THREAD_POOL_INIT ,
1732- torchrl ._THREAD_POOL + self ._total_workers_from_env (self .create_env_fn ),
1726+ torch .get_num_threads ()
1727+ + self ._total_workers_from_env (self .create_env_fn ),
17331728 )
1734- torch .set_num_threads (torchrl . _THREAD_POOL )
1729+ torch .set_num_threads (num_threads )
17351730
17361731 for proc in self .procs :
17371732 if proc .is_alive ():
Original file line number Diff line number Diff line change @@ -633,10 +633,10 @@ def close(self) -> None:
633633 self .is_closed = True
634634 import torchrl
635635
636- torchrl . _THREAD_POOL = min (
637- torchrl ._THREAD_POOL_INIT , torchrl . _THREAD_POOL + self .num_workers
636+ num_threads = min (
637+ torchrl ._THREAD_POOL_INIT , torch . get_num_threads () + self .num_workers
638638 )
639- torch .set_num_threads (torchrl . _THREAD_POOL )
639+ torch .set_num_threads (num_threads )
640640
641641 def _shutdown_workers (self ) -> None :
642642 raise NotImplementedError
@@ -1015,16 +1015,11 @@ def _start_workers(self) -> None:
10151015 from torchrl .envs .env_creator import EnvCreator
10161016
10171017 if self .num_threads is None :
1018- import torchrl
1019-
10201018 self .num_threads = max (
1021- 1 , torchrl . _THREAD_POOL - self .num_workers
1019+ 1 , torch . get_num_threads () - self .num_workers
10221020 ) # 1 more thread for this proc
10231021
10241022 torch .set_num_threads (self .num_threads )
1025- import torchrl
1026-
1027- torchrl ._THREAD_POOL = self .num_threads
10281023
10291024 ctx = mp .get_context ("spawn" )
10301025
You can’t perform that action at this time.
0 commit comments