Skip to content

Commit 744f061

Browse files
authored
[BE] Strong limit on warning appearance for failing C++ binaries (#3115)
1 parent 79217c8 commit 744f061

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

torchrl/data/replay_buffers/samplers.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tensordict.utils import NestedKey
2121
from torch.utils._pytree import tree_map
2222
from torchrl._extension import EXTENSION_WARNING
23-
from torchrl._utils import _replace_last, logger
23+
from torchrl._utils import _replace_last, logger, RL_WARNINGS
2424
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
2525
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index
2626

@@ -32,7 +32,11 @@
3232
SumSegmentTreeFp64,
3333
)
3434
except ImportError:
35-
logger.warning(EXTENSION_WARNING)
35+
# Make default values
36+
MinSegmentTreeFp32 = None
37+
MinSegmentTreeFp64 = None
38+
SumSegmentTreeFp32 = None
39+
SumSegmentTreeFp64 = None
3640

3741
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
3842

@@ -418,6 +422,8 @@ def __init__(
418422
self.dtype = dtype
419423
self._max_priority_within_buffer = max_priority_within_buffer
420424
self._init()
425+
if RL_WARNINGS and SumSegmentTreeFp32 is None:
426+
logger.warning(EXTENSION_WARNING)
421427

422428
def __repr__(self):
423429
return f"{self.__class__.__name__}(alpha={self._alpha}, beta={self._beta}, eps={self._eps}, reduction={self.reduction})"
@@ -450,6 +456,22 @@ def __getstate__(self):
450456
return super().__getstate__()
451457

452458
def _init(self) -> None:
459+
if SumSegmentTreeFp32 is None:
460+
raise RuntimeError(
461+
"SumSegmentTreeFp32 is not available. See warning above."
462+
)
463+
if MinSegmentTreeFp32 is None:
464+
raise RuntimeError(
465+
"MinSegmentTreeFp32 is not available. See warning above."
466+
)
467+
if SumSegmentTreeFp64 is None:
468+
raise RuntimeError(
469+
"SumSegmentTreeFp64 is not available. See warning above."
470+
)
471+
if MinSegmentTreeFp64 is None:
472+
raise RuntimeError(
473+
"MinSegmentTreeFp64 is not available. See warning above."
474+
)
453475
if self.dtype in (torch.float, torch.FloatType, torch.float32):
454476
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
455477
self._min_tree = MinSegmentTreeFp32(self._max_capacity)

0 commit comments

Comments
 (0)