2222 InferenceTransport ,
2323 MPTransport ,
2424 RayTransport ,
25+ SlotTransport ,
2526 ThreadingTransport ,
2627)
2728from torchrl .modules .inference_server ._monarch import MonarchTransport
@@ -728,7 +729,7 @@ def test_basic_collection(self):
728729 frames_per_batch = frames_per_batch ,
729730 total_frames = total_frames ,
730731 max_batch_size = num_envs ,
731- backend = "threading" ,
732+ env_backend = "threading" ,
732733 )
733734 total_collected = 0
734735 for batch in collector :
@@ -746,7 +747,7 @@ def test_policy_factory(self):
746747 frames_per_batch = 10 ,
747748 total_frames = 20 ,
748749 max_batch_size = num_envs ,
749- backend = "threading" ,
750+ env_backend = "threading" ,
750751 )
751752 total_collected = 0
752753 for batch in collector :
@@ -786,7 +787,7 @@ def test_yield_completed_trajectories(self):
786787 total_frames = 30 ,
787788 yield_completed_trajectories = True ,
788789 max_batch_size = num_envs ,
789- backend = "threading" ,
790+ env_backend = "threading" ,
790791 )
791792 count = 0
792793 for batch in collector :
@@ -804,7 +805,7 @@ def test_shutdown_idempotent(self):
804805 policy = policy ,
805806 frames_per_batch = 10 ,
806807 total_frames = 10 ,
807- backend = "threading" ,
808+ env_backend = "threading" ,
808809 )
809810 # Consume one batch to start
810811 for _batch in collector :
@@ -820,7 +821,7 @@ def test_endless_collector(self):
820821 policy = policy ,
821822 frames_per_batch = 10 ,
822823 total_frames = - 1 ,
823- backend = "threading" ,
824+ env_backend = "threading" ,
824825 )
825826 collected = 0
826827 for batch in collector :
@@ -857,9 +858,126 @@ def postproc(td):
857858 frames_per_batch = 10 ,
858859 total_frames = 20 ,
859860 postproc = postproc ,
860- backend = "threading" ,
861+ env_backend = "threading" ,
861862 )
862863 for _ in collector :
863864 pass
864865 collector .shutdown ()
865866 assert called ["count" ] >= 1
867+
868+
869+ # =============================================================================
870+ # Tests: SlotTransport
871+ # =============================================================================
872+
873+
874+ class TestSlotTransport :
875+ def test_single_request (self ):
876+ transport = SlotTransport (num_slots = 4 )
877+ policy = _make_policy ()
878+ with InferenceServer (policy , transport , max_batch_size = 4 ):
879+ client = transport .client ()
880+ td = TensorDict ({"observation" : torch .randn (4 )})
881+ result = client (td )
882+ assert "action" in result .keys ()
883+ assert result ["action" ].shape == (2 ,)
884+
885+ def test_concurrent_actors (self ):
886+ """Multiple threads submit concurrently via slot clients."""
887+ n_actors = 4
888+ n_requests = 30
889+ transport = SlotTransport (num_slots = n_actors )
890+ policy = _make_policy ()
891+
892+ results_per_actor : list [list [TensorDictBase ]] = [[] for _ in range (n_actors )]
893+ clients = [transport .client () for _ in range (n_actors )]
894+
895+ def actor_fn (actor_id ):
896+ for _ in range (n_requests ):
897+ td = TensorDict ({"observation" : torch .randn (4 )})
898+ result = clients [actor_id ](td )
899+ results_per_actor [actor_id ].append (result )
900+
901+ with InferenceServer (policy , transport , max_batch_size = n_actors ):
902+ with concurrent .futures .ThreadPoolExecutor (max_workers = n_actors ) as pool :
903+ futs = [pool .submit (actor_fn , i ) for i in range (n_actors )]
904+ concurrent .futures .wait (futs )
905+ for f in futs :
906+ f .result ()
907+
908+ for actor_results in results_per_actor :
909+ assert len (actor_results ) == n_requests
910+ for r in actor_results :
911+ assert "action" in r .keys ()
912+ assert r ["action" ].shape == (2 ,)
913+
914+ def test_too_many_clients_raises (self ):
915+ """Creating more clients than slots raises RuntimeError."""
916+ transport = SlotTransport (num_slots = 2 )
917+ transport .client ()
918+ transport .client ()
919+ with pytest .raises (RuntimeError , match = "slots" ):
920+ transport .client ()
921+
922+ def test_submit_raises (self ):
923+ """Direct submit() on SlotTransport is not supported."""
924+ transport = SlotTransport (num_slots = 1 )
925+ td = TensorDict ({"observation" : torch .randn (4 )})
926+ with pytest .raises (NotImplementedError ):
927+ transport .submit (td )
928+
929+ def test_exception_propagates (self ):
930+ """Model exceptions propagate through SlotTransport."""
931+
932+ def bad_model (td ):
933+ raise ValueError ("slot model error" )
934+
935+ transport = SlotTransport (num_slots = 1 )
936+ with InferenceServer (bad_model , transport , max_batch_size = 4 ):
937+ client = transport .client ()
938+ td = TensorDict ({"observation" : torch .randn (4 )})
939+ with pytest .raises (ValueError , match = "slot model error" ):
940+ client (td )
941+
942+
943+ # =============================================================================
944+ # Tests: min_batch_size
945+ # =============================================================================
946+
947+
948+ class TestMinBatchSize :
949+ def test_min_batch_size_accumulates (self ):
950+ """With min_batch_size > 1, the server waits for enough items."""
951+ min_bs = 4
952+ seen_sizes = []
953+
954+ def tracking_collate (items ):
955+ seen_sizes .append (len (items ))
956+ return lazy_stack (items )
957+
958+ transport = ThreadingTransport ()
959+ policy = _make_policy ()
960+ n = 8
961+
962+ with InferenceServer (
963+ policy ,
964+ transport ,
965+ max_batch_size = 16 ,
966+ min_batch_size = min_bs ,
967+ collate_fn = tracking_collate ,
968+ timeout = 1.0 ,
969+ ):
970+ client = transport .client ()
971+ # Submit items from threads to give the server time to accumulate
972+ with concurrent .futures .ThreadPoolExecutor (max_workers = n ) as pool :
973+ futs = [
974+ pool .submit (
975+ lambda : client (TensorDict ({"observation" : torch .randn (4 )}))
976+ )
977+ for _ in range (n )
978+ ]
979+ for f in futs :
980+ f .result (timeout = 10.0 )
981+
982+ # At least one batch should have >= min_batch_size items
983+ assert any (s >= min_bs for s in seen_sizes )
0 commit comments