|
5 | 5 | from ding.torch_utils import get_lstm
|
6 | 6 | from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
|
7 | 7 | from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
|
8 |
| - QuantileHead, FQFHead, QRDQNHead, DistributionHead |
| 8 | + QuantileHead, FQFHead, QRDQNHead, DistributionHead, BranchingHead |
9 | 9 | from ding.torch_utils.network.gtrxl import GTrXL
|
10 | 10 |
|
11 | 11 |
|
@@ -98,6 +98,101 @@ def forward(self, x: torch.Tensor) -> Dict:
|
98 | 98 | return x
|
99 | 99 |
|
100 | 100 |
|
| 101 | +@MODEL_REGISTRY.register('bdq') |
| 102 | +class BDQ(nn.Module): |
| 103 | + |
| 104 | + def __init__( |
| 105 | + self, |
| 106 | + obs_shape: Union[int, SequenceType], |
| 107 | + num_branches: int = 0, |
| 108 | + action_bins_per_branch: int = 2, |
| 109 | + layer_num: int = 3, |
| 110 | + a_layer_num: Optional[int] = None, |
| 111 | + v_layer_num: Optional[int] = None, |
| 112 | + encoder_hidden_size_list: SequenceType = [128, 128, 64], |
| 113 | + head_hidden_size: Optional[int] = None, |
| 114 | + norm_type: Optional[nn.Module] = None, |
| 115 | + activation: Optional[nn.Module] = nn.ReLU(), |
| 116 | + ) -> None: |
| 117 | + """ |
| 118 | + Overview: |
| 119 | + Init the BDQ (encoder + head) Model according to input arguments. \ |
| 120 | + referenced paper Action Branching Architectures for Deep Reinforcement Learning \ |
| 121 | + <https://arxiv.org/pdf/1711.08946> |
| 122 | + Arguments: |
| 123 | + - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. |
| 124 | + - num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension, \ |
| 125 | + such as 6 in mujoco's halfcheetah environment. |
| 126 | + - action_bins_per_branch (:obj:`int`): The number of actions in each dimension. |
| 127 | + - layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output. |
| 128 | + - a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output. |
| 129 | + - v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output. |
| 130 | + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ |
| 131 | + the last element must match ``head_hidden_size``. |
| 132 | + - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. |
| 133 | + - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ |
| 134 | + ``ding.torch_utils.fc_block`` for more details. |
| 135 | + - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ |
| 136 | + if ``None`` then default set it to ``nn.ReLU()`` |
| 137 | + """ |
| 138 | + super(BDQ, self).__init__() |
| 139 | + # For compatibility: 1, (1, ), [4, 32, 32] |
| 140 | + obs_shape, num_branches = squeeze(obs_shape), squeeze(num_branches) |
| 141 | + if head_hidden_size is None: |
| 142 | + head_hidden_size = encoder_hidden_size_list[-1] |
| 143 | + |
| 144 | + # backbone |
| 145 | + # FC Encoder |
| 146 | + if isinstance(obs_shape, int) or len(obs_shape) == 1: |
| 147 | + self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) |
| 148 | + # Conv Encoder |
| 149 | + elif len(obs_shape) == 3: |
| 150 | + self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) |
| 151 | + else: |
| 152 | + raise RuntimeError( |
| 153 | + "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape) |
| 154 | + ) |
| 155 | + |
| 156 | + self.num_branches = num_branches |
| 157 | + self.action_bins_per_branch = action_bins_per_branch |
| 158 | + |
| 159 | + # head |
| 160 | + self.head = BranchingHead( |
| 161 | + head_hidden_size, |
| 162 | + num_branches=self.num_branches, |
| 163 | + action_bins_per_branch=self.action_bins_per_branch, |
| 164 | + layer_num=layer_num, |
| 165 | + a_layer_num=a_layer_num, |
| 166 | + v_layer_num=v_layer_num, |
| 167 | + activation=activation, |
| 168 | + norm_type=norm_type |
| 169 | + ) |
| 170 | + |
| 171 | + def forward(self, x: torch.Tensor) -> Dict: |
| 172 | + r""" |
| 173 | + Overview: |
| 174 | + BDQ forward computation graph, input observation tensor to predict q_value. |
| 175 | + Arguments: |
| 176 | + - x (:obj:`torch.Tensor`): Observation inputs |
| 177 | + Returns: |
| 178 | + - outputs (:obj:`Dict`): BDQ forward outputs, such as q_value. |
| 179 | + ReturnsKeys: |
| 180 | + - logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension. |
| 181 | + Shapes: |
| 182 | + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` |
| 183 | + - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is |
| 184 | + ``num_branches * action_bins_per_branch`` |
| 185 | + Examples: |
| 186 | + >>> model = BDQ(8, 5, 2) # arguments: 'obs_shape', 'num_branches' and 'action_bins_per_branch'. |
| 187 | + >>> inputs = torch.randn(4, 8) |
| 188 | + >>> outputs = model(inputs) |
| 189 | + >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2]) |
| 190 | + """ |
| 191 | + x = self.encoder(x) / (self.num_branches + 1) # corresponds to the "Gradient Rescaling" in the paper |
| 192 | + x = self.head(x) |
| 193 | + return x |
| 194 | + |
| 195 | + |
101 | 196 | @MODEL_REGISTRY.register('c51dqn')
|
102 | 197 | class C51DQN(nn.Module):
|
103 | 198 |
|
|
0 commit comments