|
| 1 | +from typing import Union, Optional, Dict |
| 2 | +from easydict import EasyDict |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +from ding.model.common import ReparameterizationHead, EnsembleHead |
| 7 | +from ding.utils import SequenceType, squeeze |
| 8 | + |
| 9 | +from ding.utils import MODEL_REGISTRY |
| 10 | + |
| 11 | + |
| 12 | +@MODEL_REGISTRY.register('edac') |
| 13 | +class QACEnsemble(nn.Module): |
| 14 | + r""" |
| 15 | + Overview: |
| 16 | + The QAC network with ensemble, which is used in EDAC. |
| 17 | + Interfaces: |
| 18 | + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` |
| 19 | + """ |
| 20 | + mode = ['compute_actor', 'compute_critic'] |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + obs_shape: Union[int, SequenceType], |
| 25 | + action_shape: Union[int, SequenceType, EasyDict], |
| 26 | + ensemble_num: int = 2, |
| 27 | + actor_head_hidden_size: int = 64, |
| 28 | + actor_head_layer_num: int = 1, |
| 29 | + critic_head_hidden_size: int = 64, |
| 30 | + critic_head_layer_num: int = 1, |
| 31 | + activation: Optional[nn.Module] = nn.ReLU(), |
| 32 | + norm_type: Optional[str] = None, |
| 33 | + **kwargs |
| 34 | + ) -> None: |
| 35 | + """ |
| 36 | + Overview: |
| 37 | + Initailize the EDAC Model according to input arguments. |
| 38 | + Arguments: |
| 39 | + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). |
| 40 | + - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ |
| 41 | + EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). |
| 42 | + - ensemble_num (:obj:`int`): Q-net number. |
| 43 | + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. |
| 44 | + - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ |
| 45 | + for actor head. |
| 46 | + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. |
| 47 | + - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ |
| 48 | + for critic head. |
| 49 | + - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ |
| 50 | + after each FC layer, if ``None`` then default set to ``nn.ReLU()``. |
| 51 | + - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ |
| 52 | + see ``ding.torch_utils.network`` for more details. |
| 53 | + """ |
| 54 | + super(QACEnsemble, self).__init__() |
| 55 | + obs_shape: int = squeeze(obs_shape) |
| 56 | + action_shape = squeeze(action_shape) |
| 57 | + self.action_shape = action_shape |
| 58 | + self.ensemble_num = ensemble_num |
| 59 | + self.actor = nn.Sequential( |
| 60 | + nn.Linear(obs_shape, actor_head_hidden_size), activation, |
| 61 | + ReparameterizationHead( |
| 62 | + actor_head_hidden_size, |
| 63 | + action_shape, |
| 64 | + actor_head_layer_num, |
| 65 | + sigma_type='conditioned', |
| 66 | + activation=activation, |
| 67 | + norm_type=norm_type |
| 68 | + ) |
| 69 | + ) |
| 70 | + |
| 71 | + critic_input_size = obs_shape + action_shape |
| 72 | + self.critic = EnsembleHead( |
| 73 | + critic_input_size, |
| 74 | + 1, |
| 75 | + critic_head_hidden_size, |
| 76 | + critic_head_layer_num, |
| 77 | + self.ensemble_num, |
| 78 | + activation=activation, |
| 79 | + norm_type=norm_type |
| 80 | + ) |
| 81 | + |
| 82 | + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]: |
| 83 | + """ |
| 84 | + Overview: |
| 85 | + The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \ |
| 86 | + different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC. |
| 87 | + Mode compute_actor: |
| 88 | + Arguments: |
| 89 | + - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. |
| 90 | + Returns: |
| 91 | + - output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space. |
| 92 | + Mode compute_critic: |
| 93 | + Arguments: |
| 94 | + - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. |
| 95 | + Returns: |
| 96 | + - output (:obj:`Dict`): Output dict data, including q_value tensor. |
| 97 | + .. note:: |
| 98 | + For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. |
| 99 | + """ |
| 100 | + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
| 101 | + return getattr(self, mode)(inputs) |
| 102 | + |
| 103 | + def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: |
| 104 | + """ |
| 105 | + Overview: |
| 106 | + The forward computation graph of compute_actor mode, uses observation tensor to produce actor output, |
| 107 | + such as ``action``, ``logit`` and so on. |
| 108 | + Arguments: |
| 109 | + - obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \ |
| 110 | + i.e. ``(B, obs_shape)``. |
| 111 | + Returns: |
| 112 | + - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \ |
| 113 | + from action_space: ``reparameterization``. |
| 114 | + ReturnsKeys (either): |
| 115 | + - logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC. |
| 116 | + - mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution. |
| 117 | + - sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution. |
| 118 | + Shapes: |
| 119 | + - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. |
| 120 | + - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. |
| 121 | + - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. |
| 122 | + - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size. |
| 123 | + - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ |
| 124 | + ``action_shape.action_type_shape``. |
| 125 | + - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ |
| 126 | + ``action_shape.action_args_shape``. |
| 127 | + Examples: |
| 128 | + >>> model = QACEnsemble(64, 64,) |
| 129 | + >>> obs = torch.randn(4, 64) |
| 130 | + >>> actor_outputs = model(obs,'compute_actor') |
| 131 | + >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu |
| 132 | + >>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma |
| 133 | + """ |
| 134 | + x = self.actor(obs) |
| 135 | + return {'logit': [x['mu'], x['sigma']]} |
| 136 | + |
| 137 | + def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| 138 | + """ |
| 139 | + Overview: |
| 140 | + The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic |
| 141 | + output, such as ``q_value``. |
| 142 | + Arguments: |
| 143 | + - inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \ |
| 144 | + ``action`` tensor |
| 145 | + Returns: |
| 146 | + - outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``. |
| 147 | + ArgumentsKeys: |
| 148 | + - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data. |
| 149 | + - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``. |
| 150 | + ReturnKeys: |
| 151 | + - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
| 152 | + Shapes: |
| 153 | + - obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \ |
| 154 | + ``obs_shape``. |
| 155 | + - action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \ |
| 156 | + is ``action_shape``. |
| 157 | + - q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size. |
| 158 | + Examples: |
| 159 | + >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} |
| 160 | + >>> model = EDAC(obs_shape=(8, ),action_shape=1) |
| 161 | + >>> model(inputs, mode='compute_critic')['q_value'] # q value |
| 162 | + ... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>) |
| 163 | + """ |
| 164 | + |
| 165 | + obs, action = inputs['obs'], inputs['action'] |
| 166 | + if len(action.shape) == 1: # (B, ) -> (B, 1) |
| 167 | + action = action.unsqueeze(1) |
| 168 | + x = torch.cat([obs, action], dim=-1) |
| 169 | + if len(obs.shape) < 3: |
| 170 | + # [batch_size,dim] -> [batch_size,Ensemble_num * dim,1] |
| 171 | + x = x.repeat(1, self.ensemble_num).unsqueeze(-1) |
| 172 | + else: |
| 173 | + # [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1] |
| 174 | + x = x.transpose(0, 1) |
| 175 | + batch_size = obs.shape[1] |
| 176 | + x = x.reshape(batch_size, -1, 1) |
| 177 | + # [Ensemble_num,batch_size,1] |
| 178 | + x = self.critic(x)['pred'] |
| 179 | + # [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size] |
| 180 | + x = x.permute(1, 0) |
| 181 | + return {'q_value': x} |
0 commit comments