|
4 | 4 | import torch
|
5 | 5 |
|
6 | 6 | from ding.model import model_wrap
|
7 |
| -from ding.rl_utils import vtrace_data, vtrace_error, get_train_sample |
| 7 | +from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action, get_train_sample |
8 | 8 | from ding.torch_utils import Adam, RMSprop, to_device
|
9 | 9 | from ding.utils import POLICY_REGISTRY
|
10 | 10 | from ding.utils.data import default_collate, default_decollate
|
@@ -48,6 +48,8 @@ class IMPALAPolicy(Policy):
|
48 | 48 | priority=False,
|
49 | 49 | # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
|
50 | 50 | priority_IS_weight=False,
|
| 51 | + # (str) Which kind of action space used in IMPALAPolicy, ['discrete', 'continuous'] |
| 52 | + action_space='discrete', |
51 | 53 | # (int) the trajectory length to calculate v-trace target
|
52 | 54 | unroll_len=32,
|
53 | 55 | # (bool) Whether to need policy data in process transition
|
@@ -97,6 +99,8 @@ def _init_learn(self) -> None:
|
97 | 99 | Learn mode init method. Called by ``self.__init__``.
|
98 | 100 | Initialize the optimizer, algorithm config and main model.
|
99 | 101 | """
|
| 102 | + assert self._cfg.action_space in ["continuous", "discrete"] |
| 103 | + self._action_space = self._cfg.action_space |
100 | 104 | # Optimizer
|
101 | 105 | grad_clip_type = self._cfg.learn.get("grad_clip_type", None)
|
102 | 106 | clip_value = self._cfg.learn.get("clip_value", None)
|
@@ -165,10 +169,21 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
|
165 | 169 | else:
|
166 | 170 | data['weight'] = data.get('weight', None)
|
167 | 171 | data['obs_plus_1'] = torch.cat((data['obs'] + data['next_obs'][-1:]), dim=0) # shape (T+1)*B,env_obs_shape
|
168 |
| - data['logit'] = torch.cat( |
169 |
| - data['logit'], dim=0 |
170 |
| - ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape |
171 |
| - data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B, |
| 172 | + if self._action_space == 'continuous': |
| 173 | + data['logit']['mu'] = torch.cat( |
| 174 | + data['logit']['mu'], dim=0 |
| 175 | + ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape |
| 176 | + data['logit']['sigma'] = torch.cat( |
| 177 | + data['logit']['sigma'], dim=0 |
| 178 | + ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape |
| 179 | + data['action'] = torch.cat( |
| 180 | + data['action'], dim=0 |
| 181 | + ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape |
| 182 | + elif self._action_space == 'discrete': |
| 183 | + data['logit'] = torch.cat( |
| 184 | + data['logit'], dim=0 |
| 185 | + ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape |
| 186 | + data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B, |
172 | 187 | data['done'] = torch.cat(data['done'], dim=0).reshape(self._unroll_len, -1).float() # shape T,B,
|
173 | 188 | data['reward'] = torch.cat(data['reward'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
|
174 | 189 | data['weight'] = torch.cat(
|
@@ -204,7 +219,11 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
204 | 219 | # Calculate vtrace error
|
205 | 220 | data = vtrace_data(target_logit, behaviour_logit, actions, values, rewards, weights)
|
206 | 221 | g, l, r, c, rg = self._gamma, self._lambda, self._rho_clip_ratio, self._c_clip_ratio, self._rho_pg_clip_ratio
|
207 |
| - vtrace_loss = vtrace_error(data, g, l, r, c, rg) |
| 222 | + if self._action_space == 'continuous': |
| 223 | + vtrace_loss = vtrace_error_continuous_action(data, g, l, r, c, rg) |
| 224 | + elif self._action_space == 'discrete': |
| 225 | + vtrace_loss = vtrace_error_discrete_action(data, g, l, r, c, rg) |
| 226 | + |
208 | 227 | wv, we = self._value_weight, self._entropy_weight
|
209 | 228 | total_loss = vtrace_loss.policy_loss + wv * vtrace_loss.value_loss - we * vtrace_loss.entropy_loss
|
210 | 229 | # ====================
|
@@ -244,10 +263,18 @@ def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple[A
|
244 | 263 | - rewards (:obj:`torch.FloatTensor`): :math:`(T, B)`
|
245 | 264 | - weights (:obj:`torch.FloatTensor`): :math:`(T, B)`
|
246 | 265 | """
|
247 |
| - target_logit = output['logit'].reshape(self._unroll_len + 1, -1, |
248 |
| - self._action_shape)[:-1] # shape (T+1),B,env_obs_shape |
| 266 | + if self._action_space == 'continuous': |
| 267 | + target_logit = {} |
| 268 | + target_logit['mu'] = output['logit']['mu'].reshape(self._unroll_len + 1, -1, |
| 269 | + self._action_shape)[:-1 |
| 270 | + ] # shape (T+1),B,env_action_shape |
| 271 | + target_logit['sigma'] = output['logit']['sigma'].reshape(self._unroll_len + 1, -1, self._action_shape |
| 272 | + )[:-1] # shape (T+1),B,env_action_shape |
| 273 | + elif self._action_space == 'discrete': |
| 274 | + target_logit = output['logit'].reshape(self._unroll_len + 1, -1, |
| 275 | + self._action_shape)[:-1] # shape (T+1),B,env_action_shape |
249 | 276 | behaviour_logit = data['logit'] # shape T,B
|
250 |
| - actions = data['action'] # shape T,B |
| 277 | + actions = data['action'] # shape T,B for discrete # shape T,B,env_action_shape for continuous |
251 | 278 | values = output['value'].reshape(self._unroll_len + 1, -1) # shape T+1,B,env_action_shape
|
252 | 279 | rewards = data['reward'] # shape T,B
|
253 | 280 | weights_ = 1 - data['done'] # shape T,B
|
@@ -289,7 +316,13 @@ def _init_collect(self) -> None:
|
289 | 316 | Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model.
|
290 | 317 | Use multinomial_sample to choose action.
|
291 | 318 | """
|
292 |
| - self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') |
| 319 | + assert self._cfg.action_space in ["continuous", "discrete"] |
| 320 | + self._action_space = self._cfg.action_space |
| 321 | + if self._action_space == 'continuous': |
| 322 | + self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample') |
| 323 | + elif self._action_space == 'discrete': |
| 324 | + self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') |
| 325 | + |
293 | 326 | self._collect_model.reset()
|
294 | 327 |
|
295 | 328 | def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]:
|
@@ -364,7 +397,13 @@ def _init_eval(self) -> None:
|
364 | 397 | Evaluate mode init method. Called by ``self.__init__``, initialize eval_model,
|
365 | 398 | and use argmax_sample to choose action.
|
366 | 399 | """
|
367 |
| - self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') |
| 400 | + assert self._cfg.action_space in ["continuous", "discrete"] |
| 401 | + self._action_space = self._cfg.action_space |
| 402 | + if self._action_space == 'continuous': |
| 403 | + self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample') |
| 404 | + elif self._action_space == 'discrete': |
| 405 | + self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') |
| 406 | + |
368 | 407 | self._eval_model.reset()
|
369 | 408 |
|
370 | 409 | def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
|
|
0 commit comments