From d310d2c43849bb497a77e44ca43ee2791bd22b53 Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Fri, 20 Jun 2025 16:18:39 +0900 Subject: [PATCH 1/7] integrate wandb --- robo_manip_baselines/bin/Train.py | 23 +++- robo_manip_baselines/common/base/TrainBase.py | 2 + robo_manip_baselines/policy/act/TrainAct.py | 102 ++++++++++++++++-- 3 files changed, 113 insertions(+), 14 deletions(-) diff --git a/robo_manip_baselines/bin/Train.py b/robo_manip_baselines/bin/Train.py index 2ac0f6b9..aad1ea0e 100644 --- a/robo_manip_baselines/bin/Train.py +++ b/robo_manip_baselines/bin/Train.py @@ -2,7 +2,7 @@ import importlib import re import sys - +import wandb def camel_to_snake(name): """Converts camelCase or PascalCase to snake_case (also converts the first letter to lowercase)""" @@ -34,6 +34,12 @@ def main(): parser.add_argument( "-h", "--help", action="store_true", help="Show this help message and continue" ) + parser.add_argument( + "--sweep", action="store_true", help="Run sweep instead of normal training" + ) + parser.add_argument( + "--sweep_count", type=int, default=10, help="Number of sweep runs" + ) args, remaining_argv = parser.parse_known_args() sys.argv = [sys.argv[0]] + remaining_argv @@ -50,10 +56,19 @@ def main(): ) TrainPolicyClass = getattr(policy_module, f"Train{args.policy}") - train = TrainPolicyClass() - train.run() - train.close() + if args.sweep: + print(f"[INFO] Running sweep for policy {args.policy}") + sweep_config = TrainPolicyClass.get_sweep_config() + sweep_train_fn = TrainPolicyClass.sweep_entrypoint() + + sweep_id = wandb.sweep(sweep_config, project="robomanip-act") + wandb.agent(sweep_id, function=sweep_train_fn, count=args.sweep_count) + else: + print(f"[INFO] Running normal training for policy {args.policy}") + train = TrainPolicyClass() + train.run() + train.close() if __name__ == "__main__": main() diff --git a/robo_manip_baselines/common/base/TrainBase.py b/robo_manip_baselines/common/base/TrainBase.py index 65455349..1c986eec 100644 --- a/robo_manip_baselines/common/base/TrainBase.py +++ b/robo_manip_baselines/common/base/TrainBase.py @@ -481,6 +481,7 @@ def save_current_ckpt(self, ckpt_suffix, policy=None): ckpt_path = os.path.join(self.args.checkpoint_dir, f"policy_{ckpt_suffix}.ckpt") torch.save(policy.state_dict(), ckpt_path) + return ckpt_path def save_best_ckpt(self): ckpt_path = os.path.join(self.args.checkpoint_dir, "policy_best.ckpt") @@ -488,6 +489,7 @@ def save_best_ckpt(self): print( f"[{self.__class__.__name__}] Best val loss is {self.best_ckpt_info['loss']:.3f} at epoch {self.best_ckpt_info['epoch']}" ) + return ckpt_path def get_total_memory_usage(self): process = psutil.Process() diff --git a/robo_manip_baselines/policy/act/TrainAct.py b/robo_manip_baselines/policy/act/TrainAct.py index 78655d4c..33982961 100644 --- a/robo_manip_baselines/policy/act/TrainAct.py +++ b/robo_manip_baselines/policy/act/TrainAct.py @@ -3,6 +3,7 @@ import torch from tqdm import tqdm +import wandb sys.path.append(os.path.join(os.path.dirname(__file__), "../../../third_party/act")) from detr.models.detr_vae import DETRVAE @@ -71,6 +72,37 @@ def setup_policy(self): print(f" - chunk size: {self.args.chunk_size}") def train_loop(self): + # Update args nếu có wandb.config (sweep mode) + if wandb.run is not None and hasattr(wandb, "config"): + config = wandb.config + self.args.lr = config.lr + self.args.kl_weight = config.kl_weight + self.args.chunk_size = config.chunk_size + self.args.hidden_dim = config.hidden_dim + + # Experiment Tracking & Visualization + wandb.init( + project="robomanip-act", + config={ + "learning_rate": self.args.lr, + "epochs": self.args.num_epochs, + "batch_size": self.args.batch_size, + "kl_weight": self.args.kl_weight, + "chunk_size": self.args.chunk_size, + "hidden_dim": self.args.hidden_dim, + "dim_feedforward": self.args.dim_feedforward, + "model": "ACTPolicy", + "dataset": getattr(self.args, "dataset_name", "Unknown"), + "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU", + }, + tags=["experiment-tracking", "sweep-ready"], + notes="Training with wandb integration", + ) + + # Log gradients and parameters + wandb.watch(self.policy, log="all") + + step = 0 for epoch in tqdm(range(self.args.num_epochs)): # Run train step self.policy.train() @@ -81,8 +113,18 @@ def train_loop(self): loss = batch_result["loss"] loss.backward() self.optimizer.step() - batch_result_list.append(self.detach_batch_result(batch_result)) - self.log_epoch_summary(batch_result_list, "train", epoch) + batch_result_detached = self.detach_batch_result(batch_result) + batch_result_list.append(batch_result_detached) + + step += 1 + wandb.log({ + "step": step, + "train_loss_step": loss.item(), + "train_l1_step": batch_result_detached.get("l1", 0), + "train_kl_step": batch_result_detached.get("kl", 0), + }) + + train_epoch_summary = self.log_epoch_summary(batch_result_list, "train", epoch) # Run validation step with torch.inference_mode(): @@ -96,12 +138,52 @@ def train_loop(self): # Update best checkpoint self.update_best_ckpt(epoch_summary) - # Save current checkpoint - if epoch % max(self.args.num_epochs // 10, 1) == 0: - self.save_current_ckpt(f"epoch{epoch:0>3}") - - # Save last checkpoint - self.save_current_ckpt("last") + wandb.log({ + "epoch": epoch, + "train_loss": train_epoch_summary.get("loss", 0), + "train_l1": train_epoch_summary.get("l1", 0), + "train_kl": train_epoch_summary.get("kl", 0), + "val_loss": epoch_summary.get("loss", 0), + "val_l1": epoch_summary.get("l1", 0), + "val_kl": epoch_summary.get("kl", 0), + }) - # Save best checkpoint - self.save_best_ckpt() + if epoch % max(self.args.num_epochs // 10, 1) == 0: + model_path = self.save_current_ckpt(f"epoch{epoch:0>3}") + if model_path: + wandb.save(model_path) + + # Save last model + last_ckpt_path = self.save_current_ckpt("last") + if last_ckpt_path: + wandb.save(last_ckpt_path) + + # Save best model + best_ckpt_path = self.save_best_ckpt() + if best_ckpt_path: + wandb.save(best_ckpt_path) + + wandb.finish() + + # Sweep entrypoint + @classmethod + def sweep_entrypoint(cls): + def sweep_train(): + trainer = cls() + trainer.run() + trainer.close() + return sweep_train + + # Sweep config + @classmethod + def get_sweep_config(cls): + return { + "method": "bayes", + "metric": {"name": "val_loss", "goal": "minimize"}, + "parameters": { + "lr": {"min": 1e-6, "max": 5e-4}, + "kl_weight": {"values": [1, 5, 10, 20]}, + "chunk_size": {"values": [50, 100, 200]}, + "hidden_dim": {"values": [256, 512, 1024]}, + }, + } From 5887532a812502b087416403c6519752bd1e0b72 Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Fri, 20 Jun 2025 16:25:35 +0900 Subject: [PATCH 2/7] fix act import library --- third_party/act | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/act b/third_party/act index ad329e43..5902c412 160000 --- a/third_party/act +++ b/third_party/act @@ -1 +1 @@ -Subproject commit ad329e433dfaefb9579d24f3358bbc7c5d073898 +Subproject commit 5902c4123d3a43aed32eb0a177116a1ed2bcd864 From c9982324f8edaca3a4ca1b2f7036f8c16ceb664c Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Fri, 20 Jun 2025 17:00:53 +0900 Subject: [PATCH 3/7] Revert "fix act import library" This reverts commit 5887532a812502b087416403c6519752bd1e0b72. --- third_party/act | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/act b/third_party/act index 5902c412..ad329e43 160000 --- a/third_party/act +++ b/third_party/act @@ -1 +1 @@ -Subproject commit 5902c4123d3a43aed32eb0a177116a1ed2bcd864 +Subproject commit ad329e433dfaefb9579d24f3358bbc7c5d073898 From 8dea819b838dd9625c688b31f06b6e06082382e2 Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Fri, 20 Jun 2025 17:31:22 +0900 Subject: [PATCH 4/7] fix format by pre-commit --- robo_manip_baselines/bin/Train.py | 3 ++ robo_manip_baselines/policy/act/TrainAct.py | 47 ++++++++++++--------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/robo_manip_baselines/bin/Train.py b/robo_manip_baselines/bin/Train.py index aad1ea0e..e46d85cc 100644 --- a/robo_manip_baselines/bin/Train.py +++ b/robo_manip_baselines/bin/Train.py @@ -2,8 +2,10 @@ import importlib import re import sys + import wandb + def camel_to_snake(name): """Converts camelCase or PascalCase to snake_case (also converts the first letter to lowercase)""" name = re.sub( @@ -70,5 +72,6 @@ def main(): train.run() train.close() + if __name__ == "__main__": main() diff --git a/robo_manip_baselines/policy/act/TrainAct.py b/robo_manip_baselines/policy/act/TrainAct.py index 33982961..c7ed669b 100644 --- a/robo_manip_baselines/policy/act/TrainAct.py +++ b/robo_manip_baselines/policy/act/TrainAct.py @@ -2,8 +2,8 @@ import sys import torch -from tqdm import tqdm import wandb +from tqdm import tqdm sys.path.append(os.path.join(os.path.dirname(__file__), "../../../third_party/act")) from detr.models.detr_vae import DETRVAE @@ -93,7 +93,9 @@ def train_loop(self): "dim_feedforward": self.args.dim_feedforward, "model": "ACTPolicy", "dataset": getattr(self.args, "dataset_name", "Unknown"), - "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU", + "device": torch.cuda.get_device_name(0) + if torch.cuda.is_available() + else "CPU", }, tags=["experiment-tracking", "sweep-ready"], notes="Training with wandb integration", @@ -117,14 +119,18 @@ def train_loop(self): batch_result_list.append(batch_result_detached) step += 1 - wandb.log({ - "step": step, - "train_loss_step": loss.item(), - "train_l1_step": batch_result_detached.get("l1", 0), - "train_kl_step": batch_result_detached.get("kl", 0), - }) - - train_epoch_summary = self.log_epoch_summary(batch_result_list, "train", epoch) + wandb.log( + { + "step": step, + "train_loss_step": loss.item(), + "train_l1_step": batch_result_detached.get("l1", 0), + "train_kl_step": batch_result_detached.get("kl", 0), + } + ) + + train_epoch_summary = self.log_epoch_summary( + batch_result_list, "train", epoch + ) # Run validation step with torch.inference_mode(): @@ -138,15 +144,17 @@ def train_loop(self): # Update best checkpoint self.update_best_ckpt(epoch_summary) - wandb.log({ - "epoch": epoch, - "train_loss": train_epoch_summary.get("loss", 0), - "train_l1": train_epoch_summary.get("l1", 0), - "train_kl": train_epoch_summary.get("kl", 0), - "val_loss": epoch_summary.get("loss", 0), - "val_l1": epoch_summary.get("l1", 0), - "val_kl": epoch_summary.get("kl", 0), - }) + wandb.log( + { + "epoch": epoch, + "train_loss": train_epoch_summary.get("loss", 0), + "train_l1": train_epoch_summary.get("l1", 0), + "train_kl": train_epoch_summary.get("kl", 0), + "val_loss": epoch_summary.get("loss", 0), + "val_l1": epoch_summary.get("l1", 0), + "val_kl": epoch_summary.get("kl", 0), + } + ) if epoch % max(self.args.num_epochs // 10, 1) == 0: model_path = self.save_current_ckpt(f"epoch{epoch:0>3}") @@ -172,6 +180,7 @@ def sweep_train(): trainer = cls() trainer.run() trainer.close() + return sweep_train # Sweep config From 92ebaf3a72623fb82c1de26947003700fcc83e3d Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Mon, 23 Jun 2025 11:55:21 +0900 Subject: [PATCH 5/7] fix comment --- robo_manip_baselines/policy/act/TrainAct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robo_manip_baselines/policy/act/TrainAct.py b/robo_manip_baselines/policy/act/TrainAct.py index c7ed669b..08b47358 100644 --- a/robo_manip_baselines/policy/act/TrainAct.py +++ b/robo_manip_baselines/policy/act/TrainAct.py @@ -72,7 +72,7 @@ def setup_policy(self): print(f" - chunk size: {self.args.chunk_size}") def train_loop(self): - # Update args nếu có wandb.config (sweep mode) + # Update arguments from wandb.config when using W&B sweeps if wandb.run is not None and hasattr(wandb, "config"): config = wandb.config self.args.lr = config.lr From 2505e3a813f111255cc2aea3ea06b4d857ffc7cd Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Wed, 25 Jun 2025 23:29:50 +0900 Subject: [PATCH 6/7] move sweep to Train.py --- robo_manip_baselines/bin/Train.py | 20 ++++++++++++--- robo_manip_baselines/policy/act/TrainAct.py | 28 ++------------------- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/robo_manip_baselines/bin/Train.py b/robo_manip_baselines/bin/Train.py index e46d85cc..ae774ace 100644 --- a/robo_manip_baselines/bin/Train.py +++ b/robo_manip_baselines/bin/Train.py @@ -60,11 +60,25 @@ def main(): if args.sweep: print(f"[INFO] Running sweep for policy {args.policy}") - sweep_config = TrainPolicyClass.get_sweep_config() - sweep_train_fn = TrainPolicyClass.sweep_entrypoint() + + sweep_config = { + "method": "bayes", + "metric": {"name": "val_loss", "goal": "minimize"}, + "parameters": { + "lr": {"min": 1e-6, "max": 5e-4}, + "kl_weight": {"values": [1, 5, 10, 20]}, + "chunk_size": {"values": [50, 100, 200]}, + "hidden_dim": {"values": [256, 512, 1024]}, + }, + } + + def sweep_train(): + train = TrainPolicyClass() + train.run() + train.close() sweep_id = wandb.sweep(sweep_config, project="robomanip-act") - wandb.agent(sweep_id, function=sweep_train_fn, count=args.sweep_count) + wandb.agent(sweep_id, function=sweep_train, count=args.sweep_count) else: print(f"[INFO] Running normal training for policy {args.policy}") diff --git a/robo_manip_baselines/policy/act/TrainAct.py b/robo_manip_baselines/policy/act/TrainAct.py index 08b47358..0920c175 100644 --- a/robo_manip_baselines/policy/act/TrainAct.py +++ b/robo_manip_baselines/policy/act/TrainAct.py @@ -23,7 +23,7 @@ def set_additional_args(self, parser): parser.set_defaults(image_aug_std=0.1) parser.set_defaults(batch_size=8) - parser.set_defaults(num_epochs=1000) + parser.set_defaults(num_epochs=10) parser.set_defaults(lr=1e-5) parser.add_argument("--kl_weight", type=int, default=10, help="KL weight") @@ -171,28 +171,4 @@ def train_loop(self): if best_ckpt_path: wandb.save(best_ckpt_path) - wandb.finish() - - # Sweep entrypoint - @classmethod - def sweep_entrypoint(cls): - def sweep_train(): - trainer = cls() - trainer.run() - trainer.close() - - return sweep_train - - # Sweep config - @classmethod - def get_sweep_config(cls): - return { - "method": "bayes", - "metric": {"name": "val_loss", "goal": "minimize"}, - "parameters": { - "lr": {"min": 1e-6, "max": 5e-4}, - "kl_weight": {"values": [1, 5, 10, 20]}, - "chunk_size": {"values": [50, 100, 200]}, - "hidden_dim": {"values": [256, 512, 1024]}, - }, - } + wandb.finish() \ No newline at end of file From c89d923653ff26ef374570a355081aadac0950ff Mon Sep 17 00:00:00 2001 From: nqtabokado Date: Wed, 25 Jun 2025 23:35:32 +0900 Subject: [PATCH 7/7] fix pre-commit --- robo_manip_baselines/policy/act/TrainAct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robo_manip_baselines/policy/act/TrainAct.py b/robo_manip_baselines/policy/act/TrainAct.py index 0920c175..5fc3c2fc 100644 --- a/robo_manip_baselines/policy/act/TrainAct.py +++ b/robo_manip_baselines/policy/act/TrainAct.py @@ -171,4 +171,4 @@ def train_loop(self): if best_ckpt_path: wandb.save(best_ckpt_path) - wandb.finish() \ No newline at end of file + wandb.finish()