Skip to content

Commit 9c689d2

Browse files
authored
feature(cy): add BDQ algorithm (#558)
* add BDQ algrithm * after run reformat * update mujoco_env * add unittest; extend n-step TD; polished; * fix one error * fixed one error * fixed one error * add test_bdq.py * add readme * add pendulum_bdq test
1 parent 0a25e46 commit 9c689d2

File tree

17 files changed

+965
-6
lines changed

17 files changed

+965
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
240240
| 48 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
241241
| 49 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
242242
| 50 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
243+
| 51 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
243244
</details>
244245

245246

ding/entry/tests/test_serial_entry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config
5252
from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
5353
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
54+
from dizoo.classic_control.pendulum.config.pendulum_bdq_config import pendulum_bdq_config, pendulum_bdq_create_config # noqa
5455

5556

5657
@pytest.mark.platformtest
@@ -67,6 +68,20 @@ def test_dqn():
6768
os.popen('rm -rf cartpole_dqn_unittest')
6869

6970

71+
@pytest.mark.platformtest
72+
@pytest.mark.unittest
73+
def test_bdq():
74+
config = [deepcopy(pendulum_bdq_config), deepcopy(pendulum_bdq_create_config)]
75+
config[0].policy.learn.update_per_collect = 1
76+
config[0].exp_name = 'pendulum_bdq_unittest'
77+
try:
78+
serial_pipeline(config, seed=0, max_train_iter=1)
79+
except Exception:
80+
assert False, "pipeline fail"
81+
finally:
82+
os.popen('rm -rf pendulum_bdq_unittest')
83+
84+
7085
@pytest.mark.platformtest
7186
@pytest.mark.unittest
7287
def test_ddpg():

ding/model/common/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
2-
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, head_cls_map, \
2+
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
33
independent_normal_dist
44
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
55
from .utils import create_model

ding/model/common/head.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,116 @@ def forward(self, x: torch.Tensor) -> Dict:
174174
return {'logit': q, 'distribution': dist}
175175

176176

177+
class BranchingHead(nn.Module):
178+
179+
def __init__(
180+
self,
181+
hidden_size: int,
182+
num_branches: int = 0,
183+
action_bins_per_branch: int = 2,
184+
layer_num: int = 1,
185+
a_layer_num: Optional[int] = None,
186+
v_layer_num: Optional[int] = None,
187+
norm_type: Optional[str] = None,
188+
activation: Optional[nn.Module] = nn.ReLU(),
189+
noise: Optional[bool] = False,
190+
) -> None:
191+
"""
192+
Overview:
193+
Init the ``BranchingHead`` layers according to the provided arguments. \
194+
This head achieves a linear increase of the number of network outputs \
195+
with the number of degrees of freedom by allowing a level of independence \
196+
for each individual action dimension.
197+
Therefore, this head is suitable for high dimensional action Spaces.
198+
Arguments:
199+
- hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``BranchingHead``.
200+
- num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension.
201+
- action_bins_per_branch (:obj:int): The number of action bins in each dimension.
202+
- layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
203+
- a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
204+
- v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
205+
- output_size (:obj:`int`): The number of outputs.
206+
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
207+
for more details. Default ``None``.
208+
- activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
209+
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
210+
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
211+
Default ``False``.
212+
"""
213+
super(BranchingHead, self).__init__()
214+
if a_layer_num is None:
215+
a_layer_num = layer_num
216+
if v_layer_num is None:
217+
v_layer_num = layer_num
218+
self.num_branches = num_branches
219+
self.action_bins_per_branch = action_bins_per_branch
220+
221+
layer = NoiseLinearLayer if noise else nn.Linear
222+
block = noise_block if noise else fc_block
223+
# value network
224+
225+
self.V = nn.Sequential(
226+
MLP(
227+
hidden_size,
228+
hidden_size,
229+
hidden_size,
230+
v_layer_num,
231+
layer_fn=layer,
232+
activation=activation,
233+
norm_type=norm_type
234+
), block(hidden_size, 1)
235+
)
236+
# action branching network
237+
action_output_dim = action_bins_per_branch
238+
self.branches = nn.ModuleList(
239+
[
240+
nn.Sequential(
241+
MLP(
242+
hidden_size,
243+
hidden_size,
244+
hidden_size,
245+
a_layer_num,
246+
layer_fn=layer,
247+
activation=activation,
248+
norm_type=norm_type
249+
), block(hidden_size, action_output_dim)
250+
) for _ in range(self.num_branches)
251+
]
252+
)
253+
254+
def forward(self, x: torch.Tensor) -> Dict:
255+
"""
256+
Overview:
257+
Use encoded embedding tensor to run MLP with ``BranchingHead`` and return the prediction dictionary.
258+
Arguments:
259+
- x (:obj:`torch.Tensor`): Tensor containing input embedding.
260+
Returns:
261+
- outputs (:obj:`Dict`): Dict containing keyword ``logit`` (:obj:`torch.Tensor`).
262+
Shapes:
263+
- x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
264+
- logit: :math:`(B, M)`, where ``M = output_size``.
265+
266+
Examples:
267+
>>> head = BranchingHead(64, 5, 2)
268+
>>> inputs = torch.randn(4, 64)
269+
>>> outputs = head(inputs)
270+
>>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
271+
"""
272+
value_out = self.V(x)
273+
value_out = torch.unsqueeze(value_out, 1)
274+
action_out = []
275+
for b in self.branches:
276+
action_out.append(b(x))
277+
action_scores = torch.stack(action_out, 1)
278+
'''
279+
From the paper, this implementation performs better than both the naive alternative (Q = V + A) \
280+
and the local maximum reduction method (Q = V + max(A)).
281+
'''
282+
action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
283+
logits = value_out + action_scores
284+
return {'logit': logits}
285+
286+
177287
class RainbowHead(nn.Module):
178288
"""
179289
Overview:

ding/model/template/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# general
2-
from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN
2+
from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ
33
from .qac import QAC, DiscreteQAC
44
from .pdqn import PDQN
55
from .vac import VAC

ding/model/template/q_learning.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ding.torch_utils import get_lstm
66
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
77
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
8-
QuantileHead, FQFHead, QRDQNHead, DistributionHead
8+
QuantileHead, FQFHead, QRDQNHead, DistributionHead, BranchingHead
99
from ding.torch_utils.network.gtrxl import GTrXL
1010

1111

@@ -98,6 +98,101 @@ def forward(self, x: torch.Tensor) -> Dict:
9898
return x
9999

100100

101+
@MODEL_REGISTRY.register('bdq')
102+
class BDQ(nn.Module):
103+
104+
def __init__(
105+
self,
106+
obs_shape: Union[int, SequenceType],
107+
num_branches: int = 0,
108+
action_bins_per_branch: int = 2,
109+
layer_num: int = 3,
110+
a_layer_num: Optional[int] = None,
111+
v_layer_num: Optional[int] = None,
112+
encoder_hidden_size_list: SequenceType = [128, 128, 64],
113+
head_hidden_size: Optional[int] = None,
114+
norm_type: Optional[nn.Module] = None,
115+
activation: Optional[nn.Module] = nn.ReLU(),
116+
) -> None:
117+
"""
118+
Overview:
119+
Init the BDQ (encoder + head) Model according to input arguments. \
120+
referenced paper Action Branching Architectures for Deep Reinforcement Learning \
121+
<https://arxiv.org/pdf/1711.08946>
122+
Arguments:
123+
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
124+
- num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension, \
125+
such as 6 in mujoco's halfcheetah environment.
126+
- action_bins_per_branch (:obj:`int`): The number of actions in each dimension.
127+
- layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
128+
- a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
129+
- v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
130+
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
131+
the last element must match ``head_hidden_size``.
132+
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
133+
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
134+
``ding.torch_utils.fc_block`` for more details.
135+
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
136+
if ``None`` then default set it to ``nn.ReLU()``
137+
"""
138+
super(BDQ, self).__init__()
139+
# For compatibility: 1, (1, ), [4, 32, 32]
140+
obs_shape, num_branches = squeeze(obs_shape), squeeze(num_branches)
141+
if head_hidden_size is None:
142+
head_hidden_size = encoder_hidden_size_list[-1]
143+
144+
# backbone
145+
# FC Encoder
146+
if isinstance(obs_shape, int) or len(obs_shape) == 1:
147+
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
148+
# Conv Encoder
149+
elif len(obs_shape) == 3:
150+
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
151+
else:
152+
raise RuntimeError(
153+
"not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
154+
)
155+
156+
self.num_branches = num_branches
157+
self.action_bins_per_branch = action_bins_per_branch
158+
159+
# head
160+
self.head = BranchingHead(
161+
head_hidden_size,
162+
num_branches=self.num_branches,
163+
action_bins_per_branch=self.action_bins_per_branch,
164+
layer_num=layer_num,
165+
a_layer_num=a_layer_num,
166+
v_layer_num=v_layer_num,
167+
activation=activation,
168+
norm_type=norm_type
169+
)
170+
171+
def forward(self, x: torch.Tensor) -> Dict:
172+
r"""
173+
Overview:
174+
BDQ forward computation graph, input observation tensor to predict q_value.
175+
Arguments:
176+
- x (:obj:`torch.Tensor`): Observation inputs
177+
Returns:
178+
- outputs (:obj:`Dict`): BDQ forward outputs, such as q_value.
179+
ReturnsKeys:
180+
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension.
181+
Shapes:
182+
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
183+
- logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is
184+
``num_branches * action_bins_per_branch``
185+
Examples:
186+
>>> model = BDQ(8, 5, 2) # arguments: 'obs_shape', 'num_branches' and 'action_bins_per_branch'.
187+
>>> inputs = torch.randn(4, 8)
188+
>>> outputs = model(inputs)
189+
>>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
190+
"""
191+
x = self.encoder(x) / (self.num_branches + 1) # corresponds to the "Gradient Rescaling" in the paper
192+
x = self.head(x)
193+
return x
194+
195+
101196
@MODEL_REGISTRY.register('c51dqn')
102197
class C51DQN(nn.Module):
103198

ding/model/template/tests/test_q_learning.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from itertools import product
33
import torch
4-
from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN
4+
from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ
55
from ding.torch_utils import is_differentiable
66

77
T, B = 3, 4
@@ -40,6 +40,25 @@ def test_dqn(self, obs_shape, act_shape):
4040
assert outputs['logit'][i].shape == (B, s)
4141
self.output_check(model, outputs['logit'])
4242

43+
@pytest.mark.parametrize('obs_shape, act_shape', args)
44+
def test_bdq(self, obs_shape, act_shape):
45+
if isinstance(obs_shape, int):
46+
inputs = torch.randn(B, obs_shape)
47+
else:
48+
inputs = torch.randn(B, *obs_shape)
49+
if not isinstance(act_shape, int) and len(act_shape) > 1:
50+
return
51+
num_branches = act_shape
52+
for action_bins_per_branch in range(1, 10):
53+
model = BDQ(obs_shape, num_branches, action_bins_per_branch)
54+
outputs = model(inputs)
55+
assert isinstance(outputs, dict)
56+
if isinstance(act_shape, int):
57+
assert outputs['logit'].shape == (B, act_shape, action_bins_per_branch)
58+
else:
59+
assert outputs['logit'].shape == (B, *act_shape, action_bins_per_branch)
60+
self.output_check(model, outputs['logit'])
61+
4362
@pytest.mark.parametrize('obs_shape, act_shape', args)
4463
def test_rainbowdqn(self, obs_shape, act_shape):
4564
if isinstance(obs_shape, int):

ding/policy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@
4343

4444
from .bc import BehaviourCloningPolicy
4545
from .ibc import IBCPolicy
46+
47+
from .bdq import BDQPolicy

0 commit comments

Comments
 (0)