|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +"""Tests for torch.compile compatibility of value estimation functions.""" |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | + |
| 11 | +from torchrl.objectives.value.functional import ( |
| 12 | + generalized_advantage_estimate, |
| 13 | + td_lambda_return_estimate, |
| 14 | + vec_generalized_advantage_estimate, |
| 15 | + vec_td_lambda_return_estimate, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +class TestValueFunctionCompile: |
| 20 | + """Test compilation of value estimation functions.""" |
| 21 | + |
| 22 | + @pytest.fixture |
| 23 | + def value_data(self): |
| 24 | + """Create test data for value functions.""" |
| 25 | + batch_size = 32 |
| 26 | + time_steps = 15 |
| 27 | + feature_dim = 1 |
| 28 | + |
| 29 | + return { |
| 30 | + "gamma": 0.99, |
| 31 | + "lmbda": 0.95, |
| 32 | + "state_value": torch.randn(batch_size, time_steps, feature_dim), |
| 33 | + "next_state_value": torch.randn(batch_size, time_steps, feature_dim), |
| 34 | + "reward": torch.randn(batch_size, time_steps, feature_dim), |
| 35 | + "done": torch.zeros(batch_size, time_steps, feature_dim, dtype=torch.bool), |
| 36 | + "terminated": torch.zeros( |
| 37 | + batch_size, time_steps, feature_dim, dtype=torch.bool |
| 38 | + ), |
| 39 | + } |
| 40 | + |
| 41 | + def test_td_lambda_return_estimate_compiles_fullgraph(self, value_data): |
| 42 | + """Test that td_lambda_return_estimate (non-vectorized) compiles with fullgraph=True.""" |
| 43 | + result_eager = td_lambda_return_estimate( |
| 44 | + gamma=value_data["gamma"], |
| 45 | + lmbda=value_data["lmbda"], |
| 46 | + next_state_value=value_data["next_state_value"], |
| 47 | + reward=value_data["reward"], |
| 48 | + done=value_data["done"], |
| 49 | + terminated=value_data["terminated"], |
| 50 | + ) |
| 51 | + |
| 52 | + compiled_fn = torch.compile( |
| 53 | + td_lambda_return_estimate, |
| 54 | + fullgraph=True, |
| 55 | + backend="inductor", |
| 56 | + ) |
| 57 | + |
| 58 | + result_compiled = compiled_fn( |
| 59 | + gamma=value_data["gamma"], |
| 60 | + lmbda=value_data["lmbda"], |
| 61 | + next_state_value=value_data["next_state_value"], |
| 62 | + reward=value_data["reward"], |
| 63 | + done=value_data["done"], |
| 64 | + terminated=value_data["terminated"], |
| 65 | + ) |
| 66 | + |
| 67 | + torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4) |
| 68 | + |
| 69 | + def test_generalized_advantage_estimate_compiles_fullgraph(self, value_data): |
| 70 | + """Test that generalized_advantage_estimate (non-vectorized) compiles with fullgraph=True.""" |
| 71 | + advantage_eager, value_target_eager = generalized_advantage_estimate( |
| 72 | + gamma=value_data["gamma"], |
| 73 | + lmbda=value_data["lmbda"], |
| 74 | + state_value=value_data["state_value"], |
| 75 | + next_state_value=value_data["next_state_value"], |
| 76 | + reward=value_data["reward"], |
| 77 | + done=value_data["done"], |
| 78 | + terminated=value_data["terminated"], |
| 79 | + ) |
| 80 | + |
| 81 | + compiled_fn = torch.compile( |
| 82 | + generalized_advantage_estimate, |
| 83 | + fullgraph=True, |
| 84 | + backend="inductor", |
| 85 | + ) |
| 86 | + |
| 87 | + advantage_compiled, value_target_compiled = compiled_fn( |
| 88 | + gamma=value_data["gamma"], |
| 89 | + lmbda=value_data["lmbda"], |
| 90 | + state_value=value_data["state_value"], |
| 91 | + next_state_value=value_data["next_state_value"], |
| 92 | + reward=value_data["reward"], |
| 93 | + done=value_data["done"], |
| 94 | + terminated=value_data["terminated"], |
| 95 | + ) |
| 96 | + |
| 97 | + torch.testing.assert_close( |
| 98 | + advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4 |
| 99 | + ) |
| 100 | + torch.testing.assert_close( |
| 101 | + value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4 |
| 102 | + ) |
| 103 | + |
| 104 | + def test_vec_td_lambda_return_estimate_fails_fullgraph(self, value_data): |
| 105 | + """Test that vec_td_lambda_return_estimate fails with fullgraph=True due to data-dependent shapes.""" |
| 106 | + compiled_fn = torch.compile( |
| 107 | + vec_td_lambda_return_estimate, |
| 108 | + fullgraph=True, |
| 109 | + backend="inductor", |
| 110 | + ) |
| 111 | + |
| 112 | + # This should fail because of data-dependent shapes in _get_num_per_traj |
| 113 | + with pytest.raises(Exception): |
| 114 | + compiled_fn( |
| 115 | + gamma=value_data["gamma"], |
| 116 | + lmbda=value_data["lmbda"], |
| 117 | + next_state_value=value_data["next_state_value"], |
| 118 | + reward=value_data["reward"], |
| 119 | + done=value_data["done"], |
| 120 | + terminated=value_data["terminated"], |
| 121 | + ) |
| 122 | + |
| 123 | + def test_vec_generalized_advantage_estimate_fails_fullgraph(self, value_data): |
| 124 | + """Test that vec_generalized_advantage_estimate fails with fullgraph=True due to data-dependent shapes.""" |
| 125 | + compiled_fn = torch.compile( |
| 126 | + vec_generalized_advantage_estimate, |
| 127 | + fullgraph=True, |
| 128 | + backend="inductor", |
| 129 | + ) |
| 130 | + |
| 131 | + # This should fail because of data-dependent shapes in _get_num_per_traj |
| 132 | + with pytest.raises(Exception): |
| 133 | + compiled_fn( |
| 134 | + gamma=value_data["gamma"], |
| 135 | + lmbda=value_data["lmbda"], |
| 136 | + state_value=value_data["state_value"], |
| 137 | + next_state_value=value_data["next_state_value"], |
| 138 | + reward=value_data["reward"], |
| 139 | + done=value_data["done"], |
| 140 | + terminated=value_data["terminated"], |
| 141 | + ) |
| 142 | + |
| 143 | + def test_td_lambda_with_tensor_gamma_compiles_fullgraph(self, value_data): |
| 144 | + """Test that td_lambda_return_estimate compiles with 0-d tensor gamma (fullgraph=True). |
| 145 | +
|
| 146 | + This tests the fix for PendingUnbackedSymbolNotFound error that occurred when |
| 147 | + torch.full_like received a 0-d tensor and internally called .item(). |
| 148 | + """ |
| 149 | + # Use 0-d tensor gamma/lmbda - this was the problematic case |
| 150 | + gamma_tensor = torch.tensor(value_data["gamma"]) |
| 151 | + lmbda_tensor = torch.tensor(value_data["lmbda"]) |
| 152 | + |
| 153 | + result_eager = td_lambda_return_estimate( |
| 154 | + gamma=gamma_tensor, |
| 155 | + lmbda=lmbda_tensor, |
| 156 | + next_state_value=value_data["next_state_value"], |
| 157 | + reward=value_data["reward"], |
| 158 | + done=value_data["done"], |
| 159 | + terminated=value_data["terminated"], |
| 160 | + ) |
| 161 | + |
| 162 | + compiled_fn = torch.compile( |
| 163 | + td_lambda_return_estimate, |
| 164 | + fullgraph=True, |
| 165 | + backend="inductor", |
| 166 | + ) |
| 167 | + |
| 168 | + result_compiled = compiled_fn( |
| 169 | + gamma=gamma_tensor, |
| 170 | + lmbda=lmbda_tensor, |
| 171 | + next_state_value=value_data["next_state_value"], |
| 172 | + reward=value_data["reward"], |
| 173 | + done=value_data["done"], |
| 174 | + terminated=value_data["terminated"], |
| 175 | + ) |
| 176 | + |
| 177 | + torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4) |
| 178 | + |
| 179 | + def test_gae_with_tensor_gamma_compiles_fullgraph(self, value_data): |
| 180 | + """Test that generalized_advantage_estimate compiles with 0-d tensor gamma (fullgraph=True). |
| 181 | +
|
| 182 | + This tests the fix for PendingUnbackedSymbolNotFound error that occurred when |
| 183 | + torch.full_like received a 0-d tensor and internally called .item(). |
| 184 | + """ |
| 185 | + # Use 0-d tensor gamma/lmbda - this was the problematic case |
| 186 | + gamma_tensor = torch.tensor(value_data["gamma"]) |
| 187 | + lmbda_tensor = torch.tensor(value_data["lmbda"]) |
| 188 | + |
| 189 | + advantage_eager, value_target_eager = generalized_advantage_estimate( |
| 190 | + gamma=gamma_tensor, |
| 191 | + lmbda=lmbda_tensor, |
| 192 | + state_value=value_data["state_value"], |
| 193 | + next_state_value=value_data["next_state_value"], |
| 194 | + reward=value_data["reward"], |
| 195 | + done=value_data["done"], |
| 196 | + terminated=value_data["terminated"], |
| 197 | + ) |
| 198 | + |
| 199 | + compiled_fn = torch.compile( |
| 200 | + generalized_advantage_estimate, |
| 201 | + fullgraph=True, |
| 202 | + backend="inductor", |
| 203 | + ) |
| 204 | + |
| 205 | + advantage_compiled, value_target_compiled = compiled_fn( |
| 206 | + gamma=gamma_tensor, |
| 207 | + lmbda=lmbda_tensor, |
| 208 | + state_value=value_data["state_value"], |
| 209 | + next_state_value=value_data["next_state_value"], |
| 210 | + reward=value_data["reward"], |
| 211 | + done=value_data["done"], |
| 212 | + terminated=value_data["terminated"], |
| 213 | + ) |
| 214 | + |
| 215 | + torch.testing.assert_close( |
| 216 | + advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4 |
| 217 | + ) |
| 218 | + torch.testing.assert_close( |
| 219 | + value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4 |
| 220 | + ) |
| 221 | + |
| 222 | + |
| 223 | +class TestTDLambdaEstimatorCompile: |
| 224 | + """Test TDLambdaEstimator compile-friendly vectorized property.""" |
| 225 | + |
| 226 | + def test_vectorized_property_returns_true_in_eager_mode(self): |
| 227 | + """Test that TDLambdaEstimator.vectorized returns True in eager mode when set to True.""" |
| 228 | + from tensordict.nn import TensorDictModule |
| 229 | + from torch import nn |
| 230 | + |
| 231 | + from torchrl.objectives.value.advantages import TDLambdaEstimator |
| 232 | + |
| 233 | + value_net = TensorDictModule( |
| 234 | + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] |
| 235 | + ) |
| 236 | + estimator = TDLambdaEstimator( |
| 237 | + gamma=0.99, |
| 238 | + lmbda=0.95, |
| 239 | + value_network=value_net, |
| 240 | + vectorized=True, |
| 241 | + ) |
| 242 | + |
| 243 | + assert estimator.vectorized is True |
| 244 | + assert estimator._vectorized is True |
| 245 | + |
| 246 | + def test_vectorized_property_returns_false_in_eager_mode_when_set_false(self): |
| 247 | + """Test that TDLambdaEstimator.vectorized returns False in eager mode when set to False.""" |
| 248 | + from tensordict.nn import TensorDictModule |
| 249 | + from torch import nn |
| 250 | + |
| 251 | + from torchrl.objectives.value.advantages import TDLambdaEstimator |
| 252 | + |
| 253 | + value_net = TensorDictModule( |
| 254 | + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] |
| 255 | + ) |
| 256 | + estimator = TDLambdaEstimator( |
| 257 | + gamma=0.99, |
| 258 | + lmbda=0.95, |
| 259 | + value_network=value_net, |
| 260 | + vectorized=False, |
| 261 | + ) |
| 262 | + |
| 263 | + assert estimator.vectorized is False |
| 264 | + assert estimator._vectorized is False |
| 265 | + |
| 266 | + def test_vectorized_setter_works(self): |
| 267 | + """Test that TDLambdaEstimator.vectorized setter works correctly.""" |
| 268 | + from tensordict.nn import TensorDictModule |
| 269 | + from torch import nn |
| 270 | + |
| 271 | + from torchrl.objectives.value.advantages import TDLambdaEstimator |
| 272 | + |
| 273 | + value_net = TensorDictModule( |
| 274 | + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] |
| 275 | + ) |
| 276 | + estimator = TDLambdaEstimator( |
| 277 | + gamma=0.99, |
| 278 | + lmbda=0.95, |
| 279 | + value_network=value_net, |
| 280 | + vectorized=True, |
| 281 | + ) |
| 282 | + |
| 283 | + assert estimator.vectorized is True |
| 284 | + |
| 285 | + estimator.vectorized = False |
| 286 | + assert estimator.vectorized is False |
| 287 | + assert estimator._vectorized is False |
| 288 | + |
| 289 | + estimator.vectorized = True |
| 290 | + assert estimator.vectorized is True |
| 291 | + assert estimator._vectorized is True |
| 292 | + |
| 293 | + |
| 294 | +if __name__ == "__main__": |
| 295 | + pytest.main([__file__, "-v"]) |
0 commit comments