Skip to content

Commit 01b1a8b

Browse files
authored
feature(zjow): impala policy for continuous action space (#551)
* Add continuous impala * Add config file. * polish config * rm matrix sigma * polish * polish * add unittest * add unittest * polish * polish config * polish policy
1 parent 75d8644 commit 01b1a8b

File tree

10 files changed

+265
-39
lines changed

10 files changed

+265
-39
lines changed

ding/hpc_rl/tests/test_vtrace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
import torch
33
import torch.nn.functional as F
4-
from hpc_rll.origin.vtrace import vtrace_error, vtrace_data
4+
from hpc_rll.origin.vtrace import vtrace_error_discrete_action, vtrace_data
55
from hpc_rll.rl_utils.vtrace import VTrace
66
from testbase import mean_relative_error, times
77

@@ -48,7 +48,7 @@ def vtrace_val():
4848

4949
ori_target_output.requires_grad_(True)
5050
ori_value.requires_grad_(True)
51-
ori_loss = vtrace_error(
51+
ori_loss = vtrace_error_discrete_action(
5252
vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)
5353
)
5454
ori_loss = sum(ori_loss)
@@ -114,7 +114,7 @@ def vtrace_perf():
114114
ori_value.requires_grad_(True)
115115
for i in range(times):
116116
t = time.time()
117-
ori_loss = vtrace_error(
117+
ori_loss = vtrace_error_discrete_action(
118118
vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)
119119
)
120120
ori_loss = sum(ori_loss)

ding/hpc_rl/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def register_runtime_fn(fn_name, runtime_name, shape):
6969
'ScatterConnection': ['hpc_rll.torch_utils.network.scatter_connection', 'ScatterConnection'],
7070
'td_lambda_error': ['hpc_rll.rl_utils.td', 'TDLambda'],
7171
'upgo_loss': ['hpc_rll.rl_utils.upgo', 'UPGO'],
72-
'vtrace_error': ['hpc_rll.rl_utils.vtrace', 'VTrace'],
72+
'vtrace_error_discrete_action': ['hpc_rll.rl_utils.vtrace', 'VTrace'],
7373
}
7474
fn_str = fn_name_mapping[fn_name]
7575
cls = getattr(importlib.import_module(fn_str[0]), fn_str[1])

ding/model/wrapper/model_wrappers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from abc import ABC
33
import numpy as np
44
import torch
5+
import torch.nn.functional as F
6+
from torch.distributions import Categorical, Independent, Normal
57
from ding.torch_utils import get_tensor_data
68
from ding.rl_utils import create_noise_generator
7-
from torch.distributions import Categorical, Independent, Normal
89
from ding.utils.data import default_collate
9-
import torch.nn.functional as F
1010

1111

1212
class IModelWrapper(ABC):

