diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py old mode 100644 new mode 100755 index 11cccf0e13..bd4b6baa09 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -26,3 +26,4 @@ from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer from .serial_entry_bco import serial_pipeline_bco from .serial_entry_pc import serial_pipeline_pc +from .serial_entry_meta_offline import serial_pipeline_meta_offline diff --git a/ding/entry/serial_entry_meta_offline.py b/ding/entry/serial_entry_meta_offline.py new file mode 100755 index 0000000000..46a59e0845 --- /dev/null +++ b/ding/entry/serial_entry_meta_offline.py @@ -0,0 +1,121 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialMetaEvaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_world_size, get_rank +from ding.utils.data import create_dataset + +def serial_pipeline_meta_offline( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry. In meta pipeline, policy is trained using multiple tasks \ + and evaluates multiple tasks specified. Evaluation value is mean of every tasks. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + '_command' + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + cfg.env['seed'] = seed + + # Dataset + dataset = create_dataset(cfg) + + sampler, shuffle = None, True + if get_world_size() > 1: + sampler, shuffle = DistributedSampler(dataset), False + dataloader = DataLoader( + dataset, + # Dividing by get_world_size() here simply to make multigpu + # settings mathmatically equivalent to the singlegpu setting. + # If the training efficiency is the bottleneck, feel free to + # use the original batch size per gpu and increase learning rate + # correspondingly. + cfg.policy.learn.batch_size // get_world_size(), + shuffle=shuffle, + sampler=sampler, + collate_fn=lambda x: x, + pin_memory=cfg.policy.cuda, + ) + + # Env, policy + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) + + if hasattr(policy, 'set_statistic'): + # useful for setting action bounds for ibc + policy.set_statistic(dataset.statistics) + + if cfg.policy.need_init_dataprocess: + policy.init_dataprocess_func(dataset) + + if get_rank() == 0: + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + else: + tb_logger = None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + evaluator = InteractionSerialMetaEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator.init_params(dataset.params) + + learner.call_hook('before_run') + stop = False + + for epoch in range(cfg.policy.learn.train_epoch): + if get_world_size() > 1: + dataloader.sampler.set_epoch(epoch) + # for every train task, train policy with its dataset + for i in range(cfg.policy.train_num): + dataset.set_task_id(i) + for train_data in dataloader: + learner.train(train_data) + + # Evaluate policy at most once per epoch. + if evaluator.should_eval(learner.train_iter): + if hasattr(policy, 'warm_train'): + # if algorithm need warm train + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, + policy_warm_func=policy.warm_train, need_reward=cfg.policy.need_reward) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, + need_reward=cfg.policy.need_reward) + + if stop or learner.train_iter >= max_train_iter: + stop = True + break + + learner.call_hook('after_run') + print('final reward is: {}'.format(reward)) + return policy, stop \ No newline at end of file diff --git a/ding/envs/env_manager/__init__.py b/ding/envs/env_manager/__init__.py old mode 100644 new mode 100755 index 62d45baf27..415f63bc3d --- a/ding/envs/env_manager/__init__.py +++ b/ding/envs/env_manager/__init__.py @@ -1,5 +1,6 @@ from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls -from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2 +from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2,\ + MetaSyncSubprocessEnvManager from .gym_vector_env_manager import GymVectorEnvManager # Do not import PoolEnvManager here, because it depends on installation of `envpool` from .env_supervisor import EnvSupervisor diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py old mode 100644 new mode 100755 index 5a391f3932..2bf2219bad --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -832,3 +832,23 @@ def step(self, actions: Union[List[tnp.ndarray], tnp.ndarray]) -> List[tnp.ndarr info = remove_illegal_item(info) new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) return new_data + +@ENV_MANAGER_REGISTRY.register('meta_subprocess') +class MetaSyncSubprocessEnvManager(SyncSubprocessEnvManager): + + @property + def method_name_list(self) -> list: + return [ + 'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure', + 'set_all_goals', 'reset_task' + ] + + def set_all_goals(self, params): + for p in self._pipe_parents.values(): + p.send(['set_all_goals', [params], {}]) + data = {i: p.recv() for i, p in self._pipe_parents.items()} + + def reset_task(self, id): + for p in self._pipe_parents.values(): + p.send(['reset_task', [id], {}]) + data = {i: p.recv() for i, p in self._pipe_parents.items()} diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py old mode 100644 new mode 100755 index 3d35497383..d1cb9133a1 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ding.utils import SequenceType +from ding.utils import SequenceType, MODEL_REGISTRY class MaskedCausalAttention(nn.Module): @@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # x = x + self.mlp(self.ln2(x)) return x - +@MODEL_REGISTRY.register('dt') class DecisionTransformer(nn.Module): """ Overview: @@ -176,7 +176,8 @@ def __init__( drop_p: float, max_timestep: int = 4096, state_encoder: Optional[nn.Module] = None, - continuous: bool = False + continuous: bool = False, + use_prompt: bool = False, ): """ Overview: @@ -206,6 +207,9 @@ def __init__( # projection heads (project to embedding) self.embed_ln = nn.LayerNorm(h_dim) self.embed_timestep = nn.Embedding(max_timestep, h_dim) + if use_prompt: + self.prompt_embed_timestep = nn.Embedding(max_timestep, h_dim) + input_seq_len *= 2 self.drop = nn.Dropout(drop_p) self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) @@ -218,14 +222,21 @@ def __init__( self.embed_state = torch.nn.Linear(state_dim, h_dim) self.predict_rtg = torch.nn.Linear(h_dim, 1) self.predict_state = torch.nn.Linear(h_dim, state_dim) + if use_prompt: + self.prompt_embed_state = torch.nn.Linear(state_dim, h_dim) + self.prompt_embed_rtg = torch.nn.Linear(1, h_dim) if continuous: # continuous actions self.embed_action = torch.nn.Linear(act_dim, h_dim) use_action_tanh = True # True for continuous actions + if use_prompt: + self.prompt_embed_action = torch.nn.Linear(act_dim, h_dim) else: # discrete actions self.embed_action = torch.nn.Embedding(act_dim, h_dim) use_action_tanh = False # False for discrete actions + if use_prompt: + self.prompt_embed_action = torch.nn.Embedding(act_dim, h_dim) self.predict_action = nn.Sequential( *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) ) @@ -243,7 +254,8 @@ def forward( states: torch.Tensor, actions: torch.Tensor, returns_to_go: torch.Tensor, - tar: Optional[int] = None + tar: Optional[int] = None, + prompt: dict = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Overview: @@ -299,7 +311,30 @@ def forward( t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) h = self.embed_ln(t_p) + + if prompt is not None: + prompt_states, prompt_actions, prompt_returns_to_go,\ + prompt_timesteps, prompt_attention_mask = prompt + prompt_seq_length = prompt_states.shape[1] + prompt_state_embeddings = self.prompt_embed_state(prompt_states) + prompt_action_embeddings = self.prompt_embed_action(prompt_actions) + if prompt_returns_to_go.shape[1] % 10 == 1: + prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go[:,:-1]) + else: + prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go) + prompt_time_embeddings = self.prompt_embed_timestep(prompt_timesteps) + + prompt_state_embeddings = prompt_state_embeddings + prompt_time_embeddings + prompt_action_embeddings = prompt_action_embeddings + prompt_time_embeddings + prompt_returns_embeddings = prompt_returns_embeddings + prompt_time_embeddings + prompt_stacked_inputs = torch.stack( + (prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 + ).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) + + h = torch.cat((prompt_stacked_inputs, h), dim=1) + # transformer and prediction + h = self.transformer(h) # get h reshaped such that its size = (B x 3 x T x h_dim) and # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t @@ -308,11 +343,15 @@ def forward( # that is, for each timestep (t) we have 3 output embeddings from the transformer, # each conditioned on all previous timesteps plus # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. - h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) + if prompt is None: + h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) + else: + h = h.reshape(B, -1, 3, self.h_dim).permute(0, 2, 1, 3) + + return_preds = self.predict_rtg(h[:, 2])[:, -T:, :] # predict next rtg given r, s, a + state_preds = self.predict_state(h[:, 2])[:, -T:, :] # predict next state given r, s, a + action_preds = self.predict_action(h[:, 1])[:, -T:, :] # predict action given r, s - return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a - state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a - action_preds = self.predict_action(h[:, 1]) # predict action given r, s else: state_embeddings = self.state_encoder( states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index f8b48f3061..0d06f94796 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -26,9 +26,20 @@ def default_sample_fn(model, x, cond, t): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, values -def get_guide_output(guide, x, cond, t): +def get_guide_output(guide, x, cond, t, returns=None, is_dynamic=False, act_dim=6): x.requires_grad_() - y = guide(x, cond, t).squeeze(dim=-1) + if returns is not None: + if not is_dynamic: + y = guide(x, cond, t, returns).squeeze(dim=-1) + else: + returns = returns.unsqueeze(1).repeat_interleave(x.shape[1],dim=1) + input = torch.cat([x, returns], dim=-1) + input = input.reshape(-1, input.shape[-1]) + y = guide(input) + y = y.reshape(x.shape[0], x.shape[1], -1) + y = F.mse_loss(x[:, 1:, act_dim:], y[:, :-1], reduction='none') + else: + y = guide(x, cond, t).squeeze(dim=-1) grad = torch.autograd.grad([y.sum()], [x])[0] x.detach() return y, grad @@ -45,6 +56,15 @@ def n_step_guided_p_sample( n_guide_steps=1, scale_grad_by_std=True, ): + """ + Overview: + Guidance fn for Diffusion + Arguments: + - model (obj: 'class') diffusion model + - x (obj: 'tensor') input for guidance + - cond (obj: 'tensor') cond of input + - guide (obj: 'class') guide function + """ model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) model_std = torch.exp(0.5 * model_log_variance) model_var = torch.exp(model_log_variance) @@ -69,6 +89,64 @@ def n_step_guided_p_sample( return model_mean + model_std * noise, y +def free_guidance_sample( + model, + x, + cond, + t, + guide1, + guide2, + returns=None, + scale=1, + t_stopgrad=0, + n_guide_steps=1, + scale_grad_by_std=True, + +): + """ + Overview: + Guidance fn for MetaDiffusion + Arguments: + - model (obj: 'class') diffusion model + - x (obj: 'tensor') input for guidance + - cond (obj: 'tensor') cond of input + - guide1 (obj: 'class') guide function. In MetaDiffusion is reward function + - guide2 (obj: 'class') guide function. In MetaDiffusion is dynamic function + - returns (obj: 'tensor') for MetaDiffusion, it is id for task. + + """ + weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) + model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) + model_std = torch.exp(0.5 * model_log_variance) + model_var = torch.exp(model_log_variance) + + for _ in range(n_guide_steps): + with torch.enable_grad(): + y1, grad1 = get_guide_output(guide1, x, cond, t, returns) # get reward + y2, grad2 = get_guide_output(guide2, x, cond, t, returns, is_dynamic=True, + act_dim=model.action_dim) # get state + grad = grad1 + scale * grad2 + + if scale_grad_by_std: + grad = model_var * grad + + grad[t < t_stopgrad] = 0 + + if model.returns_condition: + # epsilon could be epsilon or x0 itself + epsilon_cond = model.model(x, cond, t, returns, use_dropout=False) + epsilon_uncond = model.model(x, cond, t, returns, force_dropout=True) + epsilon = epsilon_uncond + model.condition_guidance_w * (epsilon_cond - epsilon_uncond) + else: + epsilon = model.model(x, cond, t) + epsilon -= weight * grad + + model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, epsilon=epsilon) + # model_std = torch.exp(0.5 * model_log_variance) + noise = torch.randn_like(x) + noise[t == 0] = 0 + + return model_mean + model_std * noise, y1 class GaussianDiffusion(nn.Module): """ @@ -299,23 +377,27 @@ class ValueDiffusion(GaussianDiffusion): Gaussian diffusion model for value function. """ - def p_losses(self, x_start, cond, target, t): + def p_losses(self, x_start, cond, target, t, returns=None): noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) - pred = self.model(x_noisy, cond, t) + pred = self.model(x_noisy, cond, t, returns) loss = F.mse_loss(pred, target, reduction='none').mean() + with torch.no_grad(): + r0_loss = F.mse_loss(pred[:, 0], target[:,0]) log = { 'mean_pred': pred.mean().item(), 'max_pred': pred.max().item(), 'min_pred': pred.min().item(), + 'r0_loss': r0_loss.mean().item(), } + return loss, log - def forward(self, x, cond, t): - return self.model(x, cond, t) + def forward(self, x, cond, t, returns=None): + return self.model(x, cond, t, returns) @MODEL_REGISTRY.register('pd') @@ -567,7 +649,7 @@ def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusio batch_size = shape[0] x = 0.5 * torch.randn(shape, device=device) # In this model, init state must be given by the env and without noise. - x = apply_conditioning(x, cond, 0) + x = apply_conditioning(x, cond, self.action_dim) if return_diffusion: diffusion = [x] @@ -575,7 +657,7 @@ def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusio for i in reversed(range(0, self.n_timesteps)): timesteps = torch.full((batch_size, ), i, device=device, dtype=torch.long) x = self.p_sample(x, cond, timesteps, returns) - x = apply_conditioning(x, cond, 0) + x = apply_conditioning(x, cond, self.action_dim) if return_diffusion: diffusion.append(x) @@ -623,12 +705,12 @@ def p_losses(self, x_start, cond, t, returns=None): noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_noisy = apply_conditioning(x_noisy, cond, 0) + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) x_recon = self.model(x_noisy, cond, t, returns) if not self.predict_epsilon: - x_recon = apply_conditioning(x_recon, cond, 0) + x_recon = apply_conditioning(x_recon, cond, self.action_dim) assert noise.shape == x_recon.shape @@ -643,3 +725,235 @@ def p_losses(self, x_start, cond, t, returns=None): def forward(self, cond, *args, **kwargs): return self.conditional_sample(cond=cond, *args, **kwargs) + +class GuidenceFreeDifffuser(GaussianDiffusion): + """ + Overview: + Gaussian diffusion model with guidence + Arguments: + - model (:obj:`str`): type of model + - model_cfg (:obj:'dict') config of model + - horizon (:obj:`int`): horizon of trajectory + - obs_dim (:obj:`int`): Dim of the ovservation + - action_dim (:obj:`int`): Dim of the ation + - n_timesteps (:obj:`int`): Number of timesteps + - predict_epsilon (:obj:'bool'): Whether predict epsilon + - loss_discount (:obj:'float'): discount of loss + - clip_denoised (:obj:'bool'): Whether use clip_denoised + - action_weight (:obj:'float'): weight of action + - loss_weights (:obj:'dict'): weight of loss + - returns_condition (:obj:'bool') whether use additional condition + - condition_guidance_w (:obj:'float') guidance weight + """ + + def __init__( + self, + model: str, + model_cfg: dict, + horizon: int, + obs_dim: Union[int, SequenceType], + action_dim: Union[int, SequenceType], + n_timesteps: int = 1000, + predict_epsilon: bool = True, + loss_discount: float = 1.0, + clip_denoised: bool = False, + action_weight: float = 1.0, + loss_weights: dict = None, + returns_condition: bool = False, + condition_guidance_w: float = 0.1, + ): + super().__init__(model, model_cfg, horizon, obs_dim, action_dim, n_timesteps, predict_epsilon, + loss_discount, clip_denoised, action_weight, loss_weights,) + self.returns_condition = returns_condition + self.condition_guidance_w = condition_guidance_w + + def p_mean_variance(self, x, cond, t, epsilon): + x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) + + if self.clip_denoised: + x_recon.clamp_(-1., 1.) + else: + assert RuntimeError() + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + def p_sample_loop(self, shape, cond, sample_fn=None, plan_size=1, **sample_kwargs): + device = self.betas.device + + batch_size = shape[0] + x = torch.randn(shape, device=device) + x = apply_conditioning(x, cond, self.action_dim) + + assert sample_fn != None + for i in reversed(range(0, self.n_timesteps)): + t = torch.full((batch_size, ), i, device=device, dtype=torch.long) + x, values = sample_fn(self, x, cond, t, **sample_kwargs) + x = apply_conditioning(x, cond, self.action_dim) + + values = values.reshape(-1, plan_size, *values.shape[1:]) + x = x.reshape(-1, plan_size, *x.shape[1:]) + if plan_size > 1: + inds = torch.argsort(values, dim=1, descending=True) + inds = inds.unsqueeze(-1).expand_as(x) + x = x.gather(1, inds) + return x[:,0] + + @torch.no_grad() + def conditional_sample(self, cond, horizon=None, **sample_kwargs): + device = self.betas.device + batch_size = len(cond[0]) + horizon = horizon or self.horizon + shape = (batch_size, horizon, self.obs_dim + self.action_dim) + return self.p_sample_loop(shape, cond, **sample_kwargs) + + def p_losses(self, x_start, cond, t, returns=None): + noise = torch.randn_like(x_start) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) + + x_recon = self.model(x_noisy, cond, t, returns) + + if not self.predict_epsilon: + x_recon = apply_conditioning(x_recon, cond, self.action_dim) + + assert noise.shape == x_recon.shape + + if self.predict_epsilon: + loss = F.mse_loss(x_recon, noise, reduction='none') + a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean() + loss = (loss * self.loss_weights.to(loss.device)).mean() + else: + loss = F.mse_loss(x_recon, x_start, reduction='none') + a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean() + loss = (loss * self.loss_weights.to(loss.device)).mean() + return loss, a0_loss + + +@MODEL_REGISTRY.register('metadiffuser') +class MetaDiffuser(nn.Module): + """ + Overview: + MetaDiffusion model + Arguments: + - dim (:obj:`int`): dim of emb and dynamic model + - obs_dim (:obj:`int`): Dim of the ovservation + - action_dim (:obj:`int`): Dim of the ation + - reward_cfg (:obj:'dict') config of reward model + - diffuser_model_cfg (:obj:'dict') config of diffuser_model + - horizon (:obj:`int`): horizon of trajectory + - encoder_horizon (:obj:`int`): horizon of emb model + - sample_kwargs : config of sample function + """ + def __init__( + self, + dim: int, + obs_dim: Union[int, SequenceType], + action_dim: Union[int, SequenceType], + reward_cfg: dict, + diffuser_model_cfg: dict, + horizon: int, + encoder_horizon: int, + **sample_kwargs, + ): + super().__init__() + + self.obs_dim = obs_dim + self.action_dim = action_dim + self.horizon = horizon + self.sample_kwargs = sample_kwargs + self.encoder_horizon = encoder_horizon + + self.embed = nn.Sequential( + nn.Linear((obs_dim * 2 + action_dim + 1) * encoder_horizon, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim) + ) + + self.reward_model = ValueDiffusion(**reward_cfg) + + self.dynamic_model = nn.Sequential( + nn.Linear(obs_dim + action_dim + dim, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, obs_dim), + ) + + self.diffuser = GuidenceFreeDifffuser(**diffuser_model_cfg) + + def get_task_id(self, traj): + """ + Overview: + get task id for trajectory + Arguments: + - traj (:obj:'tensor') trajectory of env + """ + input_emb = traj.reshape(traj.shape[0], -1) + task_idx = self.embed(input_emb) + return task_idx + + def diffuser_loss(self, x_start, cond, t, returns=None): + return self.diffuser.p_losses(x_start, cond, t, returns) + + def pre_train_loss(self, traj, target, t, cond): + """ + Overview: + train dynamic, reward and embed model. + Arguments: + - traj (:obj:'tensor') traj for dataset, include: obs, next_obs, action, reward + - target (:obj:'tensor') target obs and rerward + - t (:obj:'int') step + - cond (:obj:'tensor') condition of input + """ + encoder_traj = traj[:, :self.encoder_horizon] + input_emb = encoder_traj.reshape(target.shape[0], -1) + task_idx = self.embed(input_emb) + + states = traj[:, :, self.action_dim:self.action_dim + self.obs_dim] + actions = traj[:, :, :self.action_dim] + input = torch.cat([actions, states], dim=-1) + target_reward = target[:, :, -1] + + target_next_state = target[:, :, :self.obs_dim].reshape(-1, self.obs_dim) + + reward_loss, reward_log = self.reward_model.p_losses(input, cond, target_reward, t, task_idx) + + + task_idxs = task_idx.unsqueeze(1).repeat_interleave(self.horizon, dim=1) + + input = torch.cat([input, task_idxs], dim=-1) + input = input.reshape(-1, input.shape[-1]) + + next_state = self.dynamic_model(input) + state_loss = F.mse_loss(next_state, target_next_state, reduction='none').mean() + + return state_loss, reward_loss, reward_log + + def get_eval(self, cond, id = None, batch_size = 1): + """ + Overview: + get action + Arguments: + - cond (:obj:'tensor') condition for sample + - id (:obj:'tensor') id for task. + """ + id = torch.stack(id, dim=0) + if batch_size > 1: + cond = self.repeat_cond(cond, batch_size) + id = id.unsqueeze(1).repeat_interleave(batch_size, dim=1) + id = id.reshape(-1, id.shape[-1]) + + samples = self.diffuser(cond, returns=id, sample_fn=free_guidance_sample, plan_size=batch_size, + guide1=self.reward_model, guide2=self.dynamic_model, **self.sample_kwargs) + return samples[:, 0, :self.action_dim] + + def repeat_cond(self, cond, batch_size): + for k, v in cond.items(): + cond[k] = v.repeat_interleave(batch_size, dim=0) + return cond diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2d5e3271dd..305c757901 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -51,6 +51,8 @@ from .edac import EDACPolicy from .prompt_pg import PromptPGPolicy from .plan_diffuser import PDPolicy +from .meta_diffuser import MDPolicy +from .prompt_dt import PDTPolicy class EpsCommandModePolicy(CommandModePolicy): @@ -449,3 +451,11 @@ def _get_setting_eval(self, command_info: dict) -> dict: @POLICY_REGISTRY.register('prompt_pg_command') class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): pass + +@POLICY_REGISTRY.register('metadiffuser_command') +class MDCommandModePolicy(MDPolicy, DummyCommandModePolicy): + pass + +@POLICY_REGISTRY.register('promptdt_command') +class PDTCommandModePolicy(PDTPolicy, DummyCommandModePolicy): + pass diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py new file mode 100755 index 0000000000..cabeaa2f5d --- /dev/null +++ b/ding/policy/meta_diffuser.py @@ -0,0 +1,446 @@ +from typing import List, Dict, Any, Optional, Tuple, Union +from collections import namedtuple, defaultdict +import copy +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Normal, Independent + +from ding.torch_utils import Adam, to_device +from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ + qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data +from ding.policy import Policy +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY, DatasetNormalizer +from ding.utils.data import default_collate, default_decollate +from .common_utils import default_preprocess_learn + +@POLICY_REGISTRY.register('metadiffuser') +class MDPolicy(Policy): + r""" + Overview: + Implicit Meta Diffuser + https://arxiv.org/pdf/2305.19923.pdf + + """ + config = dict( + type='pd', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool type) priority: Determine whether to use priority in buffer sample. + # Default False in SAC. + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (int) Number of training samples(randomly collected) in replay buffer when training starts. + # Default 10000 in SAC. + random_collect_size=10000, + nstep=1, + # normalizer type + normalizer='GaussianNormalizer', + model=dict( + dim=64, + obs_dim=17, + action_dim=6, + diffuser_model_cfg=dict( + # the type of model + # config of model + model_cfg=dict( + # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim + transition_dim=23, + dim=32, + dim_mults=[1, 2, 4, 8], + # whether use return as a condition + returns_condition=True, + condition_dropout=0.1, + # whether use calc energy + calc_energy=False, + kernel_size=5, + # whether use attention + attention=False, + ), + # horizon of tarjectory which generated by model + horizon=80, + # timesteps of diffusion + n_timesteps=1000, + # hidden dim of action model + # Whether predict epsilon + predict_epsilon=True, + # discount of loss + loss_discount=1.0, + # whether clip denoise + clip_denoised=False, + ), + reward_cfg=dict( + # the type of model + model='TemporalValue', + # config of model + model_cfg=dict( + horizon=4, + # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim + transition_dim=23, + dim=32, + dim_mults=[1, 2, 4, 8], + # whether use calc energy + kernel_size=5, + ), + # horizon of tarjectory which generated by model + horizon=80, + # timesteps of diffusion + n_timesteps=1000, + # hidden dim of action model + predict_epsilon=True, + # discount of loss + loss_discount=1.0, + # whether clip denoise + clip_denoised=False, + action_weight=1.0, + ), + horizon=80, + # guide_steps for p sample + n_guide_steps=2, + # scale of grad for p sample + scale=1, + # t of stopgrad for p sample + t_stopgrad=2, + # whether use std as a scale for grad + scale_grad_by_std=True, + ), + learn=dict( + + # How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + update_per_collect=1, + # (int) Minibatch size for gradient descent. + batch_size=100, + + # (float type) learning_rate_q: Learning rate for model. + # Default to 3e-4. + # Please set to 1e-3, when model.value_network is True. + learning_rate=3e-4, + # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. + ignore_done=False, + + # (float type) target_theta: Used for soft update of the target network, + # aka. Interpolation factor in polyak averaging for target networks. + # Default to 0.005. + target_theta=0.005, + # (float) discount factor for the discounted sum of rewards, aka. gamma. + discount_factor=0.99, + gradient_accumulate_every=2, + # train_epoch = train_epoch * gradient_accumulate_every + train_epoch=60000, + + # step start update target model and frequence + step_start_update_target=2000, + update_target_freq=10, + # update weight of target net + target_weight=0.995, + value_step=2e3, + + # dataset weight include returns + include_returns=True, + + # (float) Weight uniform initialization range in the last output layer + init_w=3e-3, + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'metadiffuser', ['ding.model.template.diffusion'] + + def _init_learn(self) -> None: + r""" + Overview: + Learn mode init method. Called by ``self.__init__``. + Init q, value and policy's optimizers, algorithm config, main and target models. + """ + # Init + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self.action_dim = self._cfg.model.diffuser_model_cfg.action_dim + self.obs_dim = self._cfg.model.diffuser_model_cfg.obs_dim + self.n_timesteps = self._cfg.model.diffuser_model_cfg.n_timesteps + self.gradient_accumulate_every = self._cfg.learn.gradient_accumulate_every + self.gradient_steps = 1 + self.update_target_freq = self._cfg.learn.update_target_freq + self.step_start_update_target = self._cfg.learn.step_start_update_target + self.target_weight = self._cfg.learn.target_weight + self.value_step = self._cfg.learn.value_step + self.horizon = self._cfg.model.diffuser_model_cfg.horizon + self.include_returns = self._cfg.learn.include_returns + self.eval_batch_size = self._cfg.learn.eval_batch_size + self.warm_batch_size = self._cfg.learn.warm_batch_size + self.test_num = self._cfg.learn.test_num + self.have_train = False + self._forward_learn_cnt = 0 + self.encoder_len = self._cfg.learn.encoder_len + + self._plan_optimizer = Adam( + self._model.diffuser.model.parameters(), + lr=self._cfg.learn.learning_rate, + ) + + self._pre_train_optimizer = Adam( + list(self._model.reward_model.model.parameters()) + list(self._model.embed.parameters()) \ + + list(self._model.dynamic_model.parameters()), + lr=self._cfg.learn.learning_rate, + ) + + self._gamma = self._cfg.learn.discount_factor + + self._target_model = copy.deepcopy(self._model) + + self._learn_model = model_wrap(self._model, wrapper_name='base') + self._learn_model.reset() + + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + self.have_train = True + loss_dict = {} + + if self._cuda: + data = to_device(data, self._device) + + obs, acts, rewards, cond_ids, cond_vals = [], [], [], [], [] + for d in data: + timesteps, ob, act, reward, rtg, masks, cond_id, cond_val = d + obs.append(ob) + acts.append(act) + rewards.append(reward) + cond_ids.append(cond_id) + cond_vals.append(cond_val) + + obs = torch.stack(obs, dim=0) + acts = torch.stack(acts, dim=0) + rewards = torch.stack(rewards, dim=0) + cond_vals = torch.stack(cond_vals, dim=0) + + obs, next_obs = obs[:,:-1], obs[:,1:] + acts = acts[:,:-1] + rewards = rewards[:,:-1] + conds = {cond_ids[0]: cond_vals} + + + self._learn_model.train() + pre_traj = torch.cat([acts, obs, rewards, next_obs], dim=-1).to(self._device) + target = torch.cat([next_obs, rewards], dim=-1).to(self._device) + traj = torch.cat([acts, obs], dim=-1).to(self._device) + + batch_size = len(traj) + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=traj.device).long() + if self._forward_learn_cnt < self.value_step: + state_loss, reward_loss, reward_log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + loss_dict = {'dynamic_loss': state_loss, 'reward_loss': reward_loss} + loss_dict.update(reward_log) + total_loss = (state_loss + reward_loss) / self.gradient_accumulate_every + total_loss.backward() + + task_id = self._learn_model.get_task_id(pre_traj[:, :self.encoder_len]) + + diffuser_loss, a0_loss = self._learn_model.diffuser_loss(traj, conds, t, task_id) + loss_dict['diffuser_loss'] = diffuser_loss + loss_dict['a0_loss'] = a0_loss + diffuser_loss = diffuser_loss / self.gradient_accumulate_every + diffuser_loss.backward() + loss_dict['max_return'] = reward.max().item() + loss_dict['min_return'] = reward.min().item() + loss_dict['mean_return'] = reward.mean().item() + if self.gradient_steps >= self.gradient_accumulate_every: + self._plan_optimizer.step() + self._plan_optimizer.zero_grad() + if self._forward_learn_cnt < self.value_step: + self._pre_train_optimizer.step() + self._pre_train_optimizer.zero_grad() + self.gradient_steps = 1 + else: + self.gradient_steps += 1 + + self._forward_learn_cnt += 1 + if self._forward_learn_cnt % self.update_target_freq == 0: + if self._forward_learn_cnt < self.step_start_update_target: + self._target_model.load_state_dict(self._model.state_dict()) + else: + self.update_model_average(self._target_model, self._learn_model) + + return loss_dict + + + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + if old_weight is None: + ma_params.data = up_weight + else: + old_weight * self.target_weight + (1 - self.target_weight) * up_weight + + def init_dataprocess_func(self, dataloader: torch.utils.data.Dataset): + self.dataloader = dataloader + + def _monitor_vars_learn(self) -> List[str]: + return [ + 'diffuser_loss', + 'reward_loss', + 'dynamic_loss', + 'a0_loss', + 'max_return', + 'min_return', + 'mean_return', + 'mean_pred', + 'max_pred', + 'min_pred', + 'r0_loss', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'plan_optimizer': self._plan_optimizer.state_dict(), + 'pre_train_optimizer': self._pre_train_optimizer.state_dict(), + } + + def _init_eval(self): + self._eval_model = model_wrap(self._target_model, wrapper_name='base') + self._eval_model.reset() + self.task_id = None + self.test_task_id = [[] for _ in range(self.eval_batch_size)] + # self.task_id = [0] * self.eval_batch_size + + + # obs, acts, rewards, cond_ids, cond_vals = \ + # self.dataloader.get_pretrain_data(self.task_id[0], self.warm_batch_size * self.eval_batch_size) + # obs = to_device(obs, self._device) + # acts = to_device(acts, self._device) + # rewards = to_device(rewards, self._device) + # cond_vals = to_device(cond_vals, self._device) + + # obs, next_obs = obs[:-1], obs[1:] + # acts = acts[:-1] + # rewards = rewards[:-1] + # pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=1) + # target = torch.cat([next_obs, rewards], dim=1) + # batch_size = len(pre_traj) + # conds = {cond_ids: cond_vals} + + # t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + # state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + # total_loss = state_loss + reward_loss + # self._pre_train_optimizer.zero() + # total_loss.backward() + # self._pre_train_optimizer.step() + # self.update_model_average(self._target_model, self._learn_model) + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + + self._eval_model.eval() + obs = [] + for i in range(len(data)): + if not self._cfg.no_state_normalize: + obs.append(self.dataloader.normalize(data[i], 'obs', self.task_id[i])) + + with torch.no_grad(): + obs = torch.stack(obs, dim=0) + if self._cuda: + obs = to_device(obs, self._device) + conditions = {0: obs} + action = self._eval_model.get_eval(conditions, self.test_task_id[:len(data)], self._cfg.learn.plan_batch_size) + if self._cuda: + action = to_device(action, 'cpu') + for i in range(len(data)): + if not self._cfg.no_action_normalize: + action[i] = self.dataloader.unnormalize(action[i], 'actions', self.task_id[i]) + action = torch.tensor(action).to('cpu') + output = {'action': action} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def warm_train(self, id: int): + self.task_id = [id] * self.eval_batch_size + obs, acts, rewards, cond_ids, cond_vals = \ + self.dataloader.get_pretrain_data(id, self.warm_batch_size) + obs = to_device(obs, self._device) + acts = to_device(acts, self._device) + rewards = to_device(rewards, self._device) + cond_vals = to_device(cond_vals, self._device) + + obs, next_obs = obs[:, :-1], obs[:, 1:] + acts = acts[:, :-1] + rewards = rewards[:, :-1] + + pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=-1) + target = torch.cat([next_obs, rewards], dim=-1) + batch_size = len(pre_traj) + conds = {cond_ids: cond_vals} + + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + total_loss = state_loss + reward_loss + self._pre_train_optimizer.zero_grad() + total_loss.backward() + self._pre_train_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) + + self.test_task_id = [self._target_model.get_task_id(pre_traj[:, :self.encoder_len])[0]] * self.eval_batch_size + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + if self.have_train: + if self.task_id is None: + self.task_id = [0] * self.eval_batch_size + # if data_id is None: + # data_id = list(range(self.eval_batch_size)) + # if self.task_id is not None: + # for id in data_id: + # self.task_id[id] = (self.task_id[id] + 1) % self.test_num + # else: + # self.task_id = [0] * self.eval_batch_size + + # for id in data_id: + # obs, acts, rewards, cond_ids, cond_vals = \ + # self.dataloader.get_pretrain_data(self.task_id[id], self.warm_batch_size) + # obs = to_device(obs, self._device) + # acts = to_device(acts, self._device) + # rewards = to_device(rewards, self._device) + # cond_vals = to_device(cond_vals, self._device) + + # obs, next_obs = obs[:, :-1], obs[:, 1:] + # acts = acts[:, :-1] + # rewards = rewards[:, :-1] + + # pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=-1) + # target = torch.cat([next_obs, rewards], dim=-1) + # batch_size = len(pre_traj) + # conds = {cond_ids: cond_vals} + # pre_traj = pre_traj[:, :self.encoder_len] + # target = pre_traj[:, :self.encoder_len] + + # t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + # state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + # total_loss = state_loss + reward_loss + # self._pre_train_optimizer.zero_grad() + # total_loss.backward() + # self._pre_train_optimizer.step() + # self.update_model_average(self._target_model, self._learn_model) + + # self.test_task_id[id] = self._target_model.get_task_id(pre_traj)[0] + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: dict, **kwargs) -> dict: + pass + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + pass + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + pass diff --git a/ding/policy/prompt_dt.py b/ding/policy/prompt_dt.py new file mode 100755 index 0000000000..4306b01b4b --- /dev/null +++ b/ding/policy/prompt_dt.py @@ -0,0 +1,301 @@ +from typing import List, Dict, Any, Tuple, Optional +from collections import namedtuple +import torch.nn.functional as F +import torch +import numpy as np +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_decollate +from ding.policy.dt import DTPolicy +from ding.model import model_wrap + +@POLICY_REGISTRY.register('promptdt') +class PDTPolicy(DTPolicy): + """ + Overview: + Policy class of Decision Transformer algorithm in discrete environments. + Paper link: https://arxiv.org/pdf/2206.13499. + """ + def default_model(self) -> Tuple[str, List[str]]: + return 'dt', ['ding.model.template.decision_transformer'] + + def _init_learn(self) -> None: + super()._init_learn() + self.need_prompt = self._cfg.need_prompt + + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, current learning rate. + Arguments: + - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \ + processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + """ + self._learn_model.train() + self.have_train = True + + if self._cuda: + data = to_device(data, self._device) + + p_s, p_a, p_rtg, p_t, p_mask, timesteps, states, actions, rewards, returns_to_go, \ + traj_mask = [], [], [], [], [], [], [], [], [], [], [] + + for d in data: + if self.need_prompt: + p, timestep, s, a, r, rtg, mask = d + timesteps.append(timestep) + states.append(s) + actions.append(a) + rewards.append(r) + returns_to_go.append(rtg) + traj_mask.append(mask) + ps, pa, prtg, pt, pm = p + p_s.append(ps) + p_a.append(pa) + p_rtg.append(prtg) + p_mask.append(pm) + p_t.append(pt) + else: + timestep, s, a, r, rtg, mask = d + timesteps.append(timestep) + states.append(s) + actions.append(a) + rewards.append(r) + returns_to_go.append(rtg) + traj_mask.append(mask) + + timesteps = torch.stack(timesteps, dim=0) + states = torch.stack(states, dim=0) + actions = torch.stack(actions, dim=0) + rewards = torch.stack(rewards, dim=0) + returns_to_go = torch.stack(returns_to_go, dim=0) + traj_mask = torch.stack(traj_mask, dim=0) + if self.need_prompt: + p_s = torch.stack(p_s, dim=0) + p_a = torch.stack(p_a, dim=0) + p_rtg = torch.stack(p_rtg, dim=0) + p_mask = torch.stack(p_mask, dim=0) + p_t = torch.stack(p_t, dim=0) + prompt = (p_s, p_a, p_rtg, p_t, p_mask) + else: + prompt = None + + # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), + # and we need a 3-dim tensor + if len(returns_to_go.shape) == 2: + returns_to_go = returns_to_go.unsqueeze(-1) + + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, prompt=prompt + ) + + traj_mask = traj_mask.view(-1, ) + + # only consider non padded elements + action_preds = action_preds.reshape(-1, self.act_dim)[traj_mask > 0] + + action_target = actions.reshape(-1, self.act_dim)[traj_mask > 0] + action_loss = F.mse_loss(action_preds, action_target) + + self._optimizer.zero_grad() + action_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) + self._optimizer.step() + self._scheduler.step() + + return { + 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], + 'action_loss': action_loss.detach().cpu().item(), + 'total_loss': action_loss.detach().cpu().item(), + } + + def init_dataprocess_func(self, dataloader): + self.dataloader = dataloader + + def _init_eval(self) -> None: + self.test_num = self._cfg.learn.test_num + self._eval_model = self._model + self.eval_batch_size = self._cfg.evaluator_env_num + self.rtg_target = self._cfg.rtg_target + self.task_id = None + self.test_task_id = [[] for _ in range(self.eval_batch_size)] + self.have_train =False + if self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) + + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + if self.need_prompt: + p_s, p_a, p_rtg, p_t, p_mask = [], [], [], [], [] + for i in range(self.eval_batch_size): + ps, pa, prtg, pt, pm = self.dataloader.get_prompt(is_test=True, id=self.task_id[i]) + p_s.append(ps) + p_a.append(pa) + p_rtg.append(prtg) + p_mask.append(pm) + p_t.append(pt) + p_s = torch.stack(p_s, dim=0).to(self._device) + p_a = torch.stack(p_a, dim=0).to(self._device) + p_rtg = torch.stack(p_rtg, dim=0).to(self._device) + p_mask = torch.stack(p_mask, dim=0).to(self._device) + p_t = torch.stack(p_t, dim=0).to(self._device) + prompt = (p_s, p_a, p_rtg, p_t, p_mask) + else: + prompt = None + + data_id = list(data.keys()) + + self._eval_model.eval() + with torch.no_grad(): + states = torch.zeros( + (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device + ) + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) + if not self._cfg.model.continuous: + actions = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device + ) + else: + actions = torch.zeros( + (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device + ) + rewards_to_go = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device + ) + for i in data_id: + self.states[i, self.t[i]] = self.dataloader.normalize(data[i]['obs'], 'obs', self.task_id[i]) + self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device) + self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] + + if self.t[i] <= self.context_len: + if self._atari_env: + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) + else: + timesteps[i] = self.timesteps[i, :self.context_len] + states[i] = self.states[i, :self.context_len] + actions[i] = self.actions[i, :self.context_len] + rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] + else: + timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + if self._basic_discrete_env: + actions = actions.squeeze(-1) + + _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go, prompt=prompt) + del timesteps, states, actions, rewards_to_go + + logits = act_preds[:, -1, :] + if not self._cfg.model.continuous: + act = torch.argmax(logits, axis=1).unsqueeze(1) + else: + act = logits + for i in data_id: + self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t + self.t[i] += 1 + + if self._cuda: + act = to_device(act, 'cpu') + output = {'action': act} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def warm_train(self, id: int): + self.task_id = [id] * self.eval_batch_size + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + if self.have_train: + if self.task_id is None: + self.task_id = [0] * self.eval_batch_size + + if data_id is None: + self.t = [0 for _ in range(self.eval_batch_size)] + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + if not self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), + dtype=torch.float32, + device=self._device + ) + if self._atari_env: + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + else: + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) + else: + for i in data_id: + self.t[i] = 0 + if not self._cfg.model.continuous: + self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + else: + self.actions[i] = torch.zeros( + (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) + if self._atari_env: + self.states[i] = torch.zeros( + (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target + else: + self.states[i] = torch.zeros( + (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target / self.rtg_scale + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) diff --git a/ding/torch_utils/network/diffusion.py b/ding/torch_utils/network/diffusion.py index 8dfa9d3a14..c26e3f240d 100755 --- a/ding/torch_utils/network/diffusion.py +++ b/ding/torch_utils/network/diffusion.py @@ -35,7 +35,7 @@ def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): return torch.tensor(betas_clipped, dtype=dtype) -def apply_conditioning(x, conditions, action_dim): +def apply_conditioning(x, conditions, action_dim, mask = None): """ Overview: add condition into x @@ -44,7 +44,6 @@ def apply_conditioning(x, conditions, action_dim): x[:, t, action_dim:] = val.clone() return x - class DiffusionConv1d(nn.Module): def __init__( @@ -132,7 +131,6 @@ def __init__(self, dim, eps=1e-5) -> None: self.b = nn.Parameter(torch.zeros(1, dim, 1)) def forward(self, x): - print('x.shape:', x.shape) var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b @@ -214,7 +212,8 @@ def __init__( if in_channels != out_channels else nn.Identity() def forward(self, x, t): - out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) + out = self.blocks[0](x) + out += self.time_mlp(t).unsqueeze(-1) out = self.blocks[1](out) return out + self.residual_conv(x) @@ -225,6 +224,7 @@ def __init__( self, transition_dim: int, dim: int = 32, + returns_dim: int = 1, dim_mults: SequenceType = [1, 2, 4, 8], returns_condition: bool = False, condition_dropout: float = 0.1, @@ -257,7 +257,7 @@ def __init__( act = nn.Mish() self.time_dim = dim - self.returns_dim = dim + self.returns_dim = returns_dim self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), @@ -272,8 +272,6 @@ def __init__( if self.returns_condition: self.returns_mlp = nn.Sequential( - nn.Linear(1, dim), - act, nn.Linear(dim, dim * 4), act, nn.Linear(dim * 4, dim), @@ -323,7 +321,7 @@ def __init__( nn.Conv1d(dim, transition_dim, 1), ) - def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): + def forward(self, x, cond, time, returns = None, use_dropout: bool = True, force_dropout: bool = False): """ Arguments: x (:obj:'tensor'): noise trajectory @@ -382,7 +380,7 @@ def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_d else: return x - def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): + def get_pred(self, x, cond, time, returns = None, use_dropout: bool = True, force_dropout: bool = False): # [batch, horizon, transition ] -> [batch, transition , horizon] x = x.transpose(1, 2) t = self.time_mlp(time) @@ -431,6 +429,7 @@ class TemporalValue(nn.Module): - time_dim (:obj:'): dim of time - dim_mults (:obj:'SequenceType'): mults of dim - kernel_size (:obj:'int'): kernel_size of conv1d + - returns_condition (:obj:'bool'): whether use an additionly condition """ def __init__( @@ -442,6 +441,8 @@ def __init__( out_dim: int = 1, kernel_size: int = 5, dim_mults: SequenceType = [1, 2, 4, 8], + returns_condition: bool = False, + no_need_ret_sin: bool =False, ) -> None: super().__init__() dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] @@ -454,6 +455,21 @@ def __init__( nn.Mish(), nn.Linear(dim * 4, dim), ) + if returns_condition: + time_dim += time_dim + if not no_need_ret_sin: + self.returns_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim), + ) + else: + self.returns_mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim), + ) self.blocks = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): @@ -488,14 +504,23 @@ def __init__( nn.Linear(fc_dim // 2, out_dim), ) - def forward(self, x, cond, time, *args): + def forward(self, x, cond, time, returns=None, *args): # [batch, horizon, transition ] -> [batch, transition , horizon] x = x.transpose(1, 2) t = self.time_mlp(time) + + if returns is not None: + returns_embed = self.returns_mlp(returns) + t = torch.cat([t, returns_embed], dim=-1) + for resnet, resnet2, downsample in self.blocks: + # print('x:',x) x = resnet(x, t) + # print('after res1 x:',x) x = resnet2(x, t) + # print('after res2 x:',x) x = downsample(x) + # print('after down x:',x) x = self.mid_block1(x, t) x = self.mid_down1(x) diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 23db0fcdf9..aa6a16368f 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1101,6 +1101,332 @@ def __getitem__(self, idx, eps=1e-4): batch.update(self.get_conditions(observations)) return batch + +@DATASET_REGISTRY.register('meta_traj') +class MetaTraj(Dataset): + """ + Overview: + Dataset for Meta policy + Arguments: + - cfg (:obj:'dict'): cfg of policy + Key: + - dataset.data_dir_prefix (:obj:'str'): dataset path + - dataset.env_param_path (:obj:'str'): environment params path + - dataset.rtg_scale (:obj:'float'): return to go scale + - dataset.context_len (:obj:'int'): context len + - no_state_normalize (:obj:'bool'): whether normalize state + - no_action_normalize (:obj:'bool'): whether normalize action + - task_num (:obj:'int'): nums of meta tasks + - policy.max_len (:obj:'int'): max len of trajectory + - dataset.stochastic_prompt (:obj:'bool'): select max return prompt or random prompt + - dataset.need_prompt (:obj:'bool'): whether need prompt + - dataset.need_prompt (:obj:'list'): id of test evnironment + - dataset.need_next_obs (:obj:'bool'): whether need next_obs, if need, traj len = max_len + 1 + - dataset.cond (:obj:'bool'): whether add condition + Returns: + return trajectory dataset for Meta Policy. + """ + def __init__(self, cfg): + dataset_path = cfg.dataset.data_dir_prefix + env_param_path = cfg.dataset.env_param_path + self.rtg_scale = cfg.dataset.rtg_scale + self.context_len = cfg.dataset.context_len + self.no_state_normalize = cfg.policy.no_state_normalize + self.no_action_normalize = cfg.policy.no_action_normalize + self.task_num = cfg.policy.task_num + self.state_dim = cfg.policy.obs_dim + self.act_dim = cfg.policy.act_dim + self.max_len = cfg.policy.max_len + self.max_ep_len = cfg.policy.max_ep_len + self.batch_size = cfg.policy.learn.batch_size + self.stochastic_prompt = cfg.dataset.stochastic_prompt + self.need_prompt = cfg.dataset.need_prompt + self.test_id = cfg.dataset.test_id + self.need_next_obs = cfg.dataset.need_next_obs + self.cond = None + if 'cond' in cfg.dataset: + self.cond = cfg.dataset.cond + + try: + import h5py + import collections + except ImportError: + import sys + logging.warning("not found h5py package, please install it trough `pip install h5py ") + sys.exit(1) + data_ = collections.defaultdict(list) + + file_paths = [dataset_path + '_{}_sub_task_0.hdf5'.format(i) for i in range(0, self.task_num)] + param_paths = [env_param_path + '{}.pkl'.format(i) for i in self.test_id] + # train_env_dataset + self.traj = [] + self.state_means = [] + self.state_stds = [] + + # test_env_dataset + self.test_traj = [] + self.test_state_means = [] + self.test_state_stds = [] + + # for MetaDiffuser + if not self.no_action_normalize: + self.action_means = [] + self.action_stds = [] + self.test_action_means = [] + self.test_action_stds = [] + + # for prompt-DT + if self.need_prompt: + self.returns = [] + self.test_returns = [] + + id = 0 + for file_path in file_paths: + with h5py.File(file_path, 'r') as hf: + N = hf['rewards'].shape[0] + path = [] + for i in range(N): + for k in ['obs', 'actions', 'rewards', 'terminals', 'mask']: + data_[k].append(hf[k][i]) + path.append(data_) + data_ = collections.defaultdict(list) + + if self.need_prompt: + returns = np.sum(np.array(hf['rewards']), axis=1) + + state_mean, state_std = hf['state_mean'][:], hf['state_std'][:] + if not self.no_action_normalize: + action_mean, action_std = hf['action_mean'][:], hf['action_std'][:] + + # if id not in self.test_id: + # self.traj.append(path) + # self.state_means.append(state_mean) + # self.state_stds.append(state_std) + # if not self.no_action_normalize: + # self.action_means.append(action_mean) + # self.action_stds.append(action_std) + # if self.need_prompt: + # self.returns.append(returns) + # else: + # self.test_traj.append(path) + # self.test_state_means.append(state_mean) + # self.test_state_stds.append(state_std) + # if not self.no_action_normalize: + # self.test_action_means.append(action_mean) + # self.test_action_stds.append(action_std) + # if self.need_prompt: + # self.test_returns.append(returns) + + self.traj.append(path) + self.state_means.append(state_mean) + self.state_stds.append(state_std) + if not self.no_action_normalize: + self.action_means.append(action_mean) + self.action_stds.append(action_std) + if self.need_prompt: + self.returns.append(returns) + + self.test_traj.append(path) + self.test_state_means.append(state_mean) + self.test_state_stds.append(state_std) + if not self.no_action_normalize: + self.test_action_means.append(action_mean) + self.test_action_stds.append(action_std) + if self.need_prompt: + self.test_returns.append(returns) + id += 1 + + self.params = [] + for file in param_paths: + with open(file, 'rb') as f: + self.params.append(pickle.load(f)[0]) + + if self.need_prompt: + self.prompt_trajectories = [] + for i in range(len(self.traj)): + idx = np.argsort(self.returns[i]) # lowest to highest + # set 10% highest traj as prompt + idx = idx[-(len(self.traj[i]) // 20) : ] + self.prompt_trajectories.append(np.array(self.traj[i])[idx]) + + self.test_prompt_trajectories = [] + for i in range(len(self.test_traj)): + idx = np.argsort(self.test_returns[i]) + idx = idx[-(len(self.test_traj[i]) // 20) : ] + self.test_prompt_trajectories.append(np.array(self.test_traj[i])[idx]) + + self.set_task_id(0) + + def __len__(self): + return len(self.traj[self.task_id]) + + def get_prompt(self, sample_size=1, is_test=False, id=0): + if not is_test: + batch_inds = np.random.choice( + np.arange(len(self.prompt_trajectories[id])), + size=sample_size, + replace=True, + # p=p_sample, # reweights so we sample according to timesteps + ) + prompt_trajectories = self.prompt_trajectories[id] + sorted_inds = np.argsort(self.returns[id]) + else: + batch_inds = np.random.choice( + np.arange(len(self.test_prompt_trajectories[id])), + size=sample_size, + replace=True, + # p=p_sample, # reweights so we sample according to timesteps + ) + prompt_trajectories = self.test_prompt_trajectories[id] + sorted_inds = np.argsort(self.test_returns[id]) + + if self.stochastic_prompt: + traj = prompt_trajectories[batch_inds[sample_size]][0,0] # random select traj + else: + traj = prompt_trajectories[sorted_inds[-sample_size]][0,0] # select the best traj with highest rewards + # traj = prompt_trajectories[i] + si = max(0, len(traj['rewards'][0]) - self.max_len -1) # select the last traj with length max_len + + # get sequences from dataset + + s = traj['obs'][0][si:si + self.max_len] + a = traj['actions'][0][si:si + self.max_len] + mask = traj['mask'][0][si:si + self.max_len] + + timesteps = np.arange(si, si + np.array(mask).sum()) + rtg = discount_cumsum(traj['rewards'][0][si:], gamma=1.)[:s.shape[0]] + if rtg.shape[0] < s.shape[0]: + rtg = np.concatenate([rtg, np.zeros(((s.shape[0] - rtg.shape[0]), 1))], axis=1) + + # padding and state + reward normalization + # if tlen !=args.K: + # print('tlen not equal to k') + if not self.no_state_normalize: + s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] + rtg = rtg/ self.rtg_scale + + t_len = int(np.array(mask).sum()) + + timesteps = np.concatenate([timesteps, np.zeros((self.max_len - t_len))], axis=0) + + s = torch.from_numpy(s).to(dtype=torch.float32) + a = torch.from_numpy(a).to(dtype=torch.float32) + rtg = torch.from_numpy(rtg).to(dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.long) + mask = torch.from_numpy(mask).to(dtype=torch.long) + + return s, a, rtg, timesteps, mask + + # set task id + def set_task_id(self, id: int): + self.task_id = id + + def normalize(self, data: np.array, type: str, task_id: int): + if type == 'obs': + return (data - self.test_state_means[task_id]) / self.test_state_stds[task_id] + else: + return (data - self.test_action_means[task_id]) / self.test_action_stds[task_id] + + def unnormalize(self, data: np.array, type: str, task_id: int): + if type == 'obs': + return data * self.test_state_stds[task_id] + self.test_state_means[task_id] + else: + return data * self.test_action_stds[task_id] + self.test_action_means[task_id] + + # get warm start data + def get_pretrain_data(self, task_id: int, batch_size: int): + # get warm data + trajs = self.test_traj[task_id] + batch_idx = np.random.choice( + np.arange(len(trajs)), + size=batch_size, + ) + + max_len = self.max_len + if self.need_next_obs: + max_len += 1 + + s, a, r = [], [], [] + + for idx in batch_idx: + traj = trajs[idx] + si = np.random.randint(0, len(traj['obs'][0]) - max_len) + + state = traj['obs'][0][si:si + max_len] + action = traj['actions'][0][si:si + max_len] + state = np.array(state).squeeze() + action = np.array(action).squeeze() + if not self.no_state_normalize: + state = (state - self.test_state_means[task_id]) / self.test_state_stds[task_id] + if not self.no_action_normalize: + action = (action - self.test_action_means[task_id]) / self.test_action_stds[task_id] + s.append(state) + a.append(action) + r.append(traj['rewards'][0][si:si + max_len]) + + s = np.array(s) + a = np.array(a) + r = np.array(r) + + s = torch.from_numpy(s).to(dtype=torch.float32) + a = torch.from_numpy(a).to(dtype=torch.float32) + r = torch.from_numpy(r).to(dtype=torch.float32) + + cond_id = 0 + cond_val = s[:,0] + return s, a, r, cond_id, cond_val + + def __getitem__(self, index): + traj = self.traj[self.task_id][index] + si = np.random.randint(0, len(traj['rewards'][0]) - self.max_len) + + max_len = self.max_len + if self.need_next_obs: + max_len += 1 + + s = traj['obs'][0][si:si + max_len] + a = traj['actions'][0][si:si + max_len] + r = traj['rewards'][0][si:si + max_len] + mask = np.array(traj['mask'][0][si:si + max_len]) + # mask = np.ones((s.shape[0])) + timesteps = np.arange(si, si + mask.sum()) + rtg = discount_cumsum(traj['rewards'][0][si:], gamma=1.)[:s.shape[0]] / self.rtg_scale + if rtg.shape[0] < s.shape[0]: + rtg = np.concatenate([rtg, np.zeros(((s.shape[0] - rtg.shape[0]), 1))], axis=1) + + + if not self.no_state_normalize: + s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] + if not self.no_action_normalize: + a = (a - self.action_means[self.task_id]) / self.action_stds[self.task_id] + + s = np.array(s) + a = np.array(a) + r = np.array(r) + + tlen = int(mask.sum()) + + s = torch.from_numpy(s).to(dtype=torch.float32) + a = torch.from_numpy(a).to(dtype=torch.float32) + r = torch.from_numpy(r).to(dtype=torch.float32) + + rtg = rtg / self.rtg_scale + timesteps = np.concatenate([timesteps, np.zeros((max_len - tlen))], axis=0) + + + rtg = torch.from_numpy(rtg).to(dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.long) + mask = torch.from_numpy(mask).to(dtype=torch.long) + + if self.need_prompt: + prompt = self.get_prompt(id=self.task_id) + return prompt, timesteps, s, a, r, rtg, mask + elif self.cond: + cond_id = 0 + cond_val = s[0] + return timesteps, s, a, r, rtg, mask, cond_id, cond_val + else: + return timesteps, s, a, r, rtg, mask def hdf5_save(exp_data, expert_data_path): diff --git a/ding/worker/collector/__init__.py b/ding/worker/collector/__init__.py index 8ccfb17260..1b06f20aa8 100644 --- a/ding/worker/collector/__init__.py +++ b/ding/worker/collector/__init__.py @@ -16,3 +16,4 @@ from .zergling_parallel_collector import ZerglingParallelCollector from .marine_parallel_collector import MarineParallelCollector from .comm import BaseCommCollector, FlaskFileSystemCollector, create_comm_collector, NaiveCollector +from .interaction_serial_meta_evaluator import InteractionSerialMetaEvaluator diff --git a/ding/worker/collector/interaction_serial_meta_evaluator.py b/ding/worker/collector/interaction_serial_meta_evaluator.py new file mode 100755 index 0000000000..0cc3beb738 --- /dev/null +++ b/ding/worker/collector/interaction_serial_meta_evaluator.py @@ -0,0 +1,235 @@ +from typing import Optional, Callable, Tuple, Dict, List +from collections import namedtuple, defaultdict +import numpy as np +import torch + +from ...envs import BaseEnvManager +from ...envs import BaseEnvManager + +from ding.envs import BaseEnvManager +from ding.torch_utils import to_tensor, to_ndarray, to_item +from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY +from ding.utils import get_world_size, get_rank, broadcast_object_list +from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from .interaction_serial_evaluator import InteractionSerialEvaluator + +class InteractionSerialMetaEvaluator(InteractionSerialEvaluator): + """ + Overview: + Interaction serial evaluator class, policy interacts with env. This class evaluator algorithm + with test environment list. + Interfaces: + ``__init__``, reset, reset_policy, reset_env, close, should_eval, eval + Property: + env, policy + """ + config = dict( + # (int) Evaluate every "eval_freq" training iterations. + eval_freq=1000, + render=dict( + # Tensorboard video render is disabled by default. + render_freq=-1, + mode='train_iter', + ), + # (str) File path for visualize environment information. + figure_path=None, + # test env num + test_env_num=10, + ) + + def __init__( + self, + cfg: dict, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'evaluator', + ) -> None: + super().__init__(cfg, env, policy, tb_logger, exp_name, instance_name) + self.test_env_num = cfg.test_env_num + + def init_params(self, params): + self.params = params + self._env.set_all_goals(params) + + def eval( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + force_render: bool = False, + policy_kwargs: Optional[Dict] = {}, + policy_warm_func: namedtuple = None, + need_reward: bool = False, + ) -> Tuple[bool, Dict[str, List]]: + infos = defaultdict(list) + for i in range(self.test_env_num): + print('-----------------------------start task ', i) + self._env.reset_task(i) + if policy_warm_func is not None: + policy_warm_func(i) + info = self.sub_eval(save_ckpt_fn, train_iter, envstep, n_episode, \ + force_render, policy_kwargs, i, need_reward) + for key, val in info.items(): + if i == 0: + info[key] = [] + infos[key].append(val) + + meta_infos = defaultdict(list) + for key, val in infos.items(): + meta_infos[key] = np.array(val).mean() + + episode_return = meta_infos['reward_mean'] + meta_infos['train_iter'] = train_iter + meta_infos['ckpt_name'] = 'iteration_{}.pth.tar'.format(train_iter) + + self._logger.info(self._logger.get_tabulate_vars_hor(meta_infos)) + # self._logger.info(self._logger.get_tabulate_vars(info)) + for k, v in meta_infos.items(): + if k in ['train_iter', 'ckpt_name', 'each_reward']: + continue + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + + if episode_return > self._max_episode_return: + if save_ckpt_fn: + save_ckpt_fn('ckpt_best.pth.tar') + self._max_episode_return = episode_return + + stop_flag = episode_return >= self._stop_value and train_iter > 0 + if stop_flag: + self._logger.info( + "[DI-engine serial pipeline] " + "Current episode_return: {:.4f} is greater than stop_value: {}". + format(episode_return, self._stop_value) + ", so your RL agent is converged, you can refer to " + + "'log/evaluator/evaluator_logger.txt' for details." + ) + + return stop_flag, meta_infos + + def sub_eval( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + force_render: bool = False, + policy_kwargs: Optional[Dict] = {}, + task_id: int = 0, + need_reward: bool = False, + ) -> Tuple[bool, Dict[str, List]]: + ''' + Overview: + Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Arguments: + - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. + - train_iter (:obj:`int`): Current training iteration. + - envstep (:obj:`int`): Current env interaction step. + - n_episode (:obj:`int`): Number of evaluation episodes. + Returns: + - stop_flag (:obj:`bool`): Whether this training program can be ended. + - episode_info (:obj:`Dict[str, List]`): Current evaluation episode information. + ''' + # evaluator only work on rank0 + stop_flag = False + if get_rank() == 0: + if n_episode is None: + n_episode = self._default_n_episode + assert n_episode is not None, "please indicate eval n_episode" + envstep_count = 0 + info = {} + eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) + self._env.reset() + self._policy.reset() + + # force_render overwrite frequency constraint + render = force_render or self._should_render(envstep, train_iter) + + rewards = None + + with self._timer: + while not eval_monitor.is_finished(): + obs = self._env.ready_obs + obs = to_tensor(obs, dtype=torch.float32) + + if need_reward: + for id,val in obs.items(): + if rewards is None: + reward = torch.zeros((1)) + else: + reward = torch.tensor(rewards[id], dtype=torch.float32) + obs[id] = {'obs':val, 'reward':reward} + + + # update videos + if render: + eval_monitor.update_video(self._env.ready_imgs) + + if self._policy_cfg.type == 'dreamer_command': + policy_output = self._policy.forward( + obs, **policy_kwargs, reset=self._resets, state=self._states + ) + #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} + self._states = [output['state'] for output in policy_output.values()] + else: + policy_output = self._policy.forward(obs, **policy_kwargs) + actions = {i: a['action'] for i, a in policy_output.items()} + actions = to_ndarray(actions) + timesteps = self._env.step(actions) + timesteps = to_tensor(timesteps, dtype=torch.float32) + rewards = [] + for env_id, t in timesteps.items(): + rewards.append(t.reward) + if t.info.get('abnormal', False): + # If there is an abnormal timestep, reset all the related variables(including this env). + self._policy.reset([env_id]) + continue + if self._policy_cfg.type == 'dreamer_command': + self._resets[env_id] = t.done + if t.done: + # Env reset is done by env_manager automatically. + if 'figure_path' in self._cfg and self._cfg.figure_path is not None: + self._env.enable_save_figure(env_id, self._cfg.figure_path) + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'] + saved_info = {'eval_episode_return': t.info['eval_episode_return']} + if 'episode_info' in t.info: + saved_info.update(t.info['episode_info']) + eval_monitor.update_info(env_id, saved_info) + eval_monitor.update_reward(env_id, reward) + self._logger.info( + "[EVALUATOR]env {} finish task {} episode, final reward: {:.4f}, current episode: {}".format( + env_id, task_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) + ) + envstep_count += 1 + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() + info = { + 'episode_count': n_episode, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / n_episode, + 'evaluate_time': duration, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + # 'each_reward': episode_return, + } + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + info.update(episode_info) + + if render: + video_title = '{}_{}/'.format(self._instance_name, self._render.mode) + videos = eval_monitor.get_video() + render_iter = envstep if self._render.mode == 'envstep' else train_iter + from ding.utils import fps + self._tb_logger.add_video(video_title, videos, render_iter, fps(self._env)) + + return info diff --git a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py new file mode 100755 index 0000000000..6e0b5a7701 --- /dev/null +++ b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py @@ -0,0 +1,137 @@ +from easydict import EasyDict + +main_config = dict( + exp_name="walker_params_md_seed0", + env=dict( + env_id='walker_params', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + returns_scale=1.0, + termination_penalty=-100, + max_path_length=1000, + use_padding=True, + include_returns=True, + normed=False, + stop_value=8000, + horizon=32, + obs_dim=17, + action_dim=6, + test_num=1,#10, + ), + policy=dict( + cuda=True, + max_len=32, + max_ep_len=200, + task_num=1,#40, + train_num=1,#30, + obs_dim=17, + act_dim=6, + no_state_normalize=False, + no_action_normalize=False, + need_init_dataprocess=True, + need_reward=False, + model=dict( + diffuser_model_cfg=dict( + model='DiffusionUNet1d', + model_cfg=dict( + transition_dim=23, + dim=64, + returns_dim=1, + dim_mults=[1, 4, 8], + returns_condition=True, + condition_dropout=0.3, + kernel_size=5, + attention=False, + ), + horizon=32, + obs_dim=17, + action_dim=6, + n_timesteps=20, + predict_epsilon=False, + condition_guidance_w=1.6, + action_weight=10, + loss_discount=1, + returns_condition=True, + ), + reward_cfg=dict( + model='TemporalValue', + model_cfg=dict( + horizon=32, + transition_dim=23, + dim=64, + out_dim=32, + dim_mults=[1, 4, 8], + kernel_size=5, + returns_condition=True, + no_need_ret_sin=True, + ), + horizon=32, + obs_dim=17, + action_dim=6, + n_timesteps=20, + predict_epsilon=True, + loss_discount=1, + ), + horizon=32, + encoder_horizon=20, + n_guide_steps=2, + scale=2, + t_stopgrad=2, + scale_grad_by_std=True, + ), + normalizer='GaussianNormalizer', + learn=dict( + data_path=None, + train_epoch=60000, + gradient_accumulate_every=2, + batch_size=32, + encoder_len=20, + learning_rate=2e-4, + discount_factor=0.99, + learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), + eval_batch_size=8, + warm_batch_size=32, + test_num=1,#10, + plan_batch_size=1, + ), + collect=dict(data_type='meta_traj', ), + eval=dict( + evaluator=dict( + eval_freq=500, + test_env_num=1, + ), + test_ret=0.9, + ), + other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ), + ), + dataset=dict( + data_dir_prefix='/mnt/nfs/share/meta/walker_traj/buffers_walker_param_train', + rtg_scale=1, + context_len=1, + stochastic_prompt=False, + need_prompt=False, + test_id=[0],#[2,12,22,28,31,4,15,10,18,38], + cond=True, + env_param_path='/mnt/nfs/share/meta/walker/env_walker_param_train_task', + need_next_obs=True, + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type='meta', + import_names=['dizoo.meta_mujoco.envs.meta_env'], + ), + env_manager=dict(type='meta_subprocess'), + policy=dict( + type='metadiffuser', + ), + replay_buffer=dict(type='naive', ), +) +create_config = EasyDict(create_config) +create_config = create_config \ No newline at end of file diff --git a/dizoo/meta_mujoco/config/walker2d_promptdt_config.py b/dizoo/meta_mujoco/config/walker2d_promptdt_config.py new file mode 100755 index 0000000000..24ba95bc7c --- /dev/null +++ b/dizoo/meta_mujoco/config/walker2d_promptdt_config.py @@ -0,0 +1,109 @@ +from easydict import EasyDict +from copy import deepcopy + +main_config = dict( + exp_name='walker_params_promptdt_seed0', + env=dict( + env_id='walker_params', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + returns_scale=1.0, + termination_penalty=-100, + max_path_length=1000, + use_padding=True, + include_returns=True, + normed=False, + stop_value=8000, + horizon=32, + obs_dim=17, + action_dim=6, + test_num=1, + ), + dataset=dict( + data_dir_prefix='/mnt/nfs/share/meta/walker_traj/buffers_walker_param_train', + rtg_scale=1, + context_len=1, + stochastic_prompt=False, + need_prompt=False, + test_id=[0],#[2,12,22,28,31,4,15,10,18,38], + cond=False, + env_param_path='/mnt/nfs/share/meta/walker/env_walker_param_train_task', + need_next_obs=False, + ), + policy=dict( + cuda=True, + stop_value=5000, + max_len=20, + max_ep_len=200, + task_num=1,#40, + train_num=1,#30, + obs_dim=17, + act_dim=6, + need_prompt=False, + state_mean=None, + state_std=None, + no_state_normalize=False, + no_action_normalize=True, + need_init_dataprocess=True, + need_reward=True, + evaluator_env_num=8, + rtg_target=400, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=3, + h_dim=128, + context_len=20, + n_heads=1, + drop_p=0.1, + continuous=True, + use_prompt=False, + ), + batch_size=32, + learning_rate=1e-4, + collect=dict(data_type='meta_traj', ), + learn=dict( + data_path=None, + train_epoch=60000, + gradient_accumulate_every=2, + batch_size=32, + learning_rate=1e-4, + discount_factor=0.99, + learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), + eval_batch_size=8, + test_num=1, + ), + eval=dict( + evaluator=dict( + eval_freq=500, + test_env_num=1, + ), + test_ret=0.9, + ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type='meta', + import_names=['dizoo.meta_mujoco.envs.meta_env'], + ), + env_manager=dict(type='meta_subprocess'), + policy=dict( + type='promptdt', + ), + replay_buffer=dict(type='naive', ), +) +create_config = EasyDict(create_config) +create_config = create_config \ No newline at end of file diff --git a/dizoo/meta_mujoco/entry/meta_entry.py b/dizoo/meta_mujoco/entry/meta_entry.py new file mode 100755 index 0000000000..18ede61c16 --- /dev/null +++ b/dizoo/meta_mujoco/entry/meta_entry.py @@ -0,0 +1,21 @@ +from ding.entry import serial_pipeline_meta_offline +from ding.config import read_config +from pathlib import Path + + +def train(args): + # launch from anywhere + config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = read_config(str(config)) + config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) + serial_pipeline_meta_offline(config, seed=args.seed) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--seed', '-s', type=int, default=10) + parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py') + args = parser.parse_args() + train(args) \ No newline at end of file diff --git a/dizoo/meta_mujoco/envs/meta_env.py b/dizoo/meta_mujoco/envs/meta_env.py new file mode 100755 index 0000000000..32bb3c25b4 --- /dev/null +++ b/dizoo/meta_mujoco/envs/meta_env.py @@ -0,0 +1,94 @@ +from typing import Any, Union, List +import copy +import gym +import numpy as np +from easydict import EasyDict + +#from CORRO.environments.make_env import make_env + +from rand_param_envs.make_env import make_env + +from ding.torch_utils import to_ndarray, to_list +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs.common.common_function import affine_transform +from ding.utils import ENV_REGISTRY + +@ENV_REGISTRY.register('meta') +class MujocoEnv(BaseEnv): + + def __init__(self, cfg: dict) -> None: + self._init_flag = False + self._use_act_scale = cfg.use_act_scale + self._cfg = cfg + + def reset(self) -> Any: + if not self._init_flag: + self._env = make_env(self._cfg.env_id, 1, seed=self._cfg.seed, n_tasks=self._cfg.test_num) + self._env.observation_space.dtype = np.float32 + self._observation_space = self._env.observation_space + self._action_space = self._env.action_space + self._reward_space = gym.spaces.Box( + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 + ) + self._init_flag = True + obs = self._env.reset() + obs = to_ndarray(obs).astype('float32') + self._eval_episode_return = 0. + return obs + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: + action = to_ndarray(action) + if self._use_act_scale: + action_range = {'min': self.action_space.low[0], 'max': self.action_space.high[0], 'dtype': np.float32} + action = affine_transform(action, min_val=action_range['min'], max_val=action_range['max']) + obs, rew, done, info = self._env.step(action) + self._eval_episode_return += rew + obs = to_ndarray(obs).astype('float32') + rew = to_ndarray([rew]) + if done: + info['eval_episode_return'] = self._eval_episode_return + return BaseEnvTimestep(obs, rew, done, info) + + def __repr__(self) -> str: + return "DI-engine D4RL Env({})".format(self._cfg.env_id) + + def set_all_goals(self, params): + self._env.set_all_goals(params) + + def reset_task(self, id): + self._env.reset_task(id) + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_cfg = copy.deepcopy(cfg) + collector_env_num = collector_cfg.pop('collector_env_num', 1) + return [collector_cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_cfg = copy.deepcopy(cfg) + evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1) + evaluator_cfg.get('norm_reward', EasyDict(use_norm=False, )).use_norm = False + return [evaluator_cfg for _ in range(evaluator_env_num)] + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space \ No newline at end of file