Skip to content

Commit c226646

Browse files
authored
[BE] Avoid warning users who don't care about PRB (#3099)
1 parent 93fcb02 commit c226646

File tree

8 files changed

+9
-182
lines changed

8 files changed

+9
-182
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ __pycache__/
1414
.installed.cfg
1515
MANIFEST
1616
build/
17-
data/
17+
./data/
1818
develop-eggs/
1919
dist/
2020
downloads/

build_tools/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

build_tools/setup_helpers/__init__.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

build_tools/setup_helpers/extension.py

Lines changed: 0 additions & 161 deletions
This file was deleted.

test/_utils_internal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import sys
1212
import time
1313
import unittest
14-
import warnings
1514
from functools import wraps
1615
from typing import Callable
1716

@@ -24,6 +23,7 @@
2423

2524
from torchrl._utils import (
2625
implement_for,
26+
logger,
2727
logger as torchrl_logger,
2828
RL_WARNINGS,
2929
seed_generator,
@@ -792,7 +792,7 @@ def _call_value_nets(
792792
)
793793
else:
794794
if RL_WARNINGS:
795-
warnings.warn(
795+
logger.warning(
796796
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
797797
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
798798
)

torchrl/_extension.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _is_nightly(version):
4848
" - make sure ninja and cmake were installed\n"
4949
" - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised\n"
5050
" - make sure the version of PyTorch you are using matches the one that was present in your virtual env during "
51-
f"setup. This package was built with PyTorch {pytorch_version}."
51+
f"setup. This package was built with PyTorch {pytorch_version}. You can deactivate this warning by setting the environment variable `RL_WARNINGS=0`."
5252
)
5353

5454
else:
@@ -59,5 +59,6 @@ def _is_nightly(version):
5959
"prioritized replay buffers can only be used with the PyTorch version they were built against. "
6060
f"This package was built with PyTorch {pytorch_version}. "
6161
"Workarounds include: (1) upgrading/downgrading PyTorch or TorchRL to compatible versions, "
62-
"or (2) making a local install using `pip install git+https://github.com/pytorch/rl.git@<version>`."
62+
"or (2) making a local install using `pip install git+https://github.com/pytorch/rl.git@<version>`. "
63+
"You can deactivate this warning by setting the environment variable `RL_WARNINGS=0`."
6364
)

torchrl/data/replay_buffers/samplers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
from tensordict import MemoryMappedTensor, TensorDict
2020
from tensordict.utils import NestedKey
21-
2221
from torch.utils._pytree import tree_map
2322
from torchrl._extension import EXTENSION_WARNING
2423
from torchrl._utils import _replace_last, logger
@@ -33,7 +32,7 @@
3332
SumSegmentTreeFp64,
3433
)
3534
except ImportError:
36-
warnings.warn(EXTENSION_WARNING)
35+
logger.warning(EXTENSION_WARNING)
3736

3837
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
3938

torchrl/objectives/value/advantages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tensordict.utils import NestedKey, unravel_key
2828
from torch import Tensor
2929

30-
from torchrl._utils import RL_WARNINGS
30+
from torchrl._utils import logger, RL_WARNINGS
3131
from torchrl.envs.utils import step_mdp
3232
from torchrl.objectives.utils import (
3333
_maybe_get_or_select,
@@ -452,7 +452,7 @@ def _call_value_nets(
452452
ndim = list(data.names).index("time") + 1
453453
except ValueError:
454454
if RL_WARNINGS:
455-
warnings.warn(
455+
logger.warning(
456456
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
457457
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
458458
)

0 commit comments

Comments
 (0)