Skip to content

Commit 454334c

Browse files
committed
polish(pu): polish comments
1 parent 5a01fde commit 454334c

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

ding/model/template/q_learning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __init__(
5858
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
5959
if ``None`` then default disable dropout layer.
6060
- init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \
61+
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
62+
Default ``False``.
6163
"""
6264
super(DQN, self).__init__()
6365
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4

ding/policy/common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
def set_noise_mode(module: nn.Module, noise_enabled: bool):
1111
"""
12-
Recursively set the 'force_noise' flag on all NoiseLinearLayer modules within the given module.
12+
Overview:
13+
Recursively set the 'force_noise' flag on all NoiseLinearLayer modules within the given module.
1314
"""
1415
for m in module.modules():
1516
if isinstance(m, NoiseLinearLayer):

ding/policy/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
386386
data = default_collate(list(data.values()))
387387
if self._cuda:
388388
data = to_device(data, self._device)
389-
# Use the new config parameter to decide noise mode.
389+
# Use the add_noise parameter to decide noise mode.
390390
# Default to True if the parameter is not provided.
391391
if self._cfg.collect.get("add_noise", True):
392392
set_noise_mode(self._collect_model, True)

dizoo/atari/config/serial/demon_attack/demon_attack_dqn_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from easydict import EasyDict
22

3-
pong_dqn_config = dict(
3+
demon_attack_dqn_config = dict(
44
exp_name='DemonAttack_dqn_collect-not-noise_seed0',
55
env=dict(
66
collector_env_num=8,
@@ -41,20 +41,20 @@
4141
),
4242
),
4343
)
44-
pong_dqn_config = EasyDict(pong_dqn_config)
45-
main_config = pong_dqn_config
46-
pong_dqn_create_config = dict(
44+
demon_attack_dqn_config = EasyDict(demon_attack_dqn_config)
45+
main_config = demon_attack_dqn_config
46+
demon_attack_dqn_create_config = dict(
4747
env=dict(
4848
type='atari',
4949
import_names=['dizoo.atari.envs.atari_env'],
5050
),
5151
env_manager=dict(type='subprocess'),
5252
policy=dict(type='dqn'),
5353
)
54-
pong_dqn_create_config = EasyDict(pong_dqn_create_config)
55-
create_config = pong_dqn_create_config
54+
demon_attack_dqn_create_config = EasyDict(demon_attack_dqn_create_config)
55+
create_config = demon_attack_dqn_create_config
5656

5757
if __name__ == '__main__':
58-
# or you can enter `ding -m serial -c pong_dqn_config.py -s 0`
58+
# or you can enter `ding -m serial -c demon_attack_dqn_config.py -s 0`
5959
from ding.entry import serial_pipeline
6060
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(10e6))

dizoo/atari/config/serial/pong/pong_dqn_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
learning_rate=0.0001,
2828
target_update_freq=500,
2929
),
30-
collect=dict(n_sample=96,),
30+
collect=dict(n_sample=96, ),
3131
eval=dict(evaluator=dict(eval_freq=4000, )),
3232
other=dict(
3333
eps=dict(

0 commit comments

Comments
 (0)