diff --git a/libero/configs/config_rwla.yaml b/libero/configs/config_rwla.yaml new file mode 100644 index 00000000..047f6e19 --- /dev/null +++ b/libero/configs/config_rwla.yaml @@ -0,0 +1,83 @@ +# @package _global_ + + +defaults: + - _self_ + - data: default + - policy: bc_rwla_policy + - train: default + - eval: default + - lifelong: base + - test: null + +seed: 10000 +use_wandb: true +wandb_project: "lifelong learning" +folder: null # "/mnt/data/adityas/CS395TRobotManipulationSp25/LIBERO/libero/libero" +bddl_folder: "/mnt/data/adityas/CS395TRobotManipulationSp25/LIBERO/libero/libero/bddl_files" +init_states_folder: "/mnt/data/adityas/CS395TRobotManipulationSp25/LIBERO/libero/libero/init_files" # use default path +load_previous_model: false +device: "cuda" +device_id: 0 +device_id: 0 +task_embedding_format: "sentence-similarity" +task_embedding_one_hot_offset: 1 +pretrain: false +pretrain_model_path: "" +benchmark_name: "LIBERO_OBJECT" +# checkpoint_path: "/mnt/data/adityas/CS395TRobotManipulationSp25/experiments/LIBERO_OBJECT/Sequential/BCRWLAPolicy_seed10000/run_044/task0_model.pth" +# --------------------------------- + +data: + task_order_index: 0 + seq_len: 10 # Sequence length for temporal processing + obs: + modality: + rgb: ["agentview_rgb", "eye_in_hand_rgb"] + low_dim: ["joint_states", "gripper_states"] + +train: + batch_size: 32 + num_workers: 2 + persistent_workers: true + n_epochs: 50 + grad_clip: 100.0 + checkpoint_dir: ./checkpoints + checkpoint_interval: 5 + resume: true + # resume_path: "" # Path to checkpoint to resume from + optimizer: + name: AdamW + kwargs: + lr: 0.0001 + weight_decay: 0.0001 + betas: [0.9, 0.999] + scheduler: + name: CosineAnnealingLR + kwargs: + eta_min: 0.00001 + +lifelong: + memory_capacity: 500 + demos_per_task: 30 + epochs_per_task: 50 + eval_interval: 10 + replay_batch_size: 32 + replay_coef: 1.0 + resume_task_idx: 0 # Task index to resume from if resuming + +rwla: + # Reviewing phase settings + num_rollouts: 30 + max_steps: 50 + + # Retrieval settings + num_retrieved_demos: 20 + + # Selective weighting settings + weight_increment: 0.3 + max_weight: 2.0 + + # Local adaptation settings + adaptation_epochs: 30 + adaptation_lr: 0.001 \ No newline at end of file diff --git a/libero/configs/config_rwla_tiejean.yaml b/libero/configs/config_rwla_tiejean.yaml new file mode 100644 index 00000000..28b98391 --- /dev/null +++ b/libero/configs/config_rwla_tiejean.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +defaults: + - _self_ + - data: default + - policy: bc_rwla_policy + - train: default + - eval: default + - lifelong: base + - test: null + +train: + checkpoint_interval: 5 + +seed: 10000 +use_wandb: true +wandb_project: "rwla" +folder: null # use default path +bddl_folder: "/home/tiejean/Workspace/TaskAgonosticManipulation/LIBERO/libero/libero/bddl_files" +init_states_folder: "/home/tiejean/Workspace/TaskAgonosticManipulation/LIBERO/libero/libero/init_files" # use default path +load_previous_model: false +device: "cuda" +task_embedding_format: "sentence-similarity" +task_embedding_one_hot_offset: 1 +pretrain: false +pretrain_model_path: "" +benchmark_name: "LIBERO_SPATIAL" \ No newline at end of file diff --git a/libero/configs/eval/default.yaml b/libero/configs/eval/default.yaml index 0c926638..2c5c584f 100644 --- a/libero/configs/eval/default.yaml +++ b/libero/configs/eval/default.yaml @@ -1,10 +1,10 @@ load_path: "" # only used when separately evaluating a pretrained model eval: true -batch_size: 64 +batch_size: 128 num_workers: 4 -n_eval: 20 +n_eval: 10 eval_every: 5 max_steps: 600 use_mp: true -num_procs: 20 +num_procs: 5 save_sim_states: false diff --git a/libero/configs/policy/bc_rwla_policy.yaml b/libero/configs/policy/bc_rwla_policy.yaml new file mode 100644 index 00000000..d5dd1ac9 --- /dev/null +++ b/libero/configs/policy/bc_rwla_policy.yaml @@ -0,0 +1,20 @@ +policy_type: BCRWLAPolicy +extra_num_layers: 0 +extra_hidden_size: 128 +embed_size: 64 + +transformer_input_size: null +transformer_num_layers: 4 +transformer_num_heads: 6 +transformer_head_output_size: 64 +transformer_mlp_hidden_size: 256 +transformer_dropout: 0.1 +transformer_max_seq_len: 10 + +defaults: + - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml + - data_augmentation@translation_aug: translation_aug.yaml + - image_encoder: resnet_encoder.yaml + - language_encoder: mlp_sentence_similarity_encoder.yaml + - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml + - policy_head: gmm_head.yaml diff --git a/libero/configs/policy/language_encoder/mlp_sentence_similarity_encoder.yaml b/libero/configs/policy/language_encoder/mlp_sentence_similarity_encoder.yaml new file mode 100644 index 00000000..53873632 --- /dev/null +++ b/libero/configs/policy/language_encoder/mlp_sentence_similarity_encoder.yaml @@ -0,0 +1,6 @@ +network: MLPEncoder +network_kwargs: + input_size: 368 + hidden_size: 128 + output_size: 64 + num_layers: 1 diff --git a/libero/configs/train/default.yaml b/libero/configs/train/default.yaml index a30c2f7d..f1501f87 100644 --- a/libero/configs/train/default.yaml +++ b/libero/configs/train/default.yaml @@ -1,7 +1,7 @@ # training n_epochs: 50 -batch_size: 32 -num_workers: 4 +batch_size: 64 # default 32 +num_workers: 8 grad_clip: 100. loss_scale: 1.0 diff --git a/libero/lifelong/algos/base.py b/libero/lifelong/algos/base.py index 90d1cff8..7758d3b9 100644 --- a/libero/lifelong/algos/base.py +++ b/libero/lifelong/algos/base.py @@ -126,6 +126,22 @@ def eval_observe(self, data): return loss.item() def learn_one_task(self, dataset, task_id, benchmark, result_summary): + # task = benchmark.get_task(task_id) + # task_emb = benchmark.get_task_emb(task_id) + # sim_states = ( + # result_summary[task_str] if self.cfg.eval.save_sim_states else None + # ) + + # success_rate = evaluate_one_task_success( + # cfg=self.cfg, + # algo=self, + # task=task, + # task_emb=task_emb, + # task_id=task_id, + # sim_states=sim_states, + # task_str="", + # ) + # import pdb; pdb.set_trace() self.start_task(task_id) @@ -143,6 +159,7 @@ def learn_one_task(self, dataset, task_id, benchmark, result_summary): num_workers=self.cfg.train.num_workers, sampler=RandomSampler(dataset), persistent_workers=True, + pin_memory=True ) prev_success_rate = -1.0 @@ -159,6 +176,8 @@ def learn_one_task(self, dataset, task_id, benchmark, result_summary): task_emb = benchmark.get_task_emb(task_id) # start training + import math + best_training_loss = math.inf for epoch in range(0, self.cfg.train.n_epochs + 1): t0 = time.time() @@ -178,6 +197,7 @@ def learn_one_task(self, dataset, task_id, benchmark, result_summary): training_loss /= len(train_dataloader) t1 = time.time() + best_training_loss = min(best_training_loss, training_loss) print( f"[info] Epoch: {epoch:3d} | train loss: {training_loss:5.2f} | time: {(t1-t0)/60:4.2f}" ) @@ -224,6 +244,15 @@ def learn_one_task(self, dataset, task_id, benchmark, result_summary): + f"| succ. AoC {tmp_successes.sum()/cumulated_counter:4.2f} | time: {(t1-t0)/60:4.2f}", flush=True, ) + + if self.cfg.use_wandb: + import wandb + wandb.log({"epoch": epoch, + "loss": training_loss, + "best loss": best_training_loss, + "success rates": success_rate, + "best success rates": prev_success_rate, + "learning_rate": self.scheduler.get_last_lr()[0]}) if self.scheduler is not None and epoch > 0: self.scheduler.step() @@ -253,5 +282,243 @@ def learn_one_task(self, dataset, task_id, benchmark, result_summary): successes[idx_at_best_succ:] = successes[idx_at_best_succ] return successes.sum() / cumulated_counter, losses.sum() / cumulated_counter + def learn_one_task_with_memory_replay(self, dataset, task_id, benchmark, result_summary, memory, replay_batch_size=32, replay_coef=0.5): + """ + Learn one task with experience replay from memory + This extends the learn_one_task method with memory replay mechanism for RWLA algorithm. + """ + self.start_task(task_id) + + # recover the corresponding manipulation task ids + gsz = self.cfg.data.task_group_size + manip_task_ids = list(range(task_id * gsz, (task_id + 1) * gsz)) + + model_checkpoint_name = os.path.join( + self.experiment_dir, f"task{task_id}_model.pth" + ) + + train_dataloader = DataLoader( + dataset, + batch_size=self.cfg.train.batch_size, + num_workers=0, #self.cfg.train.num_workers, + sampler=RandomSampler(dataset), + persistent_workers=False, + pin_memory=True + ) + + prev_success_rate = -1.0 + best_state_dict = self.policy.state_dict() # currently save the best model + + # for evaluate how fast the agent learns on current task, this corresponds + # to the area under success rate curve on the new task. + cumulated_counter = 0.0 + idx_at_best_succ = 0 + successes = [] + losses = [] + + task = benchmark.get_task(task_id) + task_emb = benchmark.get_task_emb(task_id) + + # start training + import math + best_training_loss = math.inf + for epoch in range(0, self.cfg.train.n_epochs + 1): + # for epoch in range(0, 2): + t0 = time.time() + + if epoch > 0: # update + self.policy.train() + training_loss = 0.0 + replay_loss = 0.0 + + for (idx, data) in enumerate(train_dataloader): + # Move current task data to device + data = self.map_tensor_to_device(data) + + self.optimizer.zero_grad() + task_loss = self.policy.compute_loss(data) + + # Add experience replay if not first task and memory has data + if task_id > 0 and memory.get_memory_size() > 0: + try: + replay_batch = memory.get_replay_batch( + batch_size=min(replay_batch_size, memory.get_memory_size()) + ) + # import pdb; pdb.set_trace() + if len(replay_batch) > 0: + replay_data = {} + + replay_obs = {} + for key in ["agentview_rgb", "eye_in_hand_rgb", "gripper_states", "joint_states"]: + items = [] + for demo in replay_batch: + if "obs" in demo and key in demo["obs"]: + items.append(demo["obs"][key]) + + if items: + replay_obs[key] = torch.stack(items).to(self.cfg.device) + + replay_data["obs"] = replay_obs + + actions = [] + task_embs = [] + language = [] + for demo in replay_batch: + if "actions" in demo: + actions.append(demo["actions"]) + if "language_description" in demo: + language.append(demo["language_description"]) + + task_embs = [demo["task_emb"] for demo in replay_batch if "task_emb" in demo] + if task_embs: + replay_data["task_emb"] = torch.stack(task_embs).to(self.cfg.device) + if actions: + replay_data["actions"] = torch.stack(actions).to(self.cfg.device) + if language: + replay_data["language"] = language + + # Compute loss on replay data if we have valid data + if "actions" in replay_data and replay_data["obs"] and len(replay_data["obs"]) > 0: + memory_loss = self.policy.compute_loss(replay_data) + + weighted_memory_loss = replay_coef * memory_loss + combined_task_loss = task_loss + weighted_memory_loss + + replay_loss += memory_loss.item() + + (self.loss_scale * combined_task_loss).backward() + else: + # If replay data processing failed, just use current task loss + (self.loss_scale * task_loss).backward() + else: + # If no replay batch, just use current task loss + (self.loss_scale * task_loss).backward() + + except Exception as e: + # If anything fails in replay, just use current task loss + print(f"[warning] Experience replay failed: {e}") + import traceback + traceback.print_exc() + (self.loss_scale * task_loss).backward() + else: + # For first task, just use current task loss + (self.loss_scale * task_loss).backward() + + if self.cfg.train.grad_clip is not None: + grad_norm = nn.utils.clip_grad_norm_( + self.policy.parameters(), self.cfg.train.grad_clip + ) + + self.optimizer.step() + training_loss += task_loss.item() + + training_loss /= len(train_dataloader) + if task_id > 0 and memory.get_memory_size() > 0: + replay_loss /= max(1, len(train_dataloader)) + print(f"[info] Task loss: {training_loss:.4f}, Replay loss: {replay_loss:.4f}") + + else: # just evaluate the zero-shot performance on 0-th epoch + training_loss = 0.0 + for (idx, data) in enumerate(train_dataloader): + loss = self.eval_observe(data) + training_loss += loss + training_loss /= len(train_dataloader) + + t1 = time.time() + + best_training_loss = min(best_training_loss, training_loss) + print( + f"[info] Epoch: {epoch:3d} | train loss: {training_loss:5.2f} | time: {(t1-t0)/60:4.2f}" + ) + + # Evaluation code - same as in original method + if epoch % self.cfg.eval.eval_every == 0: # evaluate BC loss + losses.append(training_loss) + + t0 = time.time() + + task_str = f"k{task_id}_e{epoch//self.cfg.eval.eval_every}" + sim_states = ( + result_summary[task_str] if self.cfg.eval.save_sim_states else None + ) + success_rate = evaluate_one_task_success( + cfg=self.cfg, + algo=self, + task=task, + task_emb=task_emb, + task_id=task_id, + sim_states=sim_states, + task_str="", + ) + successes.append(success_rate) + + if prev_success_rate < success_rate: + torch_save_model(self.policy, model_checkpoint_name, cfg=self.cfg) + prev_success_rate = success_rate + idx_at_best_succ = len(losses) - 1 + + t1 = time.time() + + cumulated_counter += 1.0 + ci = confidence_interval(success_rate, self.cfg.eval.n_eval) + tmp_successes = np.array(successes) + tmp_successes[idx_at_best_succ:] = successes[idx_at_best_succ] + print( + f"[info] Epoch: {epoch:3d} | succ: {success_rate:4.2f} ± {ci:4.2f} | best succ: {prev_success_rate} " + + f"| succ. AoC {tmp_successes.sum()/cumulated_counter:4.2f} | time: {(t1-t0)/60:4.2f}", + flush=True, + ) + + # Log to wandb if enabled + if self.cfg.use_wandb: + import wandb + wandb_log_dict = { + "epoch": epoch, + "loss": training_loss, + "best loss": best_training_loss, + "success rates": success_rate, + "best success rates": prev_success_rate + } + + # Add replay loss if applicable + # if task_id > 0 and memory.get_memory_size() > 0: + # wandb_log_dict["replay_loss"] = replay_loss + + # Add learning rate if scheduler exists + if self.scheduler is not None: + wandb_log_dict["learning_rate"] = self.scheduler.get_last_lr()[0] + + wandb.log(wandb_log_dict) + + # Step scheduler + if self.scheduler is not None and epoch > 0: + self.scheduler.step() + + # load the best performance agent on the current task + self.policy.load_state_dict(torch_load_model(model_checkpoint_name)[0]) + + # end learning the current task, some algorithms need post-processing + self.end_task(dataset, task_id, benchmark) + + # return the metrics regarding forward transfer + losses = np.array(losses) + successes = np.array(successes) + auc_checkpoint_name = os.path.join( + self.experiment_dir, f"task{task_id}_auc.log" + ) + torch.save( + { + "success": successes, + "loss": losses, + }, + auc_checkpoint_name, + ) + + # pretend that the agent stops learning once it reaches the peak performance + losses[idx_at_best_succ:] = losses[idx_at_best_succ] + successes[idx_at_best_succ:] = successes[idx_at_best_succ] + return successes.sum() / cumulated_counter, losses.sum() / cumulated_counter + + def reset(self): self.policy.reset() diff --git a/libero/lifelong/datasets.py b/libero/lifelong/datasets.py index f0b5a1c1..06e3ed36 100644 --- a/libero/lifelong/datasets.py +++ b/libero/lifelong/datasets.py @@ -71,6 +71,10 @@ def __len__(self): def __getitem__(self, idx): return_dict = self.sequence_dataset.__getitem__(idx) + return_dict["obs"]["agentview_rgb"] = return_dict["obs"]["agentview_rgb"].astype(np.float32) + return_dict["obs"]["eye_in_hand_rgb"] = return_dict["obs"]["eye_in_hand_rgb"].astype(np.float32) + return_dict["obs"]["gripper_states"] = return_dict["obs"]["gripper_states"].astype(np.float32) + return_dict["obs"]["joint_states"] = return_dict["obs"]["joint_states"].astype(np.float32) return_dict["task_emb"] = self.task_emb return return_dict diff --git a/libero/lifelong/evaluate.py b/libero/lifelong/evaluate.py index d4db4dd9..ee8ab73e 100644 --- a/libero/lifelong/evaluate.py +++ b/libero/lifelong/evaluate.py @@ -113,7 +113,6 @@ def parse_args(): ), "[error] load_task should be in [0, ..., 9]" return args - def main(): args = parse_args() # e.g., experiments/LIBERO_SPATIAL/Multitask/BCRNNPolicy_seed100/ diff --git a/libero/lifelong/main.py b/libero/lifelong/main.py index 7ea86853..08b4d17e 100644 --- a/libero/lifelong/main.py +++ b/libero/lifelong/main.py @@ -34,7 +34,7 @@ ) -@hydra.main(config_path="../configs", config_name="config", version_base=None) +@hydra.main(config_path="../configs", config_name="config_rwla", version_base=None) def main(hydra_cfg): # preprocessing yaml_config = OmegaConf.to_yaml(hydra_cfg) @@ -167,8 +167,8 @@ def main(hydra_cfg): sys.exit(0) print(f"[info] start lifelong learning with algo {cfg.lifelong.algo}") - GFLOPs, MParams = compute_flops(algo, datasets[0], cfg) - print(f"[info] policy has {GFLOPs:.1f} GFLOPs and {MParams:.1f} MParams\n") + # GFLOPs, MParams = compute_flops(algo, datasets[0], cfg) + # print(f"[info] policy has {GFLOPs:.1f} GFLOPs and {MParams:.1f} MParams\n") # save the experiment config file, so we can resume or replay later with open(os.path.join(cfg.experiment_dir, "config.json"), "w") as f: @@ -259,6 +259,7 @@ def main(hydra_cfg): torch.save( result_summary, os.path.join(cfg.experiment_dir, f"result.pt") ) + torch.save(algo.policy.state_dict(), os.path.join(cfg.experiment_dir, f"checkpoint_task{i}.pt")) print("[info] finished learning\n") if cfg.use_wandb: diff --git a/libero/lifelong/metric.py b/libero/lifelong/metric.py index 9777577d..dbaee5e2 100644 --- a/libero/lifelong/metric.py +++ b/libero/lifelong/metric.py @@ -104,7 +104,7 @@ def evaluate_one_task_success( init_states_path = os.path.join( cfg.init_states_folder, task.problem_folder, task.init_states_file ) - init_states = torch.load(init_states_path) + init_states = torch.load(init_states_path, weights_only=False) num_success = 0 for i in range(eval_loop_num): env.reset() @@ -131,8 +131,11 @@ def evaluate_one_task_success( steps += 1 data = raw_obs_to_tensor_obs(obs, task_emb, cfg) - actions = algo.policy.get_action(data) - + if hasattr(algo, "policy"): + actions = algo.policy.get_action(data) + else: + actions = algo.get_action(data) + obs, reward, done, info = env.step(actions) # record the sim states for replay purpose diff --git a/libero/lifelong/models/__init__.py b/libero/lifelong/models/__init__.py index 4d7d69f1..b0e8954b 100644 --- a/libero/lifelong/models/__init__.py +++ b/libero/lifelong/models/__init__.py @@ -1,5 +1,6 @@ from libero.lifelong.models.bc_rnn_policy import BCRNNPolicy from libero.lifelong.models.bc_transformer_policy import BCTransformerPolicy from libero.lifelong.models.bc_vilt_policy import BCViLTPolicy +from libero.lifelong.models.bc_rwla_policy import BCRWLAPolicy from libero.lifelong.models.base_policy import get_policy_class, get_policy_list diff --git a/libero/lifelong/models/base_policy.py b/libero/lifelong/models/base_policy.py index 16bdcc3e..cd68b80f 100644 --- a/libero/lifelong/models/base_policy.py +++ b/libero/lifelong/models/base_policy.py @@ -100,6 +100,7 @@ def _get_aug_output_dict(self, out): def preprocess_input(self, data, train_mode=True): if train_mode: # apply augmentation if self.cfg.train.use_augmentation: + # import pdb; pdb.set_trace() img_tuple = self._get_img_tuple(data) aug_out = self._get_aug_output_dict(self.img_aug(img_tuple)) for img_name in self.image_encoders.keys(): diff --git a/libero/lifelong/models/bc_rwla_policy.py b/libero/lifelong/models/bc_rwla_policy.py new file mode 100644 index 00000000..9a626430 --- /dev/null +++ b/libero/lifelong/models/bc_rwla_policy.py @@ -0,0 +1,144 @@ +import robomimic.utils.tensor_utils as TensorUtils +import torch +import torch.nn as nn + +from libero.lifelong.models.modules.rgb_modules import * +from libero.lifelong.models.modules.language_modules import * +from libero.lifelong.models.modules.transformer_modules import * +from libero.lifelong.models.base_policy import BasePolicy +from libero.lifelong.models.policy_head import * + +############################################################################### +# +# A Transformer Policy +# +############################################################################### + +from libero.lifelong.models.bc_transformer_policy import ExtraModalityTokens +class BCRWLAPolicy(BasePolicy): + """ + Input: (o_{t-H}, ... , o_t) + Output: a_t or distribution of a_t + """ + + def __init__(self, cfg, shape_meta): + super().__init__(cfg, shape_meta) + policy_cfg = cfg.policy + + ### 1. encode image + embed_size = policy_cfg.embed_size + transformer_input_sizes = [] + self.image_encoders = {} + for name in shape_meta["all_shapes"].keys(): + if "rgb" in name or "depth" in name: + kwargs = policy_cfg.image_encoder.network_kwargs + kwargs.input_shape = shape_meta["all_shapes"][name] + kwargs.output_size = embed_size + kwargs.language_dim = ( + policy_cfg.language_encoder.network_kwargs.input_size + ) + from rwla.models import VisualEncoder + self.image_encoders[name] = { + "input_shape": shape_meta["all_shapes"][name], + "encoder": VisualEncoder(device="cuda:0"), + } + + ### 2. encode language + policy_cfg.language_encoder.network_kwargs.output_size = embed_size + self.language_encoder = eval(policy_cfg.language_encoder.network)( + **policy_cfg.language_encoder.network_kwargs + ) + + ### 3. encode extra information (e.g. gripper, joint_state) + self.extra_encoder = ExtraModalityTokens( + use_joint=cfg.data.use_joint, + use_gripper=cfg.data.use_gripper, + use_ee=cfg.data.use_ee, + extra_num_layers=policy_cfg.extra_num_layers, + extra_hidden_size=policy_cfg.extra_hidden_size, + extra_embedding_size=embed_size, + ) + + ### 4. define temporal transformer + policy_cfg.temporal_position_encoding.network_kwargs.input_size = embed_size + self.temporal_position_encoding_fn = eval( + policy_cfg.temporal_position_encoding.network + )(**policy_cfg.temporal_position_encoding.network_kwargs) + + self.temporal_transformer = TransformerDecoder( + input_size=embed_size, + num_layers=policy_cfg.transformer_num_layers, + num_heads=policy_cfg.transformer_num_heads, + head_output_size=policy_cfg.transformer_head_output_size, + mlp_hidden_size=policy_cfg.transformer_mlp_hidden_size, + dropout=policy_cfg.transformer_dropout, + ) + + policy_head_kwargs = policy_cfg.policy_head.network_kwargs + policy_head_kwargs.input_size = embed_size + policy_head_kwargs.output_size = shape_meta["ac_dim"] + + self.policy_head = eval(policy_cfg.policy_head.network)( + **policy_cfg.policy_head.loss_kwargs, + **policy_cfg.policy_head.network_kwargs + ) + + self.latent_queue = [] + self.max_seq_len = policy_cfg.transformer_max_seq_len + + def temporal_encode(self, x): + pos_emb = self.temporal_position_encoding_fn(x) + x = x + pos_emb.unsqueeze(1) # (B, T, num_modality, E) + sh = x.shape + self.temporal_transformer.compute_mask(x.shape) + + x = TensorUtils.join_dimensions(x, 1, 2) # (B, T*num_modality, E) + x = self.temporal_transformer(x) + x = x.reshape(*sh) + return x[:, :, 0] # (B, T, E) + + def spatial_encode(self, data): + # 1. encode extra + extra = self.extra_encoder(data["obs"]) # (B, T, num_extra, E) + + # 2. encode language, treat it as action token + B, T = extra.shape[:2] + text_encoded = self.language_encoder(data) # (B, E) + text_encoded = text_encoded.view(B, 1, 1, -1).expand( + -1, T, -1, -1 + ) # (B, T, 1, E) + encoded = [text_encoded, extra] + + # 3. encode image + for img_name in self.image_encoders.keys(): + x = data["obs"][img_name] + if x.shape[-1] == 3: + x = x.permute(0, 1, 4, 2, 3) + B, T, C, H, W = x.shape + img_encoded = self.image_encoders[img_name]["encoder"](x).view(B, T, 1, -1) + encoded.append(img_encoded) + encoded = torch.cat(encoded, -2) # (B, T, num_modalities, E) + return encoded + + def forward(self, data): + x = self.spatial_encode(data) + x = self.temporal_encode(x) + dist = self.policy_head(x) + return dist + + def get_action(self, data): + self.eval() + with torch.no_grad(): + data = self.preprocess_input(data, train_mode=False) + x = self.spatial_encode(data) + self.latent_queue.append(x) + if len(self.latent_queue) > self.max_seq_len: + self.latent_queue.pop(0) + x = torch.cat(self.latent_queue, dim=1) # (B, T, H_all) + x = self.temporal_encode(x) + dist = self.policy_head(x[:, -1]) + action = dist.sample().detach().cpu() + return action.view(action.shape[0], -1).numpy() + + def reset(self): + self.latent_queue = [] diff --git a/libero/lifelong/models/bc_transformer_policy.py b/libero/lifelong/models/bc_transformer_policy.py index e9c8c3db..8f6e2de9 100644 --- a/libero/lifelong/models/bc_transformer_policy.py +++ b/libero/lifelong/models/bc_transformer_policy.py @@ -255,6 +255,8 @@ def temporal_encode(self, x): def spatial_encode(self, data): # 1. encode extra + data["obs"]["gripper_states"] = data["obs"]["gripper_states"].float() + data["obs"]["joint_states"] = data["obs"]["joint_states"].float() extra = self.extra_encoder(data["obs"]) # (B, T, num_extra, E) # 2. encode language, treat it as action token @@ -268,7 +270,10 @@ def spatial_encode(self, data): # 3. encode image for img_name in self.image_encoders.keys(): x = data["obs"][img_name] + if x.shape[-1] == 3: + x.permute(0, 1, 4, 2, 3) B, T, C, H, W = x.shape + assert C == 3 img_encoded = self.image_encoders[img_name]["encoder"]( x.reshape(B * T, C, H, W), langs=data["task_emb"] diff --git a/libero/lifelong/models/modules/data_augmentation.py b/libero/lifelong/models/modules/data_augmentation.py index 667e08c9..61450e6c 100644 --- a/libero/lifelong/models/modules/data_augmentation.py +++ b/libero/lifelong/models/modules/data_augmentation.py @@ -171,6 +171,8 @@ def forward(self, x_groups): split_channels.append(x_groups[i].shape[1]) if self.training: x = torch.cat(x_groups, dim=1) + if x.shape[-1] == 3 and x.shape[-2] == 128: + x = x.permute(0, 1, 4, 2, 3) out = self.aug_layer(x) out = torch.split(out, split_channels, dim=1) return out diff --git a/libero/lifelong/utils.py b/libero/lifelong/utils.py index c0c3f10e..b1912f7a 100644 --- a/libero/lifelong/utils.py +++ b/libero/lifelong/utils.py @@ -56,7 +56,7 @@ def torch_save_model(model, model_path, cfg=None, previous_masks=None): def torch_load_model(model_path, map_location=None): - model_dict = torch.load(model_path, map_location=map_location) + model_dict = torch.load(model_path, map_location=map_location, weights_only=False) cfg = None if "cfg" in model_dict: cfg = model_dict["cfg"] @@ -216,5 +216,9 @@ def get_task_embs(cfg, descriptions): return_tensors="pt", # ask the function to return PyTorch tensors ) task_embs = model(**tokens)["pooler_output"].detach() + elif cfg.task_embedding_format == "sentence-similarity": + from rwla.models import LanguageEncoder + encoder = LanguageEncoder() + task_embs = encoder(descriptions).detach() cfg.policy.language_encoder.network_kwargs.input_size = task_embs.shape[-1] return task_embs diff --git a/test_data.py b/test_data.py new file mode 100644 index 00000000..43b483bc --- /dev/null +++ b/test_data.py @@ -0,0 +1,14 @@ +import h5py +import os + +# Print dataset path +dataset_path = "/home/tiejean/Workspace/TaskAgonosticManipulation/LIBERO/libero/datasets/libero_object/pick_up_the_alphabet_soup_and_place_it_in_the_basket_demo.hdf5" +print(f"Inspecting dataset at: {dataset_path}") + +# Open the HDF5 file and list keys +with h5py.File(dataset_path, "r") as d: + import pdb; pdb.set_trace() + print("Keys in dataset:", list(d.keys())) + print("All demos: ", list(d["data"].keys())) + print("Data format: ", list(d["data"].keys())) + d["data"]["demo_0"].keys() \ No newline at end of file diff --git a/test_sim.py b/test_sim.py new file mode 100644 index 00000000..e1ef44eb --- /dev/null +++ b/test_sim.py @@ -0,0 +1,38 @@ +from libero.libero import benchmark +from libero.libero.envs import OffScreenRenderEnv + + +benchmark_dict = benchmark.get_benchmark_dict() +task_suite_name = "libero_object" # can also choose libero_spatial, libero_object, etc. +task_suite = benchmark_dict[task_suite_name]() + +# retrieve a specific task +task_id = 0 +task = task_suite.get_task(task_id) +task_name = task.name +task_description = task.language +import os +from libero.libero import get_libero_path +task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) +print(f"[info] retrieving task {task_id} from suite {task_suite_name}, the " + \ + f"language instruction is {task_description}, and the bddl file is {task_bddl_file}") + +# step over the environment +env_args = { + "bddl_file_name": task_bddl_file, + "camera_heights": 128, + "camera_widths": 128 +} +env = OffScreenRenderEnv(**env_args) +env.seed(0) +env.reset() +init_states = task_suite.get_task_init_states(task_id) # for benchmarking purpose, we fix the a set of initial states +init_state_id = 0 +env.set_init_state(init_states[init_state_id]) + +dummy_action = [0.] * 7 +import pdb; pdb.set_trace() +for step in range(10): + obs, reward, done, info = env.step(dummy_action) + import pdb; pdb.set_trace() +env.close() \ No newline at end of file