Skip to content

Commit 1c724c3

Browse files
committed
Update
[ghstack-poisoned]
1 parent 15fdb75 commit 1c724c3

File tree

1 file changed

+295
-0
lines changed

1 file changed

+295
-0
lines changed

test/compile/test_value.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Tests for torch.compile compatibility of value estimation functions."""
6+
from __future__ import annotations
7+
8+
import pytest
9+
import torch
10+
11+
from torchrl.objectives.value.functional import (
12+
generalized_advantage_estimate,
13+
td_lambda_return_estimate,
14+
vec_generalized_advantage_estimate,
15+
vec_td_lambda_return_estimate,
16+
)
17+
18+
19+
class TestValueFunctionCompile:
20+
"""Test compilation of value estimation functions."""
21+
22+
@pytest.fixture
23+
def value_data(self):
24+
"""Create test data for value functions."""
25+
batch_size = 32
26+
time_steps = 15
27+
feature_dim = 1
28+
29+
return {
30+
"gamma": 0.99,
31+
"lmbda": 0.95,
32+
"state_value": torch.randn(batch_size, time_steps, feature_dim),
33+
"next_state_value": torch.randn(batch_size, time_steps, feature_dim),
34+
"reward": torch.randn(batch_size, time_steps, feature_dim),
35+
"done": torch.zeros(batch_size, time_steps, feature_dim, dtype=torch.bool),
36+
"terminated": torch.zeros(
37+
batch_size, time_steps, feature_dim, dtype=torch.bool
38+
),
39+
}
40+
41+
def test_td_lambda_return_estimate_compiles_fullgraph(self, value_data):
42+
"""Test that td_lambda_return_estimate (non-vectorized) compiles with fullgraph=True."""
43+
result_eager = td_lambda_return_estimate(
44+
gamma=value_data["gamma"],
45+
lmbda=value_data["lmbda"],
46+
next_state_value=value_data["next_state_value"],
47+
reward=value_data["reward"],
48+
done=value_data["done"],
49+
terminated=value_data["terminated"],
50+
)
51+
52+
compiled_fn = torch.compile(
53+
td_lambda_return_estimate,
54+
fullgraph=True,
55+
backend="inductor",
56+
)
57+
58+
result_compiled = compiled_fn(
59+
gamma=value_data["gamma"],
60+
lmbda=value_data["lmbda"],
61+
next_state_value=value_data["next_state_value"],
62+
reward=value_data["reward"],
63+
done=value_data["done"],
64+
terminated=value_data["terminated"],
65+
)
66+
67+
torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4)
68+
69+
def test_generalized_advantage_estimate_compiles_fullgraph(self, value_data):
70+
"""Test that generalized_advantage_estimate (non-vectorized) compiles with fullgraph=True."""
71+
advantage_eager, value_target_eager = generalized_advantage_estimate(
72+
gamma=value_data["gamma"],
73+
lmbda=value_data["lmbda"],
74+
state_value=value_data["state_value"],
75+
next_state_value=value_data["next_state_value"],
76+
reward=value_data["reward"],
77+
done=value_data["done"],
78+
terminated=value_data["terminated"],
79+
)
80+
81+
compiled_fn = torch.compile(
82+
generalized_advantage_estimate,
83+
fullgraph=True,
84+
backend="inductor",
85+
)
86+
87+
advantage_compiled, value_target_compiled = compiled_fn(
88+
gamma=value_data["gamma"],
89+
lmbda=value_data["lmbda"],
90+
state_value=value_data["state_value"],
91+
next_state_value=value_data["next_state_value"],
92+
reward=value_data["reward"],
93+
done=value_data["done"],
94+
terminated=value_data["terminated"],
95+
)
96+
97+
torch.testing.assert_close(
98+
advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4
99+
)
100+
torch.testing.assert_close(
101+
value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4
102+
)
103+
104+
def test_vec_td_lambda_return_estimate_fails_fullgraph(self, value_data):
105+
"""Test that vec_td_lambda_return_estimate fails with fullgraph=True due to data-dependent shapes."""
106+
compiled_fn = torch.compile(
107+
vec_td_lambda_return_estimate,
108+
fullgraph=True,
109+
backend="inductor",
110+
)
111+
112+
# This should fail because of data-dependent shapes in _get_num_per_traj
113+
with pytest.raises(Exception):
114+
compiled_fn(
115+
gamma=value_data["gamma"],
116+
lmbda=value_data["lmbda"],
117+
next_state_value=value_data["next_state_value"],
118+
reward=value_data["reward"],
119+
done=value_data["done"],
120+
terminated=value_data["terminated"],
121+
)
122+
123+
def test_vec_generalized_advantage_estimate_fails_fullgraph(self, value_data):
124+
"""Test that vec_generalized_advantage_estimate fails with fullgraph=True due to data-dependent shapes."""
125+
compiled_fn = torch.compile(
126+
vec_generalized_advantage_estimate,
127+
fullgraph=True,
128+
backend="inductor",
129+
)
130+
131+
# This should fail because of data-dependent shapes in _get_num_per_traj
132+
with pytest.raises(Exception):
133+
compiled_fn(
134+
gamma=value_data["gamma"],
135+
lmbda=value_data["lmbda"],
136+
state_value=value_data["state_value"],
137+
next_state_value=value_data["next_state_value"],
138+
reward=value_data["reward"],
139+
done=value_data["done"],
140+
terminated=value_data["terminated"],
141+
)
142+
143+
def test_td_lambda_with_tensor_gamma_compiles_fullgraph(self, value_data):
144+
"""Test that td_lambda_return_estimate compiles with 0-d tensor gamma (fullgraph=True).
145+
146+
This tests the fix for PendingUnbackedSymbolNotFound error that occurred when
147+
torch.full_like received a 0-d tensor and internally called .item().
148+
"""
149+
# Use 0-d tensor gamma/lmbda - this was the problematic case
150+
gamma_tensor = torch.tensor(value_data["gamma"])
151+
lmbda_tensor = torch.tensor(value_data["lmbda"])
152+
153+
result_eager = td_lambda_return_estimate(
154+
gamma=gamma_tensor,
155+
lmbda=lmbda_tensor,
156+
next_state_value=value_data["next_state_value"],
157+
reward=value_data["reward"],
158+
done=value_data["done"],
159+
terminated=value_data["terminated"],
160+
)
161+
162+
compiled_fn = torch.compile(
163+
td_lambda_return_estimate,
164+
fullgraph=True,
165+
backend="inductor",
166+
)
167+
168+
result_compiled = compiled_fn(
169+
gamma=gamma_tensor,
170+
lmbda=lmbda_tensor,
171+
next_state_value=value_data["next_state_value"],
172+
reward=value_data["reward"],
173+
done=value_data["done"],
174+
terminated=value_data["terminated"],
175+
)
176+
177+
torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4)
178+
179+
def test_gae_with_tensor_gamma_compiles_fullgraph(self, value_data):
180+
"""Test that generalized_advantage_estimate compiles with 0-d tensor gamma (fullgraph=True).
181+
182+
This tests the fix for PendingUnbackedSymbolNotFound error that occurred when
183+
torch.full_like received a 0-d tensor and internally called .item().
184+
"""
185+
# Use 0-d tensor gamma/lmbda - this was the problematic case
186+
gamma_tensor = torch.tensor(value_data["gamma"])
187+
lmbda_tensor = torch.tensor(value_data["lmbda"])
188+
189+
advantage_eager, value_target_eager = generalized_advantage_estimate(
190+
gamma=gamma_tensor,
191+
lmbda=lmbda_tensor,
192+
state_value=value_data["state_value"],
193+
next_state_value=value_data["next_state_value"],
194+
reward=value_data["reward"],
195+
done=value_data["done"],
196+
terminated=value_data["terminated"],
197+
)
198+
199+
compiled_fn = torch.compile(
200+
generalized_advantage_estimate,
201+
fullgraph=True,
202+
backend="inductor",
203+
)
204+
205+
advantage_compiled, value_target_compiled = compiled_fn(
206+
gamma=gamma_tensor,
207+
lmbda=lmbda_tensor,
208+
state_value=value_data["state_value"],
209+
next_state_value=value_data["next_state_value"],
210+
reward=value_data["reward"],
211+
done=value_data["done"],
212+
terminated=value_data["terminated"],
213+
)
214+
215+
torch.testing.assert_close(
216+
advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4
217+
)
218+
torch.testing.assert_close(
219+
value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4
220+
)
221+
222+
223+
class TestTDLambdaEstimatorCompile:
224+
"""Test TDLambdaEstimator compile-friendly vectorized property."""
225+
226+
def test_vectorized_property_returns_true_in_eager_mode(self):
227+
"""Test that TDLambdaEstimator.vectorized returns True in eager mode when set to True."""
228+
from tensordict.nn import TensorDictModule
229+
from torch import nn
230+
231+
from torchrl.objectives.value.advantages import TDLambdaEstimator
232+
233+
value_net = TensorDictModule(
234+
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
235+
)
236+
estimator = TDLambdaEstimator(
237+
gamma=0.99,
238+
lmbda=0.95,
239+
value_network=value_net,
240+
vectorized=True,
241+
)
242+
243+
assert estimator.vectorized is True
244+
assert estimator._vectorized is True
245+
246+
def test_vectorized_property_returns_false_in_eager_mode_when_set_false(self):
247+
"""Test that TDLambdaEstimator.vectorized returns False in eager mode when set to False."""
248+
from tensordict.nn import TensorDictModule
249+
from torch import nn
250+
251+
from torchrl.objectives.value.advantages import TDLambdaEstimator
252+
253+
value_net = TensorDictModule(
254+
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
255+
)
256+
estimator = TDLambdaEstimator(
257+
gamma=0.99,
258+
lmbda=0.95,
259+
value_network=value_net,
260+
vectorized=False,
261+
)
262+
263+
assert estimator.vectorized is False
264+
assert estimator._vectorized is False
265+
266+
def test_vectorized_setter_works(self):
267+
"""Test that TDLambdaEstimator.vectorized setter works correctly."""
268+
from tensordict.nn import TensorDictModule
269+
from torch import nn
270+
271+
from torchrl.objectives.value.advantages import TDLambdaEstimator
272+
273+
value_net = TensorDictModule(
274+
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
275+
)
276+
estimator = TDLambdaEstimator(
277+
gamma=0.99,
278+
lmbda=0.95,
279+
value_network=value_net,
280+
vectorized=True,
281+
)
282+
283+
assert estimator.vectorized is True
284+
285+
estimator.vectorized = False
286+
assert estimator.vectorized is False
287+
assert estimator._vectorized is False
288+
289+
estimator.vectorized = True
290+
assert estimator.vectorized is True
291+
assert estimator._vectorized is True
292+
293+
294+
if __name__ == "__main__":
295+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)