Skip to content

Commit aefddac

Browse files
authored
feature(zc): add EDAC and modify config of td3bc (#639)
* add EDAC and modify config of td3bc * modify edac * add conv1d * add test_ensemble * add encoder * add encoder * add encoder * modify policy_init * modify edac * add init * modify td3_bc and readme * remove head in qac * modify edac comment * modify edac comment * modif edac * modify edac * modify head overview * modify example * format
1 parent 93e4fa5 commit aefddac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+944
-98
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
252252
| 50 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
253253
| 51 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
254254
| 52 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
255+
| 53 | [edac](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
255256
</details>
256257

257258

ding/example/edac.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import gym
2+
from ditk import logging
3+
from ding.model import QACEnsemble
4+
from ding.policy import EDACPolicy
5+
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
6+
from ding.data import create_dataset
7+
from ding.config import compile_config
8+
from ding.framework import task, ding_init
9+
from ding.framework.context import OfflineRLContext
10+
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
11+
from ding.utils import set_pkg_seed
12+
from dizoo.d4rl.envs import D4RLEnv
13+
from dizoo.d4rl.config.halfcheetah_medium_edac_config import main_config, create_config
14+
15+
16+
def main():
17+
# If you don't have offline data, you need to prepare if first and set the data_path in config
18+
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
19+
logging.getLogger().setLevel(logging.INFO)
20+
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
21+
ding_init(cfg)
22+
with task.start(async_mode=False, ctx=OfflineRLContext()):
23+
evaluator_env = BaseEnvManagerV2(
24+
env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
25+
)
26+
27+
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
28+
29+
dataset = create_dataset(cfg)
30+
model = QACEnsemble(**cfg.policy.model)
31+
policy = EDACPolicy(cfg.policy, model=model)
32+
33+
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
34+
task.use(offline_data_fetcher(cfg, dataset))
35+
task.use(trainer(cfg, policy.learn_mode))
36+
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1e4))
37+
task.use(offline_logger())
38+
task.run()
39+
40+
41+
if __name__ == "__main__":
42+
main()

ding/model/common/__init__.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
22
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
3-
independent_normal_dist, AttentionPolicyHead, PopArtVHead
3+
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
44
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
55
from .utils import create_model

ding/model/common/head.py

100644100755
Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn.functional as F
77
from torch.distributions import Normal, Independent
88

9-
from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt
9+
from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt, conv1d_block
1010
from ding.rl_utils import beta_function_map
1111
from ding.utils import lists_to_dicts, SequenceType
1212

@@ -1316,6 +1316,79 @@ def forward(self, x: torch.Tensor) -> Dict:
13161316
return lists_to_dicts([m(x) for m in self.pred])
13171317

13181318

1319+
class EnsembleHead(nn.Module):
1320+
"""
1321+
Overview:
1322+
The ``EnsembleHead`` used to output action Q-value for Q-ensemble. \
1323+
Input is a (:obj:`torch.Tensor`) of shape ''(B, N * ensemble_num, 1)'' and returns a (:obj:`Dict`) containing \
1324+
output ``pred``.
1325+
Interfaces:
1326+
``__init__``, ``forward``.
1327+
"""
1328+
1329+
def __init__(
1330+
self,
1331+
input_size: int,
1332+
output_size: int,
1333+
hidden_size: int,
1334+
layer_num: int,
1335+
ensemble_num: int,
1336+
activation: Optional[nn.Module] = nn.ReLU(),
1337+
norm_type: Optional[str] = None
1338+
) -> None:
1339+
super(EnsembleHead, self).__init__()
1340+
d = input_size
1341+
layers = []
1342+
for _ in range(layer_num):
1343+
layers.append(
1344+
conv1d_block(
1345+
d * ensemble_num,
1346+
hidden_size * ensemble_num,
1347+
kernel_size=1,
1348+
stride=1,
1349+
groups=ensemble_num,
1350+
activation=activation,
1351+
norm_type=norm_type
1352+
)
1353+
)
1354+
d = hidden_size
1355+
1356+
# Adding activation for last layer will lead to train fail
1357+
layers.append(
1358+
conv1d_block(
1359+
hidden_size * ensemble_num,
1360+
output_size * ensemble_num,
1361+
kernel_size=1,
1362+
stride=1,
1363+
groups=ensemble_num,
1364+
activation=None,
1365+
norm_type=None
1366+
)
1367+
)
1368+
self.pred = nn.Sequential(*layers)
1369+
1370+
def forward(self, x: torch.Tensor) -> Dict:
1371+
"""
1372+
Overview:
1373+
Use encoded embedding tensor to run MLP with ``EnsembleHead`` and return the prediction dictionary.
1374+
Arguments:
1375+
- x (:obj:`torch.Tensor`): Tensor containing input embedding.
1376+
Returns:
1377+
- outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`).
1378+
Shapes:
1379+
- x: :math:`(B, N * ensemble_num, 1)`, where ``B = batch_size`` and ``N = hidden_size``.
1380+
- pred: :math:`(B, M * ensemble_num, 1)`, where ``M = output_size``.
1381+
Examples:
1382+
>>> head = EnsembleHead(64 * 10, 64 * 10)
1383+
>>> inputs = torch.randn(4, 64 * 10, 1) `
1384+
>>> outputs = head(inputs)
1385+
>>> assert isinstance(outputs, dict)
1386+
>>> assert outputs['pred'].shape == torch.Size([10, 64 * 10])
1387+
"""
1388+
x = self.pred(x).squeeze(-1)
1389+
return {'pred': x}
1390+
1391+
13191392
def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Distribution:
13201393
if isinstance(logits, (list, tuple)):
13211394
return Independent(Normal(*logits), 1)
@@ -1341,4 +1414,5 @@ def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Di
13411414
'popart': PopArtVHead,
13421415
# multi
13431416
'multi': MultiHead,
1417+
'ensemble': EnsembleHead,
13441418
}