ding/policy/impala.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from ding.model import model_wrap
7-
from ding.rl_utils import vtrace_data, vtrace_error, get_train_sample
7+
from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action, get_train_sample
88
from ding.torch_utils import Adam, RMSprop, to_device
99
from ding.utils import POLICY_REGISTRY
1010
from ding.utils.data import default_collate, default_decollate
@@ -48,6 +48,8 @@ class IMPALAPolicy(Policy):
4848
priority=False,
4949
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
5050
priority_IS_weight=False,
51+
# (str) Which kind of action space used in IMPALAPolicy, ['discrete', 'continuous']
52+
action_space='discrete',
5153
# (int) the trajectory length to calculate v-trace target
5254
unroll_len=32,
5355
# (bool) Whether to need policy data in process transition
@@ -97,6 +99,8 @@ def _init_learn(self) -> None:
9799
Learn mode init method. Called by ``self.__init__``.
98100
Initialize the optimizer, algorithm config and main model.
99101
"""
102+
assert self._cfg.action_space in ["continuous", "discrete"]
103+
self._action_space = self._cfg.action_space
100104
# Optimizer
101105
grad_clip_type = self._cfg.learn.get("grad_clip_type", None)
102106
clip_value = self._cfg.learn.get("clip_value", None)
@@ -165,10 +169,21 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
165169
else:
166170
data['weight'] = data.get('weight', None)
167171
data['obs_plus_1'] = torch.cat((data['obs'] + data['next_obs'][-1:]), dim=0) # shape (T+1)*B,env_obs_shape
168-
data['logit'] = torch.cat(
169-
data['logit'], dim=0
170-
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
171-
data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
172+
if self._action_space == 'continuous':
173+
data['logit']['mu'] = torch.cat(
174+
data['logit']['mu'], dim=0
175+
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
176+
data['logit']['sigma'] = torch.cat(
177+
data['logit']['sigma'], dim=0
178+
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
179+
data['action'] = torch.cat(
180+
data['action'], dim=0
181+
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
182+
elif self._action_space == 'discrete':
183+
data['logit'] = torch.cat(
184+
data['logit'], dim=0
185+
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
186+
data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
172187
data['done'] = torch.cat(data['done'], dim=0).reshape(self._unroll_len, -1).float() # shape T,B,
173188
data['reward'] = torch.cat(data['reward'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
174189
data['weight'] = torch.cat(
@@ -204,7 +219,11 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
204219
# Calculate vtrace error
205220
data = vtrace_data(target_logit, behaviour_logit, actions, values, rewards, weights)
206221
g, l, r, c, rg = self._gamma, self._lambda, self._rho_clip_ratio, self._c_clip_ratio, self._rho_pg_clip_ratio
207-
vtrace_loss = vtrace_error(data, g, l, r, c, rg)
222+
if self._action_space == 'continuous':
223+
vtrace_loss = vtrace_error_continuous_action(data, g, l, r, c, rg)
224+
elif self._action_space == 'discrete':
225+
vtrace_loss = vtrace_error_discrete_action(data, g, l, r, c, rg)
226+
208227
wv, we = self._value_weight, self._entropy_weight
209228
total_loss = vtrace_loss.policy_loss + wv * vtrace_loss.value_loss - we * vtrace_loss.entropy_loss
210229
# ====================
@@ -244,10 +263,18 @@ def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple[A
244263
- rewards (:obj:`torch.FloatTensor`): :math:`(T, B)`
245264
- weights (:obj:`torch.FloatTensor`): :math:`(T, B)`
246265
"""
247-
target_logit = output['logit'].reshape(self._unroll_len + 1, -1,
248-
self._action_shape)[:-1] # shape (T+1),B,env_obs_shape
266+
if self._action_space == 'continuous':
267+
target_logit = {}
268+
target_logit['mu'] = output['logit']['mu'].reshape(self._unroll_len + 1, -1,
269+
self._action_shape)[:-1
270+
] # shape (T+1),B,env_action_shape
271+
target_logit['sigma'] = output['logit']['sigma'].reshape(self._unroll_len + 1, -1, self._action_shape
272+
)[:-1] # shape (T+1),B,env_action_shape
273+
elif self._action_space == 'discrete':
274+
target_logit = output['logit'].reshape(self._unroll_len + 1, -1,
275+
self._action_shape)[:-1] # shape (T+1),B,env_action_shape
249276
behaviour_logit = data['logit'] # shape T,B
250-
actions = data['action'] # shape T,B
277+
actions = data['action'] # shape T,B for discrete # shape T,B,env_action_shape for continuous
251278
values = output['value'].reshape(self._unroll_len + 1, -1) # shape T+1,B,env_action_shape
252279
rewards = data['reward'] # shape T,B
253280
weights_ = 1 - data['done'] # shape T,B
@@ -289,7 +316,13 @@ def _init_collect(self) -> None:
289316
Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model.
290317
Use multinomial_sample to choose action.
291318
"""
292-
self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
319+
assert self._cfg.action_space in ["continuous", "discrete"]
320+
self._action_space = self._cfg.action_space
321+
if self._action_space == 'continuous':
322+
self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
323+
elif self._action_space == 'discrete':
324+
self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
325+
293326
self._collect_model.reset()
294327

295328
def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]:
@@ -364,7 +397,13 @@ def _init_eval(self) -> None:
364397
Evaluate mode init method. Called by ``self.__init__``, initialize eval_model,
365398
and use argmax_sample to choose action.
366399
"""
367-
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
400+
assert self._cfg.action_space in ["continuous", "discrete"]
401+
self._action_space = self._cfg.action_space
402+
if self._action_space == 'continuous':
403+
self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
404+
elif self._action_space == 'discrete':
405+
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
406+
368407
self._eval_model.reset()
369408

