Skip to content

Commit 50fb5ba

Browse files
committed
Update
[ghstack-poisoned]
2 parents 113917d + f2402c1 commit 50fb5ba

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

torchrl/modules/inference_server/_monarch.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,22 @@ def __init__(self, client: _MonarchInferenceClient, req_id: int):
2929
self._req_id = req_id
3030
self._result: Any = _SENTINEL
3131

32+
def done(self) -> bool:
33+
"""Return ``True`` if the result is available without blocking."""
34+
if self._result is not _SENTINEL:
35+
return True
36+
try:
37+
self._result = self._client._get_result(self._req_id, timeout=0)
38+
except queue.Empty:
39+
return False
40+
return True
41+
3242
def result(self, timeout: float | None = None) -> TensorDictBase:
3343
"""Block until the result is available."""
3444
if self._result is _SENTINEL:
35-
item = self._client._get_result(self._req_id, timeout=timeout)
36-
if isinstance(item, BaseException):
37-
raise item
38-
self._result = item
45+
self._result = self._client._get_result(self._req_id, timeout=timeout)
46+
if isinstance(self._result, BaseException):
47+
raise self._result
3948
return self._result
4049

4150

torchrl/modules/inference_server/_mp.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ def __init__(self, client: _MPInferenceClient, req_id: int):
3333
self._req_id = req_id
3434
self._result: Any = _SENTINEL
3535

36+
def done(self) -> bool:
37+
"""Return ``True`` if the result is available without blocking."""
38+
if self._result is not _SENTINEL:
39+
return True
40+
try:
41+
self._result = self._client._get_result(self._req_id, timeout=0)
42+
except queue.Empty:
43+
return False
44+
return True
45+
3646
def result(self, timeout: float | None = None) -> TensorDictBase:
3747
"""Block until the result is available.
3848
@@ -44,10 +54,9 @@ def result(self, timeout: float | None = None) -> TensorDictBase:
4454
Exception: if the server set an exception instead of a result.
4555
"""
4656
if self._result is _SENTINEL:
47-
item = self._client._get_result(self._req_id, timeout=timeout)
48-
if isinstance(item, BaseException):
49-
raise item
50-
self._result = item
57+
self._result = self._client._get_result(self._req_id, timeout=timeout)
58+
if isinstance(self._result, BaseException):
59+
raise self._result
5160
return self._result
5261

5362

torchrl/modules/inference_server/_ray.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,22 @@ def __init__(self, client: _RayInferenceClient, req_id: int):
3232
self._req_id = req_id
3333
self._result: Any = _SENTINEL
3434

35+
def done(self) -> bool:
36+
"""Return ``True`` if the result is available without blocking."""
37+
if self._result is not _SENTINEL:
38+
return True
39+
try:
40+
self._result = self._client._get_result(self._req_id, timeout=0)
41+
except queue.Empty:
42+
return False
43+
return True
44+
3545
def result(self, timeout: float | None = None) -> TensorDictBase:
3646
"""Block until the result is available."""
3747
if self._result is _SENTINEL:
38-
item = self._client._get_result(self._req_id, timeout=timeout)
39-
if isinstance(item, BaseException):
40-
raise item
41-
self._result = item
48+
self._result = self._client._get_result(self._req_id, timeout=timeout)
49+
if isinstance(self._result, BaseException):
50+
raise self._result
4251
return self._result
4352

4453

0 commit comments

Comments
 (0)