Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions robo_manip_baselines/bin/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
import sys

import wandb


def camel_to_snake(name):
"""Converts camelCase or PascalCase to snake_case (also converts the first letter to lowercase)"""
Expand Down Expand Up @@ -34,6 +36,12 @@ def main():
parser.add_argument(
"-h", "--help", action="store_true", help="Show this help message and continue"
)
parser.add_argument(
"--sweep", action="store_true", help="Run sweep instead of normal training"
)
parser.add_argument(
"--sweep_count", type=int, default=10, help="Number of sweep runs"
)

args, remaining_argv = parser.parse_known_args()
sys.argv = [sys.argv[0]] + remaining_argv
Expand All @@ -50,9 +58,33 @@ def main():
)
TrainPolicyClass = getattr(policy_module, f"Train{args.policy}")

train = TrainPolicyClass()
train.run()
train.close()
if args.sweep:
print(f"[INFO] Running sweep for policy {args.policy}")

sweep_config = {
"method": "bayes",
"metric": {"name": "val_loss", "goal": "minimize"},
"parameters": {
"lr": {"min": 1e-6, "max": 5e-4},
"kl_weight": {"values": [1, 5, 10, 20]},
"chunk_size": {"values": [50, 100, 200]},
"hidden_dim": {"values": [256, 512, 1024]},
},
}

def sweep_train():
train = TrainPolicyClass()
train.run()
train.close()

sweep_id = wandb.sweep(sweep_config, project="robomanip-act")
wandb.agent(sweep_id, function=sweep_train, count=args.sweep_count)

else:
print(f"[INFO] Running normal training for policy {args.policy}")
train = TrainPolicyClass()
train.run()
train.close()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions robo_manip_baselines/common/base/TrainBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,15 @@ def save_current_ckpt(self, ckpt_suffix, policy=None):

ckpt_path = os.path.join(self.args.checkpoint_dir, f"policy_{ckpt_suffix}.ckpt")
torch.save(policy.state_dict(), ckpt_path)
return ckpt_path

def save_best_ckpt(self):
ckpt_path = os.path.join(self.args.checkpoint_dir, "policy_best.ckpt")
torch.save(self.best_ckpt_info["state_dict"], ckpt_path)
print(
f"[{self.__class__.__name__}] Best val loss is {self.best_ckpt_info['loss']:.3f} at epoch {self.best_ckpt_info['epoch']}"
)
return ckpt_path

def get_total_memory_usage(self):
process = psutil.Process()
Expand Down
85 changes: 76 additions & 9 deletions robo_manip_baselines/policy/act/TrainAct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

import torch
import wandb
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../third_party/act"))
Expand All @@ -22,7 +23,7 @@ def set_additional_args(self, parser):
parser.set_defaults(image_aug_std=0.1)

parser.set_defaults(batch_size=8)
parser.set_defaults(num_epochs=1000)
parser.set_defaults(num_epochs=10)
parser.set_defaults(lr=1e-5)

parser.add_argument("--kl_weight", type=int, default=10, help="KL weight")
Expand Down Expand Up @@ -71,6 +72,39 @@ def setup_policy(self):
print(f" - chunk size: {self.args.chunk_size}")

def train_loop(self):
# Update arguments from wandb.config when using W&B sweeps
if wandb.run is not None and hasattr(wandb, "config"):
config = wandb.config
self.args.lr = config.lr
self.args.kl_weight = config.kl_weight
self.args.chunk_size = config.chunk_size
self.args.hidden_dim = config.hidden_dim

# Experiment Tracking & Visualization
wandb.init(
project="robomanip-act",
config={
"learning_rate": self.args.lr,
"epochs": self.args.num_epochs,
"batch_size": self.args.batch_size,
"kl_weight": self.args.kl_weight,
"chunk_size": self.args.chunk_size,
"hidden_dim": self.args.hidden_dim,
"dim_feedforward": self.args.dim_feedforward,
"model": "ACTPolicy",
"dataset": getattr(self.args, "dataset_name", "Unknown"),
"device": torch.cuda.get_device_name(0)
if torch.cuda.is_available()
else "CPU",
},
tags=["experiment-tracking", "sweep-ready"],
notes="Training with wandb integration",
)

# Log gradients and parameters
wandb.watch(self.policy, log="all")

step = 0
for epoch in tqdm(range(self.args.num_epochs)):
# Run train step
self.policy.train()
Expand All @@ -81,8 +115,22 @@ def train_loop(self):
loss = batch_result["loss"]
loss.backward()
self.optimizer.step()
batch_result_list.append(self.detach_batch_result(batch_result))
self.log_epoch_summary(batch_result_list, "train", epoch)
batch_result_detached = self.detach_batch_result(batch_result)
batch_result_list.append(batch_result_detached)

step += 1
wandb.log(
{
"step": step,
"train_loss_step": loss.item(),
"train_l1_step": batch_result_detached.get("l1", 0),
"train_kl_step": batch_result_detached.get("kl", 0),
}
)

train_epoch_summary = self.log_epoch_summary(
batch_result_list, "train", epoch
)

# Run validation step
with torch.inference_mode():
Expand All @@ -96,12 +144,31 @@ def train_loop(self):
# Update best checkpoint
self.update_best_ckpt(epoch_summary)

# Save current checkpoint
wandb.log(
{
"epoch": epoch,
"train_loss": train_epoch_summary.get("loss", 0),
"train_l1": train_epoch_summary.get("l1", 0),
"train_kl": train_epoch_summary.get("kl", 0),
"val_loss": epoch_summary.get("loss", 0),
"val_l1": epoch_summary.get("l1", 0),
"val_kl": epoch_summary.get("kl", 0),
}
)

if epoch % max(self.args.num_epochs // 10, 1) == 0:
self.save_current_ckpt(f"epoch{epoch:0>3}")
model_path = self.save_current_ckpt(f"epoch{epoch:0>3}")
if model_path:
wandb.save(model_path)

# Save last model
last_ckpt_path = self.save_current_ckpt("last")
if last_ckpt_path:
wandb.save(last_ckpt_path)

# Save last checkpoint
self.save_current_ckpt("last")
# Save best model
best_ckpt_path = self.save_best_ckpt()
if best_ckpt_path:
wandb.save(best_ckpt_path)

# Save best checkpoint
self.save_best_ckpt()
wandb.finish()
Loading