370409
def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:

ding/policy/ppo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _init_learn(self) -> None:
9292
self._priority_IS_weight = self._cfg.priority_IS_weight
9393
assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO"
9494

95+
assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
9596
self._action_space = self._cfg.action_space
9697
if self._cfg.learn.ppo_param_init:
9798
for n, m in self._model.named_modules():
@@ -287,6 +288,7 @@ def _init_collect(self) -> None:
287288
Init traj and unroll length, collect model.
288289
"""
289290
self._unroll_len = self._cfg.collect.unroll_len
291+
assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
290292
self._action_space = self._cfg.action_space
291293
if self._action_space == 'continuous':
292294
self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
@@ -399,6 +401,7 @@ def _init_eval(self) -> None:
399401
Evaluate mode init method. Called by ``self.__init__``.
400402
Init eval model with argmax strategy.
401403
"""
404+
assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
402405
self._action_space = self._cfg.action_space
403406
if self._action_space == 'continuous':
404407
self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
@@ -511,6 +514,7 @@ def default_model(self) -> Tuple[str, List[str]]:
511514
return 'pg', ['ding.model.template.pg']
512515

513516
def _init_learn(self) -> None:
517+
assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
514518
self._action_space = self._cfg.action_space
515519
if self._cfg.learn.ppo_param_init:
516520
for n, m in self._model.named_modules():
@@ -586,6 +590,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
586590
return return_infos
587591

588592
def _init_collect(self) -> None:
593+
assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
589594
self._action_space = self._cfg.action_space
590595
self._unroll_len = self._cfg.collect.unroll_len
591596
if self._action_space == 'continuous':
@@ -632,6 +637,7 @@ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
632637
return get_train_sample(data, self._unroll_len)
633638

634639
def _init_eval(self) -> None:
640+
assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
635641
self._action_space = self._cfg.action_space
636642
if self._action_space == 'continuous':
637643
self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')

ding/rl_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .upgo import upgo_loss
1717
from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
1818
from .value_rescale import value_transform, value_inv_transform
19-
from .vtrace import vtrace_data, vtrace_error
19+
from .vtrace import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action
2020
from .beta_function import beta_function_map
2121
from .retrace import compute_q_retraces
2222
from .acer import acer_policy_error, acer_value_error, acer_trust_region_update

ding/rl_utils/isw.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,54 @@
1+
from typing import Union
12
import torch
3+
from torch.distributions import Categorical, Independent, Normal
24

35