ding/model/common/tests/test_head.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pytest
44

5-
from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead
5+
from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead
66
from ding.torch_utils import is_differentiable
77

88
B = 4
@@ -84,3 +84,10 @@ def test_stochastic_dueling(self):
8484
assert isinstance(sigma.grad, torch.Tensor)
8585
assert outputs['q_value'].shape == (B, 1)
8686
assert outputs['v_value'].shape == (B, 1)
87+
88+
def test_ensemble(self):
89+
inputs = torch.randn(B, embedding_dim * 3, 1)
90+
model = EnsembleHead(embedding_dim, action_shape, 3, 3, 3)
91+
outputs = model(inputs)['pred']
92+
self.output_check(model, outputs)
93+
assert outputs.shape == (B, action_shape * 3, 1)

ding/model/template/__init__.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .vae import VanillaVAE
2424
from .decision_transformer import DecisionTransformer
2525
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
26+
from .edac import QACEnsemble

ding/model/template/edac.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from typing import Union, Optional, Dict
2+
from easydict import EasyDict
3+
4+
import torch
5+
import torch.nn as nn
6+
from ding.model.common import ReparameterizationHead, EnsembleHead
7+
from ding.utils import SequenceType, squeeze
8+
9+
from ding.utils import MODEL_REGISTRY
10+
11+
12+
@MODEL_REGISTRY.register('edac')
13+
class QACEnsemble(nn.Module):
14+
r"""
15+
Overview:
16+
The QAC network with ensemble, which is used in EDAC.
17+
Interfaces:
18+
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
19+
"""
20+
mode = ['compute_actor', 'compute_critic']
21+
22+
def __init__(
23+
self,
24+
obs_shape: Union[int, SequenceType],
25+
action_shape: Union[int, SequenceType, EasyDict],
26+
ensemble_num: int = 2,
27+
actor_head_hidden_size: int = 64,
28+
actor_head_layer_num: int = 1,
29+
critic_head_hidden_size: int = 64,
30+
critic_head_layer_num: int = 1,
31+
activation: Optional[nn.Module] = nn.ReLU(),
32+
norm_type: Optional[str] = None,
33+
**kwargs
34+
) -> None:
35+
"""
36+
Overview:
37+
Initailize the EDAC Model according to input arguments.
38+
Arguments:
39+
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
40+
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
41+
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
42+
- ensemble_num (:obj:`int`): Q-net number.
43+
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
44+
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
45+
for actor head.
46+
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
47+
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
48+
for critic head.
49+
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
50+
after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
51+
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
52+
see ``ding.torch_utils.network`` for more details.
53+
"""
54+
super(QACEnsemble, self).__init__()
55+
obs_shape: int = squeeze(obs_shape)
56+
action_shape = squeeze(action_shape)
57+
self.action_shape = action_shape
58+
self.ensemble_num = ensemble_num
59+
self.actor = nn.Sequential(
60+
nn.Linear(obs_shape, actor_head_hidden_size), activation,
61+
ReparameterizationHead(
62+
actor_head_hidden_size,
63+
action_shape,
64+
actor_head_layer_num,
65+
sigma_type='conditioned',
66+
activation=activation,
67+
norm_type=norm_type
68+
)
69+
)
70+
71+
critic_input_size = obs_shape + action_shape
72+
self.critic = EnsembleHead(
73+
critic_input_size,
74+
1,
75+
critic_head_hidden_size,
76+
critic_head_layer_num,
77+
self.ensemble_num,
78+
activation=activation,
79+
norm_type=norm_type
80+
)
81+
82+
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
83+
"""
84+
Overview:
85+
The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \
86+
different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC.
87+
Mode compute_actor:
88+
Arguments:
89+
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
90+
Returns:
91+
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
92+
Mode compute_critic:
93+
Arguments:
94+
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
95+
Returns:
96+
- output (:obj:`Dict`): Output dict data, including q_value tensor.
97+
.. note::
98+
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
99+
"""
100+
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
101+
return getattr(self, mode)(inputs)
102+
103+
def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
104+
"""
105+
Overview:
106+
The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
107+
such as ``action``, ``logit`` and so on.
108+
Arguments:
109+
- obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \
110+
i.e. ``(B, obs_shape)``.
111+
Returns:
112+
- outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \
113+
from action_space: ``reparameterization``.
114+
ReturnsKeys (either):
115+
- logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC.
116+
- mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution.
117+
- sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution.
118+
Shapes:
119+
- obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
120+
- action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
121+
- logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
122+
- logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
123+
- logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
124+
``action_shape.action_type_shape``.
125+
- action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
126+
``action_shape.action_args_shape``.
127+
Examples:
128+
>>> model = QACEnsemble(64, 64,)
129+
>>> obs = torch.randn(4, 64)
130+
>>> actor_outputs = model(obs,'compute_actor')
131+
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu
132+
>>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
133+
"""
134+
x = self.actor(obs)
135+
return {'logit': [x['mu'], x['sigma']]}
136+
137+
def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
138+
"""
139+
Overview:
140+
The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
141+
output, such as ``q_value``.
142+
Arguments:
143+
- inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \
144+
``action`` tensor
145+
Returns:
146+
- outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``.
147+
ArgumentsKeys:
148+
- obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
149+
- action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
150+
ReturnKeys:
151+
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
152+
Shapes:
153+
- obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \
154+
``obs_shape``.
155+
- action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \
156+
is ``action_shape``.
157+
- q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size.
158+
Examples:
159+
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
160+
>>> model = EDAC(obs_shape=(8, ),action_shape=1)
161+
>>> model(inputs, mode='compute_critic')['q_value'] # q value
162+
... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
163+
"""
164+
165+
obs, action = inputs['obs'], inputs['action']
166+
if len(action.shape) == 1: # (B, ) -> (B, 1)
167+
action = action.unsqueeze(1)
168+
x = torch.cat([obs, action], dim=-1)
169+
if len(obs.shape) < 3:
170+
# [batch_size,dim] -> [batch_size,Ensemble_num * dim,1]
171+
x = x.repeat(1, self.ensemble_num).unsqueeze(-1)
172+
else:
173+
# [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1]
174+
x = x.transpose(0, 1)
175+
batch_size = obs.shape[1]
176+
x = x.reshape(batch_size, -1, 1)
177+
# [Ensemble_num,batch_size,1]
178+
x = self.critic(x)['pred']
179+
# [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size]
180+
x = x.permute(1, 0)
181+
return {'q_value': x}

ding/model/template/qac.py

100644100755
File mode changed.

ding/policy/__init__.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy
1919
from .sac import SACPolicy, SACDiscretePolicy, SQILSACPolicy
2020
from .cql import CQLPolicy, CQLDiscretePolicy
21+
from .edac import EDACPolicy
2122
from .impala import IMPALAPolicy
2223
from .ngu import NGUPolicy
2324
from .r2d2 import R2D2Policy

ding/policy/command_mode_policy_instance.py

100644100755
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .sac import SQILSACPolicy
4848
from .madqn import MADQNPolicy
4949
from .bdq import BDQPolicy
50+
from .edac import EDACPolicy
5051

5152

5253
class EpsCommandModePolicy(CommandModePolicy):
@@ -381,6 +382,11 @@ class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy):
381382
pass
382383

383384

385+
@POLICY_REGISTRY.register('edac_command')
386+
class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy):
387+
pass
388+
389+
384390
@POLICY_REGISTRY.register('bc_command')
385391
class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy):
386392

0 commit comments

Comments
 (0)