Skip to content

Commit e98b8c0

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent b4bbd1d commit e98b8c0

File tree

5 files changed

+59
-20
lines changed

5 files changed

+59
-20
lines changed

examples/collectors/async_batched_collector.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
run many environments in parallel while automatically batching policy inference
55
through an :class:`~torchrl.modules.InferenceServer`.
66
7+
Architecture:
8+
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel using
9+
the chosen backend (``"multiprocessing"`` by default for true parallelism,
10+
or ``"threading"``/``"asyncio"``).
11+
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
12+
and runs a single forward pass.
13+
- A lightweight coordinator thread bridges the two: when an env finishes
14+
stepping its observation is submitted to the server, and when an action is
15+
ready the env is sent back for stepping -- all without synchronisation
16+
barriers.
17+
718
The user only supplies:
819
- A list of environment factories
920
- A policy (or policy factory)
10-
11-
The collector creates the ``AsyncEnvPool``, ``InferenceServer``, and
12-
``ThreadingTransport`` internally -- no manual wiring required.
1321
"""
1422
import torch.nn as nn
1523
from tensordict.nn import TensorDictModule

test/test_inference_server.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ def test_basic_collection(self):
733733
frames_per_batch=frames_per_batch,
734734
total_frames=total_frames,
735735
max_batch_size=num_envs,
736+
env_backend="threading",
736737
)
737738
total_collected = 0
738739
for batch in collector:
@@ -750,6 +751,7 @@ def test_policy_factory(self):
750751
frames_per_batch=10,
751752
total_frames=20,
752753
max_batch_size=num_envs,
754+
env_backend="threading",
753755
)
754756
total_collected = 0
755757
for batch in collector:
@@ -789,6 +791,7 @@ def test_yield_completed_trajectories(self):
789791
total_frames=30,
790792
yield_completed_trajectories=True,
791793
max_batch_size=num_envs,
794+
env_backend="threading",
792795
)
793796
count = 0
794797
for batch in collector:
@@ -806,6 +809,7 @@ def test_shutdown_idempotent(self):
806809
policy=policy,
807810
frames_per_batch=10,
808811
total_frames=10,
812+
env_backend="threading",
809813
)
810814
# Consume one batch to start
811815
for _batch in collector:
@@ -821,6 +825,7 @@ def test_endless_collector(self):
821825
policy=policy,
822826
frames_per_batch=10,
823827
total_frames=-1,
828+
env_backend="threading",
824829
)
825830
collected = 0
826831
for batch in collector:
@@ -830,18 +835,16 @@ def test_endless_collector(self):
830835
collector.shutdown()
831836
assert collected >= 50
832837

833-
def test_env_property(self):
834-
"""The env property returns an AsyncEnvPool."""
835-
from torchrl.envs import AsyncEnvPool
836-
838+
def test_num_envs(self):
839+
"""The collector knows the number of environments."""
837840
policy = _make_counting_policy()
838841
collector = AsyncBatchedCollector(
839842
create_env_fn=[_counting_env_factory] * 2,
840843
policy=policy,
841844
frames_per_batch=10,
842845
total_frames=10,
843846
)
844-
assert isinstance(collector.env, AsyncEnvPool)
847+
assert collector._num_envs == 2
845848
collector.shutdown()
846849

847850
def test_postproc(self):
@@ -859,6 +862,7 @@ def postproc(td):
859862
frames_per_batch=10,
860863
total_frames=20,
861864
postproc=postproc,
865+
env_backend="threading",
862866
)
863867
for _ in collector:
864868
pass

torchrl/modules/inference_server/_monarch.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,22 @@ def __init__(self, client: _MonarchInferenceClient, req_id: int):
2929
self._req_id = req_id
3030
self._result: Any = _SENTINEL
3131

32+
def done(self) -> bool:
33+
"""Return ``True`` if the result is available without blocking."""
34+
if self._result is not _SENTINEL:
35+
return True
36+
try:
37+
self._result = self._client._get_result(self._req_id, timeout=0)
38+
except queue.Empty:
39+
return False
40+
return True
41+
3242
def result(self, timeout: float | None = None) -> TensorDictBase:
3343
"""Block until the result is available."""
3444
if self._result is _SENTINEL:
35-
item = self._client._get_result(self._req_id, timeout=timeout)
36-
if isinstance(item, BaseException):
37-
raise item
38-
self._result = item
45+
self._result = self._client._get_result(self._req_id, timeout=timeout)
46+
if isinstance(self._result, BaseException):
47+
raise self._result
3948
return self._result
4049

4150

torchrl/modules/inference_server/_mp.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ def __init__(self, client: _MPInferenceClient, req_id: int):
3333
self._req_id = req_id
3434
self._result: Any = _SENTINEL
3535

36+
def done(self) -> bool:
37+
"""Return ``True`` if the result is available without blocking."""
38+
if self._result is not _SENTINEL:
39+
return True
40+
try:
41+
self._result = self._client._get_result(self._req_id, timeout=0)
42+
except queue.Empty:
43+
return False
44+
return True
45+
3646
def result(self, timeout: float | None = None) -> TensorDictBase:
3747
"""Block until the result is available.
3848
@@ -44,10 +54,9 @@ def result(self, timeout: float | None = None) -> TensorDictBase:
4454
Exception: if the server set an exception instead of a result.
4555
"""
4656
if self._result is _SENTINEL:
47-
item = self._client._get_result(self._req_id, timeout=timeout)
48-
if isinstance(item, BaseException):
49-
raise item
50-
self._result = item
57+
self._result = self._client._get_result(self._req_id, timeout=timeout)
58+
if isinstance(self._result, BaseException):
59+
raise self._result
5160
return self._result
5261

5362

torchrl/modules/inference_server/_ray.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,22 @@ def __init__(self, client: _RayInferenceClient, req_id: int):
3232
self._req_id = req_id
3333
self._result: Any = _SENTINEL
3434

35+
def done(self) -> bool:
36+
"""Return ``True`` if the result is available without blocking."""
37+
if self._result is not _SENTINEL:
38+
return True
39+
try:
40+
self._result = self._client._get_result(self._req_id, timeout=0)
41+
except queue.Empty:
42+
return False
43+
return True
44+
3545
def result(self, timeout: float | None = None) -> TensorDictBase:
3646
"""Block until the result is available."""
3747
if self._result is _SENTINEL:
38-
item = self._client._get_result(self._req_id, timeout=timeout)
39-
if isinstance(item, BaseException):
40-
raise item
41-
self._result = item
48+
self._result = self._client._get_result(self._req_id, timeout=timeout)
49+
if isinstance(self._result, BaseException):
50+
raise self._result
4251
return self._result
4352

4453

0 commit comments

Comments
 (0)