4-
def compute_importance_weights(target_output, behaviour_output, action, requires_grad=False):
6+
def compute_importance_weights(
7+
target_output: Union[torch.Tensor, dict],
8+
behaviour_output: Union[torch.Tensor, dict],
9+
action: torch.Tensor,
10+
action_space_type: str = 'discrete',
11+
requires_grad: bool = False
12+
):
513
"""
614
Overview:
715
Computing importance sampling weight with given output and action
816
Arguments:
9-
- target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\
10-
usually this output is network output logit
11-
- behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\
12-
usually this output is network output logit, which is used to produce the trajectory(collector)
17+
- target_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
18+
by the current policy network, \
19+
usually this output is network output logit if action space is discrete, \
20+
or is a dict containing parameters of action distribution if action space is continuous.
21+
- behaviour_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
22+
by the behaviour policy network,\
23+
usually this output is network output logit, if action space is discrete, \
24+
or is a dict containing parameters of action distribution if action space is continuous.
1325
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\
1426
i.e.: behaviour_action
27+
- action_space_type (:obj:`str`): action space types in ['discrete', 'continuous']
1528
- requires_grad (:obj:`bool`): whether requires grad computation
1629
Returns:
1730
- rhos (:obj:`torch.Tensor`): Importance sampling weight
1831
Shapes:
19-
- target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\
20-
N is action dim
21-
- behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
32+
- target_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`, \
33+
where T is timestep, B is batch size and N is action dim
34+
- behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`
2235
- action (:obj:`torch.LongTensor`): :math:`(T, B)`
2336
- rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`
2437
"""
2538
grad_context = torch.enable_grad() if requires_grad else torch.no_grad()
2639
assert isinstance(action, torch.Tensor)
40+
assert action_space_type in ['discrete', 'continuous']
2741

2842
with grad_context:
29-
dist_target = torch.distributions.Categorical(logits=target_output)
30-
dist_behaviour = torch.distributions.Categorical(logits=behaviour_output)
31-
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
32-
rhos = torch.exp(rhos)
33-
return rhos
43+
if action_space_type == 'continuous':
44+
dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1)
45+
dist_behaviour = Independent(Normal(loc=behaviour_output['mu'], scale=behaviour_output['sigma']), 1)
46+
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
47+
rhos = torch.exp(rhos)
48+
return rhos
49+
elif action_space_type == 'discrete':
50+
dist_target = Categorical(logits=target_output)
51+
dist_behaviour = Categorical(logits=behaviour_output)
52+
rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
53+
rhos = torch.exp(rhos)
54+
return rhos

ding/rl_utils/tests/test_vtrace.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,47 @@
11
import pytest
22
import torch
3-
from ding.rl_utils import vtrace_data, vtrace_error
3+
from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action
44

55

66
@pytest.mark.unittest
7-
def test_vtrace():
7+
def test_vtrace_discrete_action():
88
T, B, N = 4, 8, 16
99
value = torch.randn(T + 1, B).requires_grad_(True)
1010
reward = torch.rand(T, B)
1111
target_output = torch.randn(T, B, N).requires_grad_(True)
1212
behaviour_output = torch.randn(T, B, N)
1313
action = torch.randint(0, N, size=(T, B))
1414
data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
15-
loss = vtrace_error(data, rho_clip_ratio=1.1)
15+
loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
1616
assert all([l.shape == tuple() for l in loss])
1717
assert target_output.grad is None
1818
assert value.grad is None
1919
loss = sum(loss)
2020
loss.backward()
2121
assert isinstance(target_output, torch.Tensor)
2222
assert isinstance(value, torch.Tensor)
23+
24+
25+
@pytest.mark.unittest
26+
def test_vtrace_continuous_action():
27+
T, B, N = 4, 8, 16
28+
value = torch.randn(T + 1, B).requires_grad_(True)
29+
reward = torch.rand(T, B)
30+
target_output = {}
31+
target_output['mu'] = torch.randn(T, B, N).requires_grad_(True)
32+
target_output['sigma'] = torch.exp(torch.randn(T, B, N).requires_grad_(True))
33+
behaviour_output = {}
34+
behaviour_output['mu'] = torch.randn(T, B, N)
35+
behaviour_output['sigma'] = torch.exp(torch.randn(T, B, N))
36+
action = torch.randn((T, B, N))
37+
data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
38+
loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
39+
assert all([l.shape == tuple() for l in loss])
40+
assert target_output['mu'].grad is None
41+
assert target_output['sigma'].grad is None
42+
assert value.grad is None
43+
loss = sum(loss)
44+
loss.backward()
45+
assert isinstance(target_output['mu'], torch.Tensor)
46+
assert isinstance(target_output['sigma'], torch.Tensor)
47+
assert isinstance(value, torch.Tensor)

0 commit comments

Comments
 (0)