Skip to content

Commit 253b8dd

Browse files
committed
[Refactor,Test] Move compile test to dedicated folder
ghstack-source-id: a4d56bd Pull-Request: #3314
1 parent eaaf97a commit 253b8dd

File tree

6 files changed

+284
-204
lines changed

6 files changed

+284
-204
lines changed

test/compile/test_collectors.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Tests for torch.compile compatibility of collectors."""
6+
from __future__ import annotations
7+
8+
import functools
9+
import sys
10+
11+
import pytest
12+
import torch
13+
from packaging import version
14+
from tensordict.nn import TensorDictModule
15+
from torch import nn
16+
17+
from torchrl.collectors import Collector, MultiAsyncCollector, MultiSyncCollector
18+
from torchrl.testing.mocking_classes import ContinuousActionVecMockEnv
19+
20+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
21+
IS_WINDOWS = sys.platform == "win32"
22+
23+
pytestmark = [
24+
pytest.mark.filterwarnings(
25+
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning"
26+
),
27+
]
28+
29+
30+
@pytest.mark.skipif(
31+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
32+
)
33+
@pytest.mark.skipif(IS_WINDOWS, reason="windows is not supported for compile tests.")
34+
@pytest.mark.skipif(
35+
sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+"
36+
)
37+
class TestCompile:
38+
@pytest.mark.parametrize(
39+
"collector_cls",
40+
# Clearing compiled policies causes segfault on machines with cuda
41+
[Collector, MultiAsyncCollector, MultiSyncCollector]
42+
if not torch.cuda.is_available()
43+
else [Collector],
44+
)
45+
@pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}])
46+
@pytest.mark.parametrize(
47+
"device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")]
48+
)
49+
def test_compiled_policy(self, collector_cls, compile_policy, device):
50+
policy = TensorDictModule(
51+
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
52+
)
53+
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
54+
if collector_cls is Collector:
55+
torch._dynamo.reset_code_caches()
56+
collector = Collector(
57+
make_env(),
58+
policy,
59+
frames_per_batch=10,
60+
total_frames=30,
61+
compile_policy=compile_policy,
62+
)
63+
assert collector.compiled_policy
64+
else:
65+
collector = collector_cls(
66+
[make_env] * 2,
67+
policy,
68+
frames_per_batch=10,
69+
total_frames=30,
70+
compile_policy=compile_policy,
71+
)
72+
assert collector.compiled_policy
73+
try:
74+
for data in collector:
75+
assert data is not None
76+
finally:
77+
collector.shutdown()
78+
del collector
79+
80+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
81+
@pytest.mark.parametrize(
82+
"collector_cls",
83+
[Collector],
84+
)
85+
@pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}])
86+
def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
87+
device = torch.device("cuda:0")
88+
policy = TensorDictModule(
89+
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
90+
)
91+
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
92+
if collector_cls is Collector:
93+
collector = Collector(
94+
make_env(),
95+
policy,
96+
frames_per_batch=30,
97+
total_frames=120,
98+
cudagraph_policy=cudagraph_policy,
99+
device=device,
100+
)
101+
assert collector.cudagraphed_policy
102+
else:
103+
collector = collector_cls(
104+
[make_env] * 2,
105+
policy,
106+
frames_per_batch=30,
107+
total_frames=120,
108+
cudagraph_policy=cudagraph_policy,
109+
device=device,
110+
)
111+
assert collector.cudagraphed_policy
112+
try:
113+
for data in collector:
114+
assert data is not None
115+
finally:
116+
collector.shutdown()
117+
del collector
118+
119+
120+
if __name__ == "__main__":
121+
pytest.main([__file__, "-v"])

test/compile/test_objectives.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Tests for torch.compile compatibility of objectives-related modules."""
6+
from __future__ import annotations
7+
8+
import sys
9+
10+
import pytest
11+
import torch
12+
13+
from packaging import version
14+
from tensordict import TensorDict
15+
from tensordict.nn import ProbabilisticTensorDictModule, set_composite_lp_aggregate
16+
17+
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
18+
19+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
20+
IS_WINDOWS = sys.platform == "win32"
21+
22+
pytestmark = [
23+
pytest.mark.filterwarnings(
24+
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning"
25+
),
26+
]
27+
28+
29+
@pytest.mark.skipif(
30+
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
31+
)
32+
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
33+
@pytest.mark.skipif(
34+
sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+"
35+
)
36+
@set_composite_lp_aggregate(False)
37+
def test_exploration_compile():
38+
try:
39+
torch._dynamo.reset_code_caches()
40+
except Exception:
41+
# older versions of PT don't have that function
42+
pass
43+
m = ProbabilisticTensorDictModule(
44+
in_keys=["loc", "scale"],
45+
out_keys=["sample"],
46+
distribution_class=torch.distributions.Normal,
47+
)
48+
49+
# class set_exploration_type_random(set_exploration_type):
50+
# __init__ = object.__init__
51+
# type = ExplorationType.RANDOM
52+
it = exploration_type()
53+
54+
@torch.compile(fullgraph=True)
55+
def func(t):
56+
with set_exploration_type(ExplorationType.RANDOM):
57+
t0 = m(t.clone())
58+
t1 = m(t.clone())
59+
return t0, t1
60+
61+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
62+
t0, t1 = func(t)
63+
assert (t0["sample"] != t1["sample"]).any()
64+
assert it == exploration_type()
65+
66+
@torch.compile(fullgraph=True)
67+
def func(t):
68+
with set_exploration_type(ExplorationType.MEAN):
69+
t0 = m(t.clone())
70+
t1 = m(t.clone())
71+
return t0, t1
72+
73+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
74+
t0, t1 = func(t)
75+
assert (t0["sample"] == t1["sample"]).all()
76+
assert it == exploration_type()
77+
78+
@torch.compile(fullgraph=True)
79+
@set_exploration_type(ExplorationType.RANDOM)
80+
def func(t):
81+
t0 = m(t.clone())
82+
t1 = m(t.clone())
83+
return t0, t1
84+
85+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
86+
t0, t1 = func(t)
87+
assert (t0["sample"] != t1["sample"]).any()
88+
assert it == exploration_type()
89+
90+
@torch.compile(fullgraph=True)
91+
@set_exploration_type(ExplorationType.MEAN)
92+
def func(t):
93+
t0 = m(t.clone())
94+
t1 = m(t.clone())
95+
return t0, t1
96+
97+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
98+
t0, t1 = func(t)
99+
assert (t0["sample"] == t1["sample"]).all()
100+
assert it == exploration_type()
101+
102+
103+
if __name__ == "__main__":
104+
pytest.main([__file__, "-v"])

