@@ -733,6 +733,7 @@ def test_basic_collection(self):
733733 frames_per_batch = frames_per_batch ,
734734 total_frames = total_frames ,
735735 max_batch_size = num_envs ,
736+ env_backend = "threading" ,
736737 )
737738 total_collected = 0
738739 for batch in collector :
@@ -750,6 +751,7 @@ def test_policy_factory(self):
750751 frames_per_batch = 10 ,
751752 total_frames = 20 ,
752753 max_batch_size = num_envs ,
754+ env_backend = "threading" ,
753755 )
754756 total_collected = 0
755757 for batch in collector :
@@ -789,6 +791,7 @@ def test_yield_completed_trajectories(self):
789791 total_frames = 30 ,
790792 yield_completed_trajectories = True ,
791793 max_batch_size = num_envs ,
794+ env_backend = "threading" ,
792795 )
793796 count = 0
794797 for batch in collector :
@@ -806,6 +809,7 @@ def test_shutdown_idempotent(self):
806809 policy = policy ,
807810 frames_per_batch = 10 ,
808811 total_frames = 10 ,
812+ env_backend = "threading" ,
809813 )
810814 # Consume one batch to start
811815 for _batch in collector :
@@ -821,6 +825,7 @@ def test_endless_collector(self):
821825 policy = policy ,
822826 frames_per_batch = 10 ,
823827 total_frames = - 1 ,
828+ env_backend = "threading" ,
824829 )
825830 collected = 0
826831 for batch in collector :
@@ -830,18 +835,16 @@ def test_endless_collector(self):
830835 collector .shutdown ()
831836 assert collected >= 50
832837
833- def test_env_property (self ):
834- """The env property returns an AsyncEnvPool."""
835- from torchrl .envs import AsyncEnvPool
836-
838+ def test_num_envs (self ):
839+ """The collector knows the number of environments."""
837840 policy = _make_counting_policy ()
838841 collector = AsyncBatchedCollector (
839842 create_env_fn = [_counting_env_factory ] * 2 ,
840843 policy = policy ,
841844 frames_per_batch = 10 ,
842845 total_frames = 10 ,
843846 )
844- assert isinstance ( collector .env , AsyncEnvPool )
847+ assert collector ._num_envs == 2
845848 collector .shutdown ()
846849
847850 def test_postproc (self ):
@@ -859,6 +862,7 @@ def postproc(td):
859862 frames_per_batch = 10 ,
860863 total_frames = 20 ,
861864 postproc = postproc ,
865+ env_backend = "threading" ,
862866 )
863867 for _ in collector :
864868 pass
0 commit comments