diff --git a/test/test_exploration.py b/test/test_exploration.py index 86e883cef88..7330957f7f6 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import math import os import pytest @@ -891,3 +892,360 @@ def test_consistent_dropout_primer(self): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + + +@pytest.mark.parametrize("device", get_default_devices()) +class TestNoisyLinear: + """Tests for NoisyLinear layer based on NoisyNet paper specifications.""" + + def test_noisy_linear_initialization(self, device): + """Test that NoisyLinear initializes with correct parameters.""" + from torchrl.modules.models.exploration import NoisyLinear + + in_features, out_features = 10, 5 + layer = NoisyLinear(in_features, out_features, device=device) + + # Check that mu and sigma parameters exist + assert hasattr(layer, "weight_mu") + assert hasattr(layer, "weight_sigma") + assert hasattr(layer, "bias_mu") + assert hasattr(layer, "bias_sigma") + + # Check parameter shapes + assert layer.weight_mu.shape == (out_features, in_features) + assert layer.weight_sigma.shape == (out_features, in_features) + assert layer.bias_mu.shape == (out_features,) + assert layer.bias_sigma.shape == (out_features,) + + # Check that sigma values are positive + assert (layer.weight_sigma > 0).all() + assert (layer.bias_sigma > 0).all() + + # Check initialization ranges (from paper) + mu_range = 1 / math.sqrt(in_features) + assert (layer.weight_mu >= -mu_range).all() + assert (layer.weight_mu <= mu_range).all() + assert (layer.bias_mu >= -mu_range).all() + assert (layer.bias_mu <= mu_range).all() + + def test_noisy_linear_training_vs_eval(self, device): + """Test that NoisyLinear behaves differently in training vs eval mode.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + x = torch.randn(3, 10, device=device) + + # Get outputs in training mode + layer.train() + y_train_1 = layer(x) + layer.reset_noise() # Reset noise + y_train_2 = layer(x) + + # Get outputs in eval mode + layer.eval() + y_eval_1 = layer(x) + layer.reset_noise() # Reset noise + y_eval_2 = layer(x) + + # Training outputs should be different due to noise + assert not torch.allclose(y_train_1, y_train_2, atol=1e-6) + + # Eval outputs should be identical (no noise) + torch.testing.assert_close(y_eval_1, y_eval_2) + + # Training and eval outputs should be different + assert not torch.allclose(y_train_1, y_eval_1, atol=1e-6) + + def test_noise_consistency_within_episode(self, device): + """Test that noise remains consistent within an episode (no reset).""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + layer.train() + x = torch.randn(3, 10, device=device) + + # First forward pass + y1 = layer(x) + + # Multiple forward passes without resetting noise + y2 = layer(x) + y3 = layer(x) + y4 = layer(x) + + # All outputs should be identical (same noise) + assert torch.allclose(y1, y2, atol=1e-6) + assert torch.allclose(y1, y3, atol=1e-6) + assert torch.allclose(y1, y4, atol=1e-6) + + def test_noise_change_after_reset(self, device): + """Test that noise changes after reset_noise() is called.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + layer.train() + x = torch.randn(3, 10, device=device) + + # First episode + y1 = layer(x) + + # Reset noise (simulating new episode) + layer.reset_noise() + y2 = layer(x) + + # Reset noise again + layer.reset_noise() + y3 = layer(x) + + # Outputs should be different after each reset + assert not torch.allclose(y1, y2, atol=1e-6) + assert not torch.allclose(y1, y3, atol=1e-6) + assert not torch.allclose(y2, y3, atol=1e-6) + + def test_factorized_gaussian_noise(self, device): + """Test that the noise follows factorized Gaussian distribution.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + layer.train() + + # Get noise samples + noise_samples = [] + for _ in range(1000): + layer.reset_noise() + # Extract the actual noise used + weight_noise = layer.weight - layer.weight_mu + noise_samples.append(weight_noise.flatten()) + + noise_samples = torch.stack(noise_samples) + + # Check that noise has approximately zero mean + assert abs(noise_samples.mean()) < 0.1 + + # Check that noise has reasonable variance + noise_std = noise_samples.std() + expected_std = layer.std_init / math.sqrt(10) # Based on initialization + assert 0.5 * expected_std < noise_std < 2.0 * expected_std + + def test_weight_property_behavior(self, device): + """Test that weight property returns correct values in train/eval modes.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + + # Training mode + layer.train() + layer.reset_noise() + weight_train = layer.weight + bias_train = layer.bias + + # Should include noise + assert not torch.allclose(weight_train, layer.weight_mu, atol=1e-6) + assert not torch.allclose(bias_train, layer.bias_mu, atol=1e-6) + + # Eval mode + layer.eval() + weight_eval = layer.weight + bias_eval = layer.bias + + # Should be exactly the mean weights + assert torch.allclose(weight_eval, layer.weight_mu, atol=1e-6) + assert torch.allclose(bias_eval, layer.bias_mu, atol=1e-6) + + def test_noisy_linear_in_network(self, device): + """Test NoisyLinear in a complete network setup.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + + # Create a simple network with NoisyLinear + network = nn.Sequential( + nn.Linear(10, 20), nn.ReLU(), NoisyLinear(20, 5, device=device) + ).to(device) + + x = torch.randn(3, 10, device=device) + + # Training mode + network.train() + y_train_1 = network(x) + network[-1].reset_noise() # Reset noise in NoisyLinear layer + y_train_2 = network(x) + + # Eval mode + network.eval() + y_eval_1 = network(x) + y_eval_2 = network(x) + + # Training outputs should be different + assert not torch.allclose(y_train_1, y_train_2, atol=1e-6) + + # Eval outputs should be identical + assert torch.allclose(y_eval_1, y_eval_2, atol=1e-6) + + def test_noise_reset_function(self, device): + """Test the reset_noise utility function.""" + from torchrl.modules.models.exploration import NoisyLinear, reset_noise + + torch.manual_seed(0) + + # Create network with multiple NoisyLinear layers + network = nn.Sequential( + NoisyLinear(10, 20, device=device), + nn.ReLU(), + NoisyLinear(20, 5, device=device), + ).to(device) + + network.train() + x = torch.randn(3, 10, device=device) + + # First forward pass + network(x) + + # Reset noise using utility function + reset_noise(network) + network(x) + + # Outputs should be different (but might be the same if noise is very small) + # Let's check that at least one of the layers changed + changed = False + for module in network.modules(): + if hasattr(module, "weight_mu"): + # Check if the actual weights changed + if not torch.allclose(module.weight, module.weight_mu, atol=1e-6): + changed = True + break + + # If no noise is present, the test should still pass + if not changed: + # Check that we're in eval mode or noise is very small + assert network.training == False or all( + hasattr(m, "weight_sigma") and m.weight_sigma.max() < 1e-3 + for m in network.modules() + if hasattr(m, "weight_sigma") + ) + + def test_noisy_linear_gradients(self, device): + """Test that gradients flow through NoisyLinear parameters.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + layer.train() + + x = torch.randn(3, 10, device=device, requires_grad=True) + y = layer(x) + loss = y.sum() + + # Backward pass + loss.backward() + + # Check that gradients exist for all parameters + assert layer.weight_mu.grad is not None + assert layer.weight_sigma.grad is not None + assert layer.bias_mu.grad is not None + assert layer.bias_sigma.grad is not None + + # Check that gradients are not zero + assert not torch.allclose( + layer.weight_mu.grad, torch.zeros_like(layer.weight_mu.grad) + ) + assert not torch.allclose( + layer.weight_sigma.grad, torch.zeros_like(layer.weight_sigma.grad) + ) + + def test_noisy_linear_parameter_learning(self, device): + """Test that sigma parameters actually learn during training.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + layer.train() + + # Store initial sigma values + initial_weight_sigma = layer.weight_sigma.clone() + initial_bias_sigma = layer.bias_sigma.clone() + + # Simple training loop + optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) + x = torch.randn(100, 10, device=device) + target = torch.randn(100, 5, device=device) + + for _ in range(10): + optimizer.zero_grad() + layer.reset_noise() # Reset noise each iteration + y = layer(x) + loss = torch.nn.functional.mse_loss(y, target) + loss.backward() + optimizer.step() + + # Check that sigma values have changed + assert not torch.allclose(layer.weight_sigma, initial_weight_sigma, atol=1e-6) + assert not torch.allclose(layer.bias_sigma, initial_bias_sigma, atol=1e-6) + + def test_noisy_linear_std_init_effect(self, device): + """Test that different std_init values affect noise magnitude.""" + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + + # Create layers with different std_init values + layer_small = NoisyLinear(10, 5, std_init=0.01, device=device) + layer_large = NoisyLinear(10, 5, std_init=1.0, device=device) + + layer_small.train() + layer_large.train() + + x = torch.randn(3, 10, device=device) + + # Get outputs with different noise levels + layer_small.reset_noise() + layer_large.reset_noise() + + # Get multiple samples to measure noise variance + noise_samples_small = [] + noise_samples_large = [] + + for _ in range(10): + layer_small.reset_noise() + layer_large.reset_noise() + y_small = layer_small(x) + y_large = layer_large(x) + noise_samples_small.append(y_small) + noise_samples_large.append(y_large) + + noise_samples_small = torch.stack(noise_samples_small) + noise_samples_large = torch.stack(noise_samples_large) + + # Calculate noise variance + noise_var_small = noise_samples_small.var(dim=0).mean() + noise_var_large = noise_samples_large.var(dim=0).mean() + + # Large std_init should produce larger noise variance + assert noise_var_large > noise_var_small + + def test_noisy_linear_serialization(self, device): + """Test that NoisyLinear can be saved and loaded correctly.""" + import os + import tempfile + + from torchrl.modules.models.exploration import NoisyLinear + + torch.manual_seed(0) + layer = NoisyLinear(10, 5, device=device) + + # Save and load + with tempfile.NamedTemporaryFile(delete=False) as f: + torch.save(layer.state_dict(), f.name) + layer_loaded = NoisyLinear(10, 5, device=device) + layer_loaded.load_state_dict(torch.load(f.name)) + os.unlink(f.name) + + # Check that parameters are the same + assert torch.allclose(layer.weight_mu, layer_loaded.weight_mu, atol=1e-6) + assert torch.allclose(layer.weight_sigma, layer_loaded.weight_sigma, atol=1e-6) + assert torch.allclose(layer.bias_mu, layer_loaded.bias_mu, atol=1e-6) + assert torch.allclose(layer.bias_sigma, layer_loaded.bias_sigma, atol=1e-6) diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 1ae6234a844..51fdc60d57c 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -160,3 +160,20 @@ def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool: "_reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset" ) return any_reset + +def primers_from_module(module: nn.Module, target_cls: T) -> list[TensorDictPrimer]: + """Get primers from a module. + + Iterates over the module's children and returns the primers of the children that are instances of the target class. + These primers will write some data to be used by the models at reset time. + The tensors are set within the model during the policy call by + + + Args: + module (nn.Module): the module to get primers from. + + Returns: + list[TensorDictPrimer]: the primers from the module. + """ + # + ... \ No newline at end of file