test/compile/test_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Tests for torch.compile compatibility of utility functions."""
6+
from __future__ import annotations
7+
8+
import sys
9+
10+
import pytest
11+
import torch
12+
from packaging import version
13+
14+
from torchrl.testing import capture_log_records
15+
16+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
17+
18+
pytestmark = [
19+
pytest.mark.filterwarnings(
20+
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning"
21+
),
22+
]
23+
24+
25+
# Check that 'capture_log_records' captures records emitted when torch
26+
# recompiles a function.
27+
@pytest.mark.skipif(
28+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
29+
)
30+
@pytest.mark.skipif(
31+
sys.version_info >= (3, 14),
32+
reason="torch.compile is not supported on Python 3.14+",
33+
)
34+
def test_capture_log_records_recompile():
35+
torch.compiler.reset()
36+
37+
# This function recompiles each time it is called with a different string
38+
# input.
39+
@torch.compile
40+
def str_to_tensor(s):
41+
return bytes(s, "utf8")
42+
43+
str_to_tensor("a")
44+
45+
try:
46+
torch._logging.set_logs(recompiles=True)
47+
records = []
48+
capture_log_records(records, "torch._dynamo", "recompiles")
49+
str_to_tensor("b")
50+
51+
finally:
52+
torch._logging.set_logs()
53+
54+
assert len(records) == 1
55+
56+
57+
if __name__ == "__main__":
58+
pytest.main([__file__, "-v"])

test/test_collectors.py

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3693,96 +3693,6 @@ def test_dynamic_multiasync_collector(self):
36933693
assert data.names[-1] == "time"
36943694

36953695

3696-
@pytest.mark.skipif(
3697-
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
3698-
)
3699-
@pytest.mark.skipif(IS_WINDOWS, reason="windows is not supported for compile tests.")
3700-
@pytest.mark.skipif(
3701-
sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+"
3702-
)
3703-
class TestCompile:
3704-
@pytest.mark.parametrize(
3705-
"collector_cls",
3706-
# Clearing compiled policies causes segfault on machines with cuda
3707-
[Collector, MultiAsyncCollector, MultiSyncCollector]
3708-
if not torch.cuda.is_available()
3709-
else [Collector],
3710-
)
3711-
@pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}])
3712-
@pytest.mark.parametrize(
3713-
"device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")]
3714-
)
3715-
def test_compiled_policy(self, collector_cls, compile_policy, device):
3716-
policy = TensorDictModule(
3717-
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
3718-
)
3719-
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
3720-
if collector_cls is Collector:
3721-
torch._dynamo.reset_code_caches()
3722-
collector = Collector(
3723-
make_env(),
3724-
policy,
3725-
frames_per_batch=10,
3726-
total_frames=30,
3727-
compile_policy=compile_policy,
3728-
)
3729-
assert collector.compiled_policy
3730-
else:
3731-
collector = collector_cls(
3732-
[make_env] * 2,
3733-
policy,
3734-
frames_per_batch=10,
3735-
total_frames=30,
3736-
compile_policy=compile_policy,
3737-
)
3738-
assert collector.compiled_policy
3739-
try:
3740-
for data in collector:
3741-
assert data is not None
3742-
finally:
3743-
collector.shutdown()
3744-
del collector
3745-
3746-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
3747-
@pytest.mark.parametrize(
3748-
"collector_cls",
3749-
[Collector],
3750-
)
3751-
@pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}])
3752-
def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
3753-
device = torch.device("cuda:0")
3754-
policy = TensorDictModule(
3755-
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
3756-
)
3757-
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
3758-
if collector_cls is Collector:
3759-
collector = Collector(
3760-
make_env(),
3761-
policy,
3762-
frames_per_batch=30,
3763-
total_frames=120,
3764-
cudagraph_policy=cudagraph_policy,
3765-
device=device,
3766-
)
3767-
assert collector.cudagraphed_policy
3768-
else:
3769-
collector = collector_cls(
3770-
[make_env] * 2,
3771-
policy,
3772-
frames_per_batch=30,
3773-
total_frames=120,
3774-
cudagraph_policy=cudagraph_policy,
3775-
device=device,
3776-
)
3777-
assert collector.cudagraphed_policy
3778-
try:
3779-
for data in collector:
3780-
assert data is not None
3781-
finally:
3782-
collector.shutdown()
3783-
del collector
3784-
3785-
37863696
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
37873697
class TestCollectorsNonTensor:
37883698
class AddNontTensorData(Transform):

0 commit comments

Comments
 (0)