Skip to content

Commit d192a07

Browse files
vmoenscursoragent
andcommitted
[Feature] AsyncBatchedCollector: backend params and performance optimizations (#3511)
- Three-tier backend system: `backend` (global default), `env_backend` (env pool override), `policy_backend` (transport override), mirroring the device parameter pattern. - Lock-free SlotTransport: per-env slots with no shared lock, replacing ThreadingTransport as the default for in-process threading. - min_batch_size parameter for InferenceServer to accumulate requests. - Batch drain from result queue (get_nowait after first blocking get). - Remove redundant .copy() in ProcessorAsyncEnvPool._env_exec. Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 3e3cd93 Pull-Request: #3511 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 0dca89a commit d192a07

File tree

9 files changed

+531
-40
lines changed

9 files changed

+531
-40
lines changed

benchmarks/bench_collectors.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
2. Collector (ParallelEnv x N) -- single-process, N envs in sub-procs
1313
3. MultiCollector (sync, x N) -- N sub-processes, sync delivery
1414
4. MultiCollector (async, x N) -- N sub-processes, async delivery
15-
5. AsyncBatchedCollector (threading) -- AsyncEnvPool + InferenceServer
15+
5. AsyncBatched (env=thread, pol=thread) -- threading pool + threading transport
16+
6. AsyncBatched (env=mp, pol=thread) -- multiprocessing pool + threading transport
1617
"""
1718
from __future__ import annotations
1819

@@ -368,33 +369,33 @@ def policy_factory():
368369
)
369370
)
370371

371-
# 5. AsyncBatchedCollector (threading backend)
372+
# 5. AsyncBatchedCollector (env=threading, policy=threading)
372373
results.append(
373374
bench(
374-
f"AsyncBatchedCollector threading (x{num_envs})",
375+
f"AsyncBatched env=thread pol=thread (x{num_envs})",
375376
lambda: AsyncBatchedCollector(
376377
create_env_fn=[make_env_fn] * num_envs,
377378
policy=policy_factory(),
378379
frames_per_batch=frames_per_batch,
379380
total_frames=-1,
380381
max_batch_size=num_envs,
381-
backend="threading",
382+
env_backend="threading",
382383
),
383384
target_frames=total_frames,
384385
)
385386
)
386387

387-
# 6. AsyncBatchedCollector (multiprocessing backend)
388+
# 6. AsyncBatchedCollector (env=multiprocessing, policy=threading)
388389
results.append(
389390
bench(
390-
f"AsyncBatchedCollector mp (x{num_envs})",
391+
f"AsyncBatched env=mp pol=thread (x{num_envs})",
391392
lambda: AsyncBatchedCollector(
392393
create_env_fn=[make_env_fn] * num_envs,
393394
policy=policy_factory(),
394395
frames_per_batch=frames_per_batch,
395396
total_frames=-1,
396397
max_batch_size=num_envs,
397-
backend="multiprocessing",
398+
env_backend="multiprocessing",
398399
),
399400
target_frames=total_frames,
400401
)

docs/source/reference/modules_inference_server.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Transport Backends
2828
:template: rl_template_noinherit.rst
2929

3030
ThreadingTransport
31+
SlotTransport
3132
MPTransport
3233
RayTransport
3334
MonarchTransport

examples/collectors/async_batched_collector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@
66
77
Architecture:
88
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel
9-
using the chosen backend (``"threading"`` or ``"multiprocessing"``).
9+
using the chosen ``env_backend`` (``"threading"`` or ``"multiprocessing"``).
1010
- One lightweight coordinator thread per environment owns a slot in the pool
1111
and an inference client.
1212
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
13-
and runs a single forward pass.
13+
and runs a single forward pass. The communication layer (transport) is
14+
controlled by ``policy_backend`` (``"threading"``, ``"multiprocessing"``,
15+
``"ray"``, or ``"monarch"``).
1416
- There is no global synchronisation barrier -- fast envs keep stepping
1517
while slow ones wait for inference.
1618
19+
Backend parameters:
20+
- ``backend`` -- global default for both env pool and policy transport.
21+
- ``env_backend`` -- override for the env pool (falls back to ``backend``).
22+
- ``policy_backend`` -- override for the transport (falls back to ``backend``).
23+
1724
The user only supplies:
1825
- A list of environment factories
1926
- A policy (or policy factory)

test/test_inference_server.py

Lines changed: 189 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
InferenceTransport,
2323
MPTransport,
2424
RayTransport,
25+
SlotTransport,
2526
ThreadingTransport,
2627
)
2728
from torchrl.modules.inference_server._monarch import MonarchTransport
@@ -728,7 +729,7 @@ def test_basic_collection(self):
728729
frames_per_batch=frames_per_batch,
729730
total_frames=total_frames,
730731
max_batch_size=num_envs,
731-
backend="threading",
732+
env_backend="threading",
732733
)
733734
total_collected = 0
734735
for batch in collector:
@@ -746,7 +747,7 @@ def test_policy_factory(self):
746747
frames_per_batch=10,
747748
total_frames=20,
748749
max_batch_size=num_envs,
749-
backend="threading",
750+
env_backend="threading",
750751
)
751752
total_collected = 0
752753
for batch in collector:
@@ -786,7 +787,7 @@ def test_yield_completed_trajectories(self):
786787
total_frames=30,
787788
yield_completed_trajectories=True,
788789
max_batch_size=num_envs,
789-
backend="threading",
790+
env_backend="threading",
790791
)
791792
count = 0
792793
for batch in collector:
@@ -804,7 +805,7 @@ def test_shutdown_idempotent(self):
804805
policy=policy,
805806
frames_per_batch=10,
806807
total_frames=10,
807-
backend="threading",
808+
env_backend="threading",
808809
)
809810
# Consume one batch to start
810811
for _batch in collector:
@@ -820,7 +821,7 @@ def test_endless_collector(self):
820821
policy=policy,
821822
frames_per_batch=10,
822823
total_frames=-1,
823-
backend="threading",
824+
env_backend="threading",
824825
)
825826
collected = 0
826827
for batch in collector:
@@ -857,9 +858,191 @@ def postproc(td):
857858
frames_per_batch=10,
858859
total_frames=20,
859860
postproc=postproc,
860-
backend="threading",
861+
env_backend="threading",
861862
)
862863
for _ in collector:
863864
pass
864865
collector.shutdown()
865866
assert called["count"] >= 1
867+
868+
869+
# =============================================================================
870+
# Tests: SlotTransport
871+
# =============================================================================
872+
873+
874+
class TestSlotTransport:
875+
def test_single_request(self):
876+
transport = SlotTransport(num_slots=4)
877+
policy = _make_policy()
878+
with InferenceServer(policy, transport, max_batch_size=4):
879+
client = transport.client()
880+
td = TensorDict({"observation": torch.randn(4)})
881+
result = client(td)
882+
assert "action" in result.keys()
883+
assert result["action"].shape == (2,)
884+
885+
def test_concurrent_actors(self):
886+
"""Multiple threads submit concurrently via slot clients."""
887+
n_actors = 4
888+
n_requests = 30
889+
transport = SlotTransport(num_slots=n_actors)
890+
policy = _make_policy()
891+
892+
results_per_actor: list[list[TensorDictBase]] = [[] for _ in range(n_actors)]
893+
clients = [transport.client() for _ in range(n_actors)]
894+
895+
def actor_fn(actor_id):
896+
for _ in range(n_requests):
897+
td = TensorDict({"observation": torch.randn(4)})
898+
result = clients[actor_id](td)
899+
results_per_actor[actor_id].append(result)
900+
901+
with InferenceServer(policy, transport, max_batch_size=n_actors):
902+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_actors) as pool:
903+
futs = [pool.submit(actor_fn, i) for i in range(n_actors)]
904+
concurrent.futures.wait(futs)
905+
for f in futs:
906+
f.result()
907+
908+
for actor_results in results_per_actor:
909+
assert len(actor_results) == n_requests
910+
for r in actor_results:
911+
assert "action" in r.keys()
912+
assert r["action"].shape == (2,)
913+
914+
def test_too_many_clients_raises(self):
915+
"""Creating more clients than slots raises RuntimeError."""
916+
transport = SlotTransport(num_slots=2)
917+
transport.client()
918+
transport.client()
919+
with pytest.raises(RuntimeError, match="slots"):
920+
transport.client()
921+
922+
def test_submit_raises(self):
923+
"""Direct submit() on SlotTransport is not supported."""
924+
transport = SlotTransport(num_slots=1)
925+
td = TensorDict({"observation": torch.randn(4)})
926+
with pytest.raises(NotImplementedError):
927+
transport.submit(td)
928+
929+
def test_exception_propagates(self):
930+
"""Model exceptions propagate through SlotTransport."""
931+
932+
def bad_model(td):
933+
raise ValueError("slot model error")
934+
935+
transport = SlotTransport(num_slots=1)
936+
with InferenceServer(bad_model, transport, max_batch_size=4):
937+
client = transport.client()
938+
td = TensorDict({"observation": torch.randn(4)})
939+
with pytest.raises(ValueError, match="slot model error"):
940+
client(td)
941+
942+
943+
# =============================================================================
944+
# Tests: min_batch_size
945+
# =============================================================================
946+
947+
948+
class TestMinBatchSize:
949+
def test_min_batch_size_accumulates(self):
950+
"""With min_batch_size > 1, the server waits for enough items."""
951+
min_bs = 4
952+
seen_sizes = []
953+
954+
def tracking_collate(items):
955+
seen_sizes.append(len(items))
956+
return lazy_stack(items)
957+
958+
transport = ThreadingTransport()
959+
policy = _make_policy()
960+
n = 8
961+
962+
with InferenceServer(
963+
policy,
964+
transport,
965+
max_batch_size=16,
966+
min_batch_size=min_bs,
967+
collate_fn=tracking_collate,
968+
timeout=1.0,
969+
):
970+
client = transport.client()
971+
# Submit items from threads to give the server time to accumulate
972+
with concurrent.futures.ThreadPoolExecutor(max_workers=n) as pool:
973+
futs = [
974+
pool.submit(
975+
lambda: client(TensorDict({"observation": torch.randn(4)}))
976+
)
977+
for _ in range(n)
978+
]
979+
for f in futs:
980+
f.result(timeout=10.0)
981+
982+
# At least one batch should have >= min_batch_size items
983+
assert any(s >= min_bs for s in seen_sizes)
984+
985+
986+
# =============================================================================
987+
# Tests: bugfix regressions
988+
# =============================================================================
989+
990+
991+
class TestShutdownPendingFutures:
992+
def test_shutdown_resolves_pending_futures(self):
993+
"""Pending futures receive an exception on shutdown (no hang)."""
994+
transport = ThreadingTransport()
995+
policy = _make_policy()
996+
server = InferenceServer(policy, transport, max_batch_size=1024)
997+
server.start()
998+
futures = [
999+
transport.submit(TensorDict({"observation": torch.randn(4)}))
1000+
for _ in range(5)
1001+
]
1002+
time.sleep(0.05)
1003+
server.shutdown(timeout=5.0)
1004+
for f in futures:
1005+
try:
1006+
f.result(timeout=2.0)
1007+
except Exception:
1008+
pass # exception is acceptable; hanging is not
1009+
1010+
1011+
class TestThreadingTransportNoLostSignals:
1012+
def test_rapid_submit_no_lost_signals(self):
1013+
"""Rapid submits from many threads don't lose signals."""
1014+
transport = ThreadingTransport()
1015+
policy = _make_policy()
1016+
n = 100
1017+
with InferenceServer(policy, transport, max_batch_size=4, timeout=0.001):
1018+
client = transport.client()
1019+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool:
1020+
futs = [
1021+
pool.submit(
1022+
lambda: client(TensorDict({"observation": torch.randn(4)}))
1023+
)
1024+
for _ in range(n)
1025+
]
1026+
results = [f.result(timeout=10.0) for f in futs]
1027+
assert len(results) == n
1028+
for r in results:
1029+
assert "action" in r.keys()
1030+
1031+
1032+
class TestWorkerCrashPropagation:
1033+
def test_worker_crash_propagates(self):
1034+
"""If the model always fails, the collector propagates the error."""
1035+
1036+
def bad_model(td):
1037+
raise RuntimeError("model crash")
1038+
1039+
collector = AsyncBatchedCollector(
1040+
create_env_fn=[_counting_env_factory] * 2,
1041+
policy=bad_model,
1042+
frames_per_batch=10,
1043+
total_frames=100,
1044+
)
1045+
with pytest.raises(RuntimeError, match="worker thread"):
1046+
for _ in collector:
1047+
pass
1048+
collector.shutdown()

0 commit comments

Comments
 (0)