Skip to content

Commit cf72cc0

Browse files
authored
fix(pu): fix noise layer's usage based on the original paper (#866)
* fix(pu): fix noise layer's usage * polish(pu): polish comments * polish(pu): polish noisy_net config * fix(pu): fix reset_noise bug in noisy_net option * fix(pu): fix enable_noise bug in rainbow * style(pu): yapf format * style(pu): yapf format * style(pu): flake8 format * style(pu): yapf format * polish(pu): polish set_noise_mode when self._cfg.noisy_net is False * fature(pu): add unittest for noise_linear_layer --------- Co-authored-by: puyuan <[email protected]>
1 parent c290a67 commit cf72cc0

File tree

7 files changed

+174
-9
lines changed

7 files changed

+174
-9
lines changed

ding/model/template/q_learning.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
norm_type: Optional[str] = None,
3838
dropout: Optional[float] = None,
3939
init_bias: Optional[float] = None,
40+
noise: bool = False,
4041
) -> None:
4142
"""
4243
Overview:
@@ -57,6 +58,8 @@ def __init__(
5758
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
5859
if ``None`` then default disable dropout layer.
5960
- init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \
61+
- noise (:obj:`bool`): Whether to use ``NoiseLinearLayer`` as ``layer_fn`` to boost exploration in \
62+
Q networks' MLP. Default to ``False``.
6063
"""
6164
super(DQN, self).__init__()
6265
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
@@ -90,7 +93,8 @@ def __init__(
9093
layer_num=head_layer_num,
9194
activation=activation,
9295
norm_type=norm_type,
93-
dropout=dropout
96+
dropout=dropout,
97+
noise=noise,
9498
)
9599
else:
96100
self.head = head_cls(
@@ -99,7 +103,8 @@ def __init__(
99103
head_layer_num,
100104
activation=activation,
101105
norm_type=norm_type,
102-
dropout=dropout
106+
dropout=dropout,
107+
noise=noise,
103108
)
104109
if init_bias is not None and head_cls == DuelingHead:
105110
# Zero the last layer bias of advantage head

ding/policy/common_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
from typing import List, Any, Dict, Callable
22
import torch
3+
import torch.nn as nn
34
import numpy as np
45
import treetensor.torch as ttorch
56
from ding.utils.data import default_collate
67
from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze
8+
from ding.torch_utils import NoiseLinearLayer
9+
10+
11+
def set_noise_mode(module: nn.Module, noise_enabled: bool):
12+
"""
13+
Overview:
14+
Recursively set the 'enable_noise' attribute for all NoiseLinearLayer modules within the given module.
15+
This function is typically used in algorithms such as NoisyNet and Rainbow.
16+
During training, 'enable_noise' should be set to True to enable noise for exploration.
17+
During inference or evaluation, it should be set to False to disable noise for deterministic behavior.
18+
19+
Arguments:
20+
- module (:obj:`nn.Module`): The root module to search for NoiseLinearLayer instances.
21+
- noise_enabled (:obj:`bool`): Whether to enable or disable noise.
22+
"""
23+
for m in module.modules():
24+
if isinstance(m, NoiseLinearLayer):
25+
m.enable_noise = noise_enabled
726

827

928
def default_preprocess_learn(

ding/policy/dqn.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ding.utils.data import default_collate, default_decollate
1111

1212
from .base_policy import Policy
13-
from .common_utils import default_preprocess_learn
13+
from .common_utils import default_preprocess_learn, set_noise_mode
1414

1515

1616
@POLICY_REGISTRY.register('dqn')
@@ -97,6 +97,8 @@ class DQNPolicy(Policy):
9797
discount_factor=0.97,
9898
# (int) The number of steps for calculating target q_value.
9999
nstep=1,
100+
# (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is False.
101+
noisy_net=False,
100102
model=dict(
101103
# (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer.
102104
encoder_hidden_size_list=[128, 128, 64],
@@ -248,6 +250,21 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
248250
.. note::
249251
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
250252
"""
253+
# Set noise mode for NoisyNet for exploration in learning if enabled in config
254+
# We need to reset set_noise_mode every _forward_xxx because the model is reused across different
255+
# phases (learn/collect/eval).
256+
if self._cfg.noisy_net:
257+
set_noise_mode(self._learn_model, True)
258+
set_noise_mode(self._target_model, True)
259+
260+
# A noisy network agent samples a new set of parameters after every step of optimisation.
261+
# Between optimisation steps, the agent acts according to a fixed set of parameters (weights and biases).
262+
# This ensures that the agent always acts according to parameters that are drawn from
263+
# the current noise distribution.
264+
if self._cfg.noisy_net:
265+
self._reset_noise(self._learn_model)
266+
self._reset_noise(self._target_model)
267+
251268
# Data preprocessing operations, such as stack data, cpu to cuda device
252269
data = default_preprocess_learn(
253270
data,
@@ -380,10 +397,17 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
380397
.. note::
381398
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
382399
"""
400+
# Set noise mode for NoisyNet for exploration in collecting if enabled in config.
401+
# We need to reset set_noise_mode every _forward_xxx because the model is reused across different
402+
# phases (learn/collect/eval).
403+
if self._cfg.noisy_net:
404+
set_noise_mode(self._collect_model, True)
405+
383406
data_id = list(data.keys())
384407
data = default_collate(list(data.values()))
385408
if self._cuda:
386409
data = to_device(data, self._device)
410+
387411
self._collect_model.eval()
388412
with torch.no_grad():
389413
output = self._collect_model.forward(data, eps=eps)
@@ -472,10 +496,16 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
472496
.. note::
473497
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
474498
"""
499+
# We need to reset set_noise_mode every _forward_xxx because the model is reused across different
500+
# phases (learn/collect/eval).
501+
# Ensure that in evaluation mode noise is disabled.
502+
set_noise_mode(self._eval_model, False)
503+
475504
data_id = list(data.keys())
476505
data = default_collate(list(data.values()))
477506
if self._cuda:
478507
data = to_device(data, self._device)
508+
479509
self._eval_model.eval()
480510
with torch.no_grad():
481511
output = self._eval_model.forward(data)
@@ -533,6 +563,18 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F
533563
)
534564
return {'priority': td_error_per_sample.abs().tolist()}
535565

566+
def _reset_noise(self, model: torch.nn.Module):
567+
r"""
568+
Overview:
569+
Reset the noise of model.
570+
571+
Arguments:
572+
- model (:obj:`torch.nn.Module`): the model to reset, must contain reset_noise method
573+
"""
574+
for m in model.modules():
575+
if hasattr(m, 'reset_noise'):
576+
m.reset_noise()
577+
536578

537579
@POLICY_REGISTRY.register('dqn_stdim')
538580
class DQNSTDIMPolicy(DQNPolicy):

ding/policy/rainbow.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ding.utils import POLICY_REGISTRY
99
from ding.utils.data import default_collate, default_decollate
1010
from .dqn import DQNPolicy
11-
from .common_utils import default_preprocess_learn
11+
from .common_utils import default_preprocess_learn, set_noise_mode
1212

1313

1414
@POLICY_REGISTRY.register('rainbow')
@@ -86,8 +86,9 @@ class RainbowDQNPolicy(DQNPolicy):
8686
discount_factor=0.99,
8787
# (int) N-step reward for target q_value estimation
8888
nstep=3,
89+
# (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is True.
90+
noisy_net=True,
8991
learn=dict(
90-
9192
# How many updates(iterations) to train after collector's one collection.
9293
# Bigger "update_per_collect" means bigger off-policy.
9394
# collect data -> update policy-> collect data -> ...
@@ -201,6 +202,11 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
201202
# ====================
202203
self._learn_model.train()
203204
self._target_model.train()
205+
206+
# Set noise mode for NoisyNet for exploration in learning if enabled in config
207+
set_noise_mode(self._learn_model, True)
208+
set_noise_mode(self._target_model, True)
209+
204210
# reset noise of noisenet for both main model and target model
205211
self._reset_noise(self._learn_model)
206212
self._reset_noise(self._target_model)
@@ -262,12 +268,16 @@ def _forward_collect(self, data: dict, eps: float) -> dict:
262268
ReturnsKeys
263269
- necessary: ``action``
264270
"""
271+
# Set noise mode for NoisyNet for exploration in collecting if enabled in config
272+
# We need to reset set_noise_mode every _forward_xxx because the model is reused across
273+
# different phases (learn/collect/eval).
274+
set_noise_mode(self._collect_model, True)
275+
265276
data_id = list(data.keys())
266277
data = default_collate(list(data.values()))
267278
if self._cuda:
268279
data = to_device(data, self._device)
269280
self._collect_model.eval()
270-
self._reset_noise(self._collect_model)
271281
with torch.no_grad():
272282
output = self._collect_model.forward(data, eps=eps)
273283
if self._cuda:

ding/torch_utils/network/nn_module.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,10 @@ class NoiseLinearLayer(nn.Module):
637637
def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> None:
638638
"""
639639
Overview:
640-
Initialize the NoiseLinearLayer class.
640+
Initialize the NoiseLinearLayer class. The 'enable_noise' attribute enables external control over whether \
641+
noise is applied.
642+
- If enable_noise is True, the layer adds noise even if the module is in evaluation mode.
643+
- If enable_noise is False, no noise is added regardless of self.training.
641644
Arguments:
642645
- in_channels (:obj:`int`): Number of channels in the input tensor.
643646
- out_channels (:obj:`int`): Number of channels in the output tensor.
@@ -654,6 +657,7 @@ def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> No
654657
self.register_buffer("weight_eps", torch.empty(out_channels, in_channels))
655658
self.register_buffer("bias_eps", torch.empty(out_channels))
656659
self.sigma0 = sigma0
660+
self.enable_noise = False
657661
self.reset_parameters()
658662
self.reset_noise()
659663

@@ -703,7 +707,8 @@ def forward(self, x: torch.Tensor):
703707
Returns:
704708
- output (:obj:`torch.Tensor`): The output tensor with noise.
705709
"""
706-
if self.training:
710+
# Determine whether to add noise:
711+
if self.enable_noise:
707712
return F.linear(
708713
x,
709714
self.weight_mu + self.weight_sigma * self.weight_eps,

ding/torch_utils/network/tests/test_nn_module.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ding.torch_utils import build_activation
66
from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
77
ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
8-
normed_linear, normed_conv2d
8+
normed_linear, normed_conv2d, NoiseLinearLayer
99

1010
batch_size = 2
1111
in_channels = 2
@@ -238,3 +238,27 @@ def test_flatten(self):
238238
model3 = NaiveFlatten(1, 3)
239239
output3 = model2(inputs)
240240
assert output1.shape == (4, 3 * 8 * 8)
241+
242+
def test_noise_linear_layer(self):
243+
input = torch.rand(batch_size, in_channels).requires_grad_(True)
244+
layer = NoiseLinearLayer(in_channels, out_channels, sigma0=0.5)
245+
# No noise by default
246+
output = self.run_model(input, layer)
247+
assert output.shape == (batch_size, out_channels)
248+
# Enable noise
249+
layer.enable_noise = True
250+
layer.reset_noise()
251+
output_noise = self.run_model(input, layer)
252+
assert output_noise.shape == (batch_size, out_channels)
253+
# Check that outputs are different after resetting noise
254+
with torch.no_grad():
255+
layer.reset_noise()
256+
out1 = layer(input)
257+
layer.reset_noise()
258+
out2 = layer(input)
259+
# The outputs should be different (very likely)
260+
assert not torch.allclose(out1, out2)
261+
# Check reset_parameters
262+
layer.reset_parameters()
263+
assert layer.weight_mu.shape == (out_channels, in_channels)
264+
assert layer.bias_mu.shape == (out_channels, )
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from easydict import EasyDict
2+
3+
demon_attack_dqn_config = dict(
4+
exp_name='DemonAttack_dqn_seed0',
5+
env=dict(
6+
collector_env_num=8,
7+
evaluator_env_num=8,
8+
n_evaluator_episode=8,
9+
stop_value=1e6,
10+
env_id='DemonAttackNoFrameskip-v4',
11+
frame_stack=4,
12+
),
13+
policy=dict(
14+
cuda=True,
15+
priority=False,
16+
model=dict(
17+
obs_shape=[4, 84, 84],
18+
action_shape=6,
19+
encoder_hidden_size_list=[128, 128, 512],
20+
noise=True,
21+
),
22+
nstep=3,
23+
discount_factor=0.99,
24+
learn=dict(
25+
update_per_collect=10,
26+
batch_size=32,
27+
learning_rate=0.0001,
28+
target_update_freq=500,
29+
),
30+
noisy_net=True,
31+
collect=dict(n_sample=96),
32+
eval=dict(evaluator=dict(eval_freq=4000, )),
33+
other=dict(
34+
eps=dict(
35+
type='exp',
36+
start=1.,
37+
end=0.05,
38+
decay=250000,
39+
),
40+
replay_buffer=dict(replay_buffer_size=100000, ),
41+
),
42+
),
43+
)
44+
demon_attack_dqn_config = EasyDict(demon_attack_dqn_config)
45+
main_config = demon_attack_dqn_config
46+
demon_attack_dqn_create_config = dict(
47+
env=dict(
48+
type='atari',
49+
import_names=['dizoo.atari.envs.atari_env'],
50+
),
51+
env_manager=dict(type='subprocess'),
52+
policy=dict(type='dqn'),
53+
)
54+
demon_attack_dqn_create_config = EasyDict(demon_attack_dqn_create_config)
55+
create_config = demon_attack_dqn_create_config
56+
57+
if __name__ == '__main__':
58+
# or you can enter `ding -m serial -c demon_attack_dqn_config.py -s 0`
59+
from ding.entry import serial_pipeline
60+
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(10e6))

0 commit comments

Comments
 (0)