2222 InferenceTransport ,
2323 MPTransport ,
2424 RayTransport ,
25+ SlotTransport ,
2526 ThreadingTransport ,
2627)
2728from torchrl .modules .inference_server ._monarch import MonarchTransport
@@ -728,7 +729,7 @@ def test_basic_collection(self):
728729 frames_per_batch = frames_per_batch ,
729730 total_frames = total_frames ,
730731 max_batch_size = num_envs ,
731- backend = "threading" ,
732+ env_backend = "threading" ,
732733 )
733734 total_collected = 0
734735 for batch in collector :
@@ -746,7 +747,7 @@ def test_policy_factory(self):
746747 frames_per_batch = 10 ,
747748 total_frames = 20 ,
748749 max_batch_size = num_envs ,
749- backend = "threading" ,
750+ env_backend = "threading" ,
750751 )
751752 total_collected = 0
752753 for batch in collector :
@@ -786,7 +787,7 @@ def test_yield_completed_trajectories(self):
786787 total_frames = 30 ,
787788 yield_completed_trajectories = True ,
788789 max_batch_size = num_envs ,
789- backend = "threading" ,
790+ env_backend = "threading" ,
790791 )
791792 count = 0
792793 for batch in collector :
@@ -804,7 +805,7 @@ def test_shutdown_idempotent(self):
804805 policy = policy ,
805806 frames_per_batch = 10 ,
806807 total_frames = 10 ,
807- backend = "threading" ,
808+ env_backend = "threading" ,
808809 )
809810 # Consume one batch to start
810811 for _batch in collector :
@@ -820,7 +821,7 @@ def test_endless_collector(self):
820821 policy = policy ,
821822 frames_per_batch = 10 ,
822823 total_frames = - 1 ,
823- backend = "threading" ,
824+ env_backend = "threading" ,
824825 )
825826 collected = 0
826827 for batch in collector :
@@ -857,9 +858,191 @@ def postproc(td):
857858 frames_per_batch = 10 ,
858859 total_frames = 20 ,
859860 postproc = postproc ,
860- backend = "threading" ,
861+ env_backend = "threading" ,
861862 )
862863 for _ in collector :
863864 pass
864865 collector .shutdown ()
865866 assert called ["count" ] >= 1
867+
868+
869+ # =============================================================================
870+ # Tests: SlotTransport
871+ # =============================================================================
872+
873+
874+ class TestSlotTransport :
875+ def test_single_request (self ):
876+ transport = SlotTransport (num_slots = 4 )
877+ policy = _make_policy ()
878+ with InferenceServer (policy , transport , max_batch_size = 4 ):
879+ client = transport .client ()
880+ td = TensorDict ({"observation" : torch .randn (4 )})
881+ result = client (td )
882+ assert "action" in result .keys ()
883+ assert result ["action" ].shape == (2 ,)
884+
885+ def test_concurrent_actors (self ):
886+ """Multiple threads submit concurrently via slot clients."""
887+ n_actors = 4
888+ n_requests = 30
889+ transport = SlotTransport (num_slots = n_actors )
890+ policy = _make_policy ()
891+
892+ results_per_actor : list [list [TensorDictBase ]] = [[] for _ in range (n_actors )]
893+ clients = [transport .client () for _ in range (n_actors )]
894+
895+ def actor_fn (actor_id ):
896+ for _ in range (n_requests ):
897+ td = TensorDict ({"observation" : torch .randn (4 )})
898+ result = clients [actor_id ](td )
899+ results_per_actor [actor_id ].append (result )
900+
901+ with InferenceServer (policy , transport , max_batch_size = n_actors ):
902+ with concurrent .futures .ThreadPoolExecutor (max_workers = n_actors ) as pool :
903+ futs = [pool .submit (actor_fn , i ) for i in range (n_actors )]
904+ concurrent .futures .wait (futs )
905+ for f in futs :
906+ f .result ()
907+
908+ for actor_results in results_per_actor :
909+ assert len (actor_results ) == n_requests
910+ for r in actor_results :
911+ assert "action" in r .keys ()
912+ assert r ["action" ].shape == (2 ,)
913+
914+ def test_too_many_clients_raises (self ):
915+ """Creating more clients than slots raises RuntimeError."""
916+ transport = SlotTransport (num_slots = 2 )
917+ transport .client ()
918+ transport .client ()
919+ with pytest .raises (RuntimeError , match = "slots" ):
920+ transport .client ()
921+
922+ def test_submit_raises (self ):
923+ """Direct submit() on SlotTransport is not supported."""
924+ transport = SlotTransport (num_slots = 1 )
925+ td = TensorDict ({"observation" : torch .randn (4 )})
926+ with pytest .raises (NotImplementedError ):
927+ transport .submit (td )
928+
929+ def test_exception_propagates (self ):
930+ """Model exceptions propagate through SlotTransport."""
931+
932+ def bad_model (td ):
933+ raise ValueError ("slot model error" )
934+
935+ transport = SlotTransport (num_slots = 1 )
936+ with InferenceServer (bad_model , transport , max_batch_size = 4 ):
937+ client = transport .client ()
938+ td = TensorDict ({"observation" : torch .randn (4 )})
939+ with pytest .raises (ValueError , match = "slot model error" ):
940+ client (td )
941+
942+
943+ # =============================================================================
944+ # Tests: min_batch_size
945+ # =============================================================================
946+
947+
948+ class TestMinBatchSize :
949+ def test_min_batch_size_accumulates (self ):
950+ """With min_batch_size > 1, the server waits for enough items."""
951+ min_bs = 4
952+ seen_sizes = []
953+
954+ def tracking_collate (items ):
955+ seen_sizes .append (len (items ))
956+ return lazy_stack (items )
957+
958+ transport = ThreadingTransport ()
959+ policy = _make_policy ()
960+ n = 8
961+
962+ with InferenceServer (
963+ policy ,
964+ transport ,
965+ max_batch_size = 16 ,
966+ min_batch_size = min_bs ,
967+ collate_fn = tracking_collate ,
968+ timeout = 1.0 ,
969+ ):
970+ client = transport .client ()
971+ # Submit items from threads to give the server time to accumulate
972+ with concurrent .futures .ThreadPoolExecutor (max_workers = n ) as pool :
973+ futs = [
974+ pool .submit (
975+ lambda : client (TensorDict ({"observation" : torch .randn (4 )}))
976+ )
977+ for _ in range (n )
978+ ]
979+ for f in futs :
980+ f .result (timeout = 10.0 )
981+
982+ # At least one batch should have >= min_batch_size items
983+ assert any (s >= min_bs for s in seen_sizes )
984+
985+
986+ # =============================================================================
987+ # Tests: bugfix regressions
988+ # =============================================================================
989+
990+
991+ class TestShutdownPendingFutures :
992+ def test_shutdown_resolves_pending_futures (self ):
993+ """Pending futures receive an exception on shutdown (no hang)."""
994+ transport = ThreadingTransport ()
995+ policy = _make_policy ()
996+ server = InferenceServer (policy , transport , max_batch_size = 1024 )
997+ server .start ()
998+ futures = [
999+ transport .submit (TensorDict ({"observation" : torch .randn (4 )}))
1000+ for _ in range (5 )
1001+ ]
1002+ time .sleep (0.05 )
1003+ server .shutdown (timeout = 5.0 )
1004+ for f in futures :
1005+ try :
1006+ f .result (timeout = 2.0 )
1007+ except Exception :
1008+ pass # exception is acceptable; hanging is not
1009+
1010+
1011+ class TestThreadingTransportNoLostSignals :
1012+ def test_rapid_submit_no_lost_signals (self ):
1013+ """Rapid submits from many threads don't lose signals."""
1014+ transport = ThreadingTransport ()
1015+ policy = _make_policy ()
1016+ n = 100
1017+ with InferenceServer (policy , transport , max_batch_size = 4 , timeout = 0.001 ):
1018+ client = transport .client ()
1019+ with concurrent .futures .ThreadPoolExecutor (max_workers = 8 ) as pool :
1020+ futs = [
1021+ pool .submit (
1022+ lambda : client (TensorDict ({"observation" : torch .randn (4 )}))
1023+ )
1024+ for _ in range (n )
1025+ ]
1026+ results = [f .result (timeout = 10.0 ) for f in futs ]
1027+ assert len (results ) == n
1028+ for r in results :
1029+ assert "action" in r .keys ()
1030+
1031+
1032+ class TestWorkerCrashPropagation :
1033+ def test_worker_crash_propagates (self ):
1034+ """If the model always fails, the collector propagates the error."""
1035+
1036+ def bad_model (td ):
1037+ raise RuntimeError ("model crash" )
1038+
1039+ collector = AsyncBatchedCollector (
1040+ create_env_fn = [_counting_env_factory ] * 2 ,
1041+ policy = bad_model ,
1042+ frames_per_batch = 10 ,
1043+ total_frames = 100 ,
1044+ )
1045+ with pytest .raises (RuntimeError , match = "worker thread" ):
1046+ for _ in collector :
1047+ pass
1048+ collector .shutdown ()
0 commit comments