Skip to content

Commit 56b0b9a

Browse files
author
Vincent Moens
committed
[BugFix] Fix old deps tests
ghstack-source-id: 134d129 Pull Request resolved: #2500
1 parent 9f6c21f commit 56b0b9a

File tree

8 files changed

+58
-17
lines changed

8 files changed

+58
-17
lines changed

test/test_collector.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
MultiKeyCountingEnvPolicy,
4444
NestedCountingEnv,
4545
)
46+
from packaging import version
4647
from tensordict import (
4748
assert_allclose_td,
4849
LazyStackedTensorDict,
@@ -106,6 +107,7 @@
106107
IS_OSX = sys.platform == "darwin"
107108
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
108109
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
110+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
109111

110112

111113
class WrappablePolicy(nn.Module):
@@ -2654,6 +2656,9 @@ def test_dynamic_multiasync_collector(self):
26542656
assert data.names[-1] == "time"
26552657

26562658

2659+
@pytest.mark.skipif(
2660+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
2661+
)
26572662
class TestCompile:
26582663
@pytest.mark.parametrize(
26592664
"collector_cls",
@@ -2996,8 +3001,9 @@ def __deepcopy_error__(*args, **kwargs):
29963001
raise RuntimeError("deepcopy not allowed")
29973002

29983003

2999-
@pytest.mark.filterwarnings("error")
3000-
@pytest.mark.filterwarnings("ignore:Tensordict is registered in PyTree")
3004+
@pytest.mark.filterwarnings(
3005+
"error::UserWarning", "ignore:Tensordict is registered in PyTree:UserWarning"
3006+
)
30013007
@pytest.mark.parametrize(
30023008
"collector_type",
30033009
[
@@ -3016,6 +3022,8 @@ def test_no_deepcopy_policy(collector_type):
30163022
# If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we
30173023
# can trust that the user knows what to do).
30183024

3025+
# warnings.warn("Tensordict is registered in PyTree", category=UserWarning)
3026+
30193027
shared_device = torch.device("cpu")
30203028
if torch.cuda.is_available():
30213029
original_device = torch.device("cuda:0")

test/test_cost.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
get_default_devices,
4747
)
4848
from mocking_classes import ContinuousActionConvMockEnv
49+
from packaging import version
4950

5051
# from torchrl.data.postprocs.utils import expand_as_right
5152
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
@@ -146,7 +147,7 @@
146147
_split_and_pad_sequence,
147148
)
148149

149-
TORCH_VERSION = torch.__version__
150+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
150151

151152
# Capture all warnings
152153
pytestmark = [
@@ -15731,7 +15732,9 @@ def __init__(self):
1573115732
assert p.device == dest
1573215733

1573315734

15734-
@pytest.mark.skipif(TORCH_VERSION < "2.5", reason="requires torch>=2.5")
15735+
@pytest.mark.skipif(
15736+
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
15737+
)
1573515738
def test_exploration_compile():
1573615739
m = ProbabilisticTensorDictModule(
1573715740
in_keys=["loc", "scale"],

test/test_distributions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from _utils_internal import get_default_devices
1414
from tensordict import TensorDictBase
1515
from torch import autograd, nn
16+
from torch.utils._pytree import tree_map
1617
from torchrl.modules import (
1718
NormalParamWrapper,
1819
OneHotCategorical,
@@ -182,7 +183,7 @@ class TestTruncatedNormal:
182183
@pytest.mark.parametrize("device", get_default_devices())
183184
def test_truncnormal(self, min, max, vecs, upscale, shape, device):
184185
torch.manual_seed(0)
185-
*vecs, min, max, vecs, upscale = torch.utils._pytree.tree_map(
186+
*vecs, min, max, vecs, upscale = tree_map(
186187
lambda t: torch.as_tensor(t, device=device),
187188
(*vecs, min, max, vecs, upscale),
188189
)

test/test_exploration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def test_consistent_dropout(self, dropout_p, parallel_spec, device):
757757

758758
# NOTE: Please only put a module with one dropout layer.
759759
# That's how this test is constructed anyways.
760-
@torch.no_grad
760+
@torch.no_grad()
761761
def inner_verify_routine(module, env):
762762
# Perform transitions.
763763
collector = SyncDataCollector(

test/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
make_dqn_actor,
5151
)
5252

53-
TORCH_VERSION = version.parse(torch.__version__)
53+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
5454
if TORCH_VERSION < version.parse("1.12.0"):
5555
UNSQUEEZE_SINGLETON = True
5656
else:

torchrl/collectors/collectors.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
)
3636
from tensordict.base import NO_DEFAULT
3737
from tensordict.nn import CudaGraphModule, TensorDictModule
38+
from tensordict.utils import Buffer
3839
from torch import multiprocessing as mp
40+
from torch.nn import Parameter
3941
from torch.utils.data import IterableDataset
4042

4143
from torchrl._utils import (
@@ -202,17 +204,17 @@ def map_weight(
202204
policy_device=policy_device,
203205
):
204206

205-
is_param = isinstance(weight, nn.Parameter)
206-
is_buffer = isinstance(weight, nn.Buffer)
207+
is_param = isinstance(weight, Parameter)
208+
is_buffer = isinstance(weight, Buffer)
207209
weight = weight.data
208210
if weight.device != policy_device:
209211
weight = weight.to(policy_device)
210212
elif weight.device.type in ("cpu", "mps"):
211213
weight = weight.share_memory_()
212214
if is_param:
213-
weight = nn.Parameter(weight, requires_grad=False)
215+
weight = Parameter(weight, requires_grad=False)
214216
elif is_buffer:
215-
weight = nn.Buffer(weight)
217+
weight = Buffer(weight)
216218
return weight
217219

218220
# Create a stateless policy, then populate this copy with params on device
@@ -3089,12 +3091,12 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
30893091

30903092

30913093
def _make_meta_params(param):
3092-
is_param = isinstance(param, nn.Parameter)
3094+
is_param = isinstance(param, Parameter)
30933095

30943096
pd = param.detach().to("meta")
30953097

30963098
if is_param:
3097-
pd = nn.Parameter(pd, requires_grad=False)
3099+
pd = Parameter(pd, requires_grad=False)
30983100
return pd
30993101

31003102

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensordict.nn.utils import _set_dispatch_td_nn_modules
3232
from tensordict.utils import expand_as_right, expand_right
3333
from torch import Tensor
34+
from torch.utils._pytree import tree_map
3435

3536
from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
3637
from torchrl.data.replay_buffers.samplers import (
@@ -319,9 +320,7 @@ def dim_extend(self, value):
319320
def _transpose(self, data):
320321
if is_tensor_collection(data):
321322
return data.transpose(self.dim_extend, 0)
322-
return torch.utils._pytree.tree_map(
323-
lambda x: x.transpose(self.dim_extend, 0), data
324-
)
323+
return tree_map(lambda x: x.transpose(self.dim_extend, 0), data)
325324

326325
def _get_collate_fn(self, collate_fn):
327326
self._collate_fn = (

torchrl/data/replay_buffers/storages.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1367,16 +1367,44 @@ def _collate_list_tensordict(x):
13671367
return out
13681368

13691369

1370+
@implement_for("torch", "2.4")
13701371
def _stack_anything(data):
13711372
if is_tensor_collection(data[0]):
13721373
return LazyStackedTensorDict.maybe_dense_stack(data)
1373-
return torch.utils._pytree.tree_map(
1374+
return tree_map(
13741375
lambda *x: torch.stack(x),
13751376
*data,
13761377
is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x),
13771378
)
13781379

13791380

1381+
@implement_for("torch", None, "2.4")
1382+
def _stack_anything(data): # noqa: F811
1383+
from tensordict import _pytree
1384+
1385+
if not _pytree.PYTREE_REGISTERED_TDS:
1386+
raise RuntimeError(
1387+
"TensorDict is not registered within PyTree. "
1388+
"If you see this error, it means tensordicts instances cannot be natively stacked using tree_map. "
1389+
"To solve this issue, (a) upgrade pytorch to a version > 2.4, or (b) make sure TensorDict is registered in PyTree. "
1390+
"If this error persists, open an issue on https://github.com/pytorch/rl/issues"
1391+
)
1392+
if is_tensor_collection(data[0]):
1393+
return LazyStackedTensorDict.maybe_dense_stack(data)
1394+
flat_trees = []
1395+
spec = None
1396+
for d in data:
1397+
flat_tree, spec = tree_flatten(d)
1398+
flat_trees.append(flat_tree)
1399+
1400+
leaves = []
1401+
for leaf in zip(*flat_trees):
1402+
leaf = torch.stack(leaf)
1403+
leaves.append(leaf)
1404+
1405+
return tree_unflatten(leaves, spec)
1406+
1407+
13801408
def _collate_id(x):
13811409
return x
13821410

0 commit comments

Comments
 (0)