diff --git a/robo_manip_baselines/bin/Train.py b/robo_manip_baselines/bin/Train.py index 2ac0f6b9..ae774ace 100644 --- a/robo_manip_baselines/bin/Train.py +++ b/robo_manip_baselines/bin/Train.py @@ -3,6 +3,8 @@ import re import sys +import wandb + def camel_to_snake(name): """Converts camelCase or PascalCase to snake_case (also converts the first letter to lowercase)""" @@ -34,6 +36,12 @@ def main(): parser.add_argument( "-h", "--help", action="store_true", help="Show this help message and continue" ) + parser.add_argument( + "--sweep", action="store_true", help="Run sweep instead of normal training" + ) + parser.add_argument( + "--sweep_count", type=int, default=10, help="Number of sweep runs" + ) args, remaining_argv = parser.parse_known_args() sys.argv = [sys.argv[0]] + remaining_argv @@ -50,9 +58,33 @@ def main(): ) TrainPolicyClass = getattr(policy_module, f"Train{args.policy}") - train = TrainPolicyClass() - train.run() - train.close() + if args.sweep: + print(f"[INFO] Running sweep for policy {args.policy}") + + sweep_config = { + "method": "bayes", + "metric": {"name": "val_loss", "goal": "minimize"}, + "parameters": { + "lr": {"min": 1e-6, "max": 5e-4}, + "kl_weight": {"values": [1, 5, 10, 20]}, + "chunk_size": {"values": [50, 100, 200]}, + "hidden_dim": {"values": [256, 512, 1024]}, + }, + } + + def sweep_train(): + train = TrainPolicyClass() + train.run() + train.close() + + sweep_id = wandb.sweep(sweep_config, project="robomanip-act") + wandb.agent(sweep_id, function=sweep_train, count=args.sweep_count) + + else: + print(f"[INFO] Running normal training for policy {args.policy}") + train = TrainPolicyClass() + train.run() + train.close() if __name__ == "__main__": diff --git a/robo_manip_baselines/common/base/TrainBase.py b/robo_manip_baselines/common/base/TrainBase.py index 65455349..1c986eec 100644 --- a/robo_manip_baselines/common/base/TrainBase.py +++ b/robo_manip_baselines/common/base/TrainBase.py @@ -481,6 +481,7 @@ def save_current_ckpt(self, ckpt_suffix, policy=None): ckpt_path = os.path.join(self.args.checkpoint_dir, f"policy_{ckpt_suffix}.ckpt") torch.save(policy.state_dict(), ckpt_path) + return ckpt_path def save_best_ckpt(self): ckpt_path = os.path.join(self.args.checkpoint_dir, "policy_best.ckpt") @@ -488,6 +489,7 @@ def save_best_ckpt(self): print( f"[{self.__class__.__name__}] Best val loss is {self.best_ckpt_info['loss']:.3f} at epoch {self.best_ckpt_info['epoch']}" ) + return ckpt_path def get_total_memory_usage(self): process = psutil.Process() diff --git a/robo_manip_baselines/policy/act/TrainAct.py b/robo_manip_baselines/policy/act/TrainAct.py index 78655d4c..5fc3c2fc 100644 --- a/robo_manip_baselines/policy/act/TrainAct.py +++ b/robo_manip_baselines/policy/act/TrainAct.py @@ -2,6 +2,7 @@ import sys import torch +import wandb from tqdm import tqdm sys.path.append(os.path.join(os.path.dirname(__file__), "../../../third_party/act")) @@ -22,7 +23,7 @@ def set_additional_args(self, parser): parser.set_defaults(image_aug_std=0.1) parser.set_defaults(batch_size=8) - parser.set_defaults(num_epochs=1000) + parser.set_defaults(num_epochs=10) parser.set_defaults(lr=1e-5) parser.add_argument("--kl_weight", type=int, default=10, help="KL weight") @@ -71,6 +72,39 @@ def setup_policy(self): print(f" - chunk size: {self.args.chunk_size}") def train_loop(self): + # Update arguments from wandb.config when using W&B sweeps + if wandb.run is not None and hasattr(wandb, "config"): + config = wandb.config + self.args.lr = config.lr + self.args.kl_weight = config.kl_weight + self.args.chunk_size = config.chunk_size + self.args.hidden_dim = config.hidden_dim + + # Experiment Tracking & Visualization + wandb.init( + project="robomanip-act", + config={ + "learning_rate": self.args.lr, + "epochs": self.args.num_epochs, + "batch_size": self.args.batch_size, + "kl_weight": self.args.kl_weight, + "chunk_size": self.args.chunk_size, + "hidden_dim": self.args.hidden_dim, + "dim_feedforward": self.args.dim_feedforward, + "model": "ACTPolicy", + "dataset": getattr(self.args, "dataset_name", "Unknown"), + "device": torch.cuda.get_device_name(0) + if torch.cuda.is_available() + else "CPU", + }, + tags=["experiment-tracking", "sweep-ready"], + notes="Training with wandb integration", + ) + + # Log gradients and parameters + wandb.watch(self.policy, log="all") + + step = 0 for epoch in tqdm(range(self.args.num_epochs)): # Run train step self.policy.train() @@ -81,8 +115,22 @@ def train_loop(self): loss = batch_result["loss"] loss.backward() self.optimizer.step() - batch_result_list.append(self.detach_batch_result(batch_result)) - self.log_epoch_summary(batch_result_list, "train", epoch) + batch_result_detached = self.detach_batch_result(batch_result) + batch_result_list.append(batch_result_detached) + + step += 1 + wandb.log( + { + "step": step, + "train_loss_step": loss.item(), + "train_l1_step": batch_result_detached.get("l1", 0), + "train_kl_step": batch_result_detached.get("kl", 0), + } + ) + + train_epoch_summary = self.log_epoch_summary( + batch_result_list, "train", epoch + ) # Run validation step with torch.inference_mode(): @@ -96,12 +144,31 @@ def train_loop(self): # Update best checkpoint self.update_best_ckpt(epoch_summary) - # Save current checkpoint + wandb.log( + { + "epoch": epoch, + "train_loss": train_epoch_summary.get("loss", 0), + "train_l1": train_epoch_summary.get("l1", 0), + "train_kl": train_epoch_summary.get("kl", 0), + "val_loss": epoch_summary.get("loss", 0), + "val_l1": epoch_summary.get("l1", 0), + "val_kl": epoch_summary.get("kl", 0), + } + ) + if epoch % max(self.args.num_epochs // 10, 1) == 0: - self.save_current_ckpt(f"epoch{epoch:0>3}") + model_path = self.save_current_ckpt(f"epoch{epoch:0>3}") + if model_path: + wandb.save(model_path) + + # Save last model + last_ckpt_path = self.save_current_ckpt("last") + if last_ckpt_path: + wandb.save(last_ckpt_path) - # Save last checkpoint - self.save_current_ckpt("last") + # Save best model + best_ckpt_path = self.save_best_ckpt() + if best_ckpt_path: + wandb.save(best_ckpt_path) - # Save best checkpoint - self.save_best_ckpt() + wandb.finish()