|
20 | 20 | from tensordict.utils import NestedKey
|
21 | 21 | from torch.utils._pytree import tree_map
|
22 | 22 | from torchrl._extension import EXTENSION_WARNING
|
23 |
| -from torchrl._utils import _replace_last, logger |
| 23 | +from torchrl._utils import _replace_last, logger, RL_WARNINGS |
24 | 24 | from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
|
25 | 25 | from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index
|
26 | 26 |
|
|
32 | 32 | SumSegmentTreeFp64,
|
33 | 33 | )
|
34 | 34 | except ImportError:
|
35 |
| - logger.warning(EXTENSION_WARNING) |
| 35 | + # Make default values |
| 36 | + MinSegmentTreeFp32 = None |
| 37 | + MinSegmentTreeFp64 = None |
| 38 | + SumSegmentTreeFp32 = None |
| 39 | + SumSegmentTreeFp64 = None |
36 | 40 |
|
37 | 41 | _EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
|
38 | 42 |
|
@@ -418,6 +422,8 @@ def __init__(
|
418 | 422 | self.dtype = dtype
|
419 | 423 | self._max_priority_within_buffer = max_priority_within_buffer
|
420 | 424 | self._init()
|
| 425 | + if RL_WARNINGS and SumSegmentTreeFp32 is None: |
| 426 | + logger.warning(EXTENSION_WARNING) |
421 | 427 |
|
422 | 428 | def __repr__(self):
|
423 | 429 | return f"{self.__class__.__name__}(alpha={self._alpha}, beta={self._beta}, eps={self._eps}, reduction={self.reduction})"
|
@@ -450,6 +456,22 @@ def __getstate__(self):
|
450 | 456 | return super().__getstate__()
|
451 | 457 |
|
452 | 458 | def _init(self) -> None:
|
| 459 | + if SumSegmentTreeFp32 is None: |
| 460 | + raise RuntimeError( |
| 461 | + "SumSegmentTreeFp32 is not available. See warning above." |
| 462 | + ) |
| 463 | + if MinSegmentTreeFp32 is None: |
| 464 | + raise RuntimeError( |
| 465 | + "MinSegmentTreeFp32 is not available. See warning above." |
| 466 | + ) |
| 467 | + if SumSegmentTreeFp64 is None: |
| 468 | + raise RuntimeError( |
| 469 | + "SumSegmentTreeFp64 is not available. See warning above." |
| 470 | + ) |
| 471 | + if MinSegmentTreeFp64 is None: |
| 472 | + raise RuntimeError( |
| 473 | + "MinSegmentTreeFp64 is not available. See warning above." |
| 474 | + ) |
453 | 475 | if self.dtype in (torch.float, torch.FloatType, torch.float32):
|
454 | 476 | self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
|
455 | 477 | self._min_tree = MinSegmentTreeFp32(self._max_capacity)
|
|
0 commit comments