Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions robo_manip_baselines/bin/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
import sys

import wandb


def camel_to_snake(name):
"""Converts camelCase or PascalCase to snake_case (also converts the first letter to lowercase)"""
Expand Down Expand Up @@ -34,6 +36,12 @@ def main():
parser.add_argument(
"-h", "--help", action="store_true", help="Show this help message and continue"
)
parser.add_argument(
"--sweep", action="store_true", help="Run sweep instead of normal training"
)
parser.add_argument(
"--sweep_count", type=int, default=10, help="Number of sweep runs"
)

args, remaining_argv = parser.parse_known_args()
sys.argv = [sys.argv[0]] + remaining_argv
Expand All @@ -50,9 +58,19 @@ def main():
)
TrainPolicyClass = getattr(policy_module, f"Train{args.policy}")

train = TrainPolicyClass()
train.run()
train.close()
if args.sweep:
print(f"[INFO] Running sweep for policy {args.policy}")
sweep_config = TrainPolicyClass.get_sweep_config()
sweep_train_fn = TrainPolicyClass.sweep_entrypoint()

sweep_id = wandb.sweep(sweep_config, project="robomanip-act")
wandb.agent(sweep_id, function=sweep_train_fn, count=args.sweep_count)

else:
print(f"[INFO] Running normal training for policy {args.policy}")
train = TrainPolicyClass()
train.run()
train.close()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions robo_manip_baselines/common/base/TrainBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,15 @@ def save_current_ckpt(self, ckpt_suffix, policy=None):

ckpt_path = os.path.join(self.args.checkpoint_dir, f"policy_{ckpt_suffix}.ckpt")
torch.save(policy.state_dict(), ckpt_path)
return ckpt_path

def save_best_ckpt(self):
ckpt_path = os.path.join(self.args.checkpoint_dir, "policy_best.ckpt")
torch.save(self.best_ckpt_info["state_dict"], ckpt_path)
print(
f"[{self.__class__.__name__}] Best val loss is {self.best_ckpt_info['loss']:.3f} at epoch {self.best_ckpt_info['epoch']}"
)
return ckpt_path

def get_total_memory_usage(self):
process = psutil.Process()
Expand Down
111 changes: 101 additions & 10 deletions robo_manip_baselines/policy/act/TrainAct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

import torch
import wandb
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../third_party/act"))
Expand Down Expand Up @@ -71,6 +72,39 @@ def setup_policy(self):
print(f" - chunk size: {self.args.chunk_size}")

def train_loop(self):
# Update arguments from wandb.config when using W&B sweeps
if wandb.run is not None and hasattr(wandb, "config"):
config = wandb.config
self.args.lr = config.lr
self.args.kl_weight = config.kl_weight
self.args.chunk_size = config.chunk_size
self.args.hidden_dim = config.hidden_dim

# Experiment Tracking & Visualization
wandb.init(
project="robomanip-act",
config={
"learning_rate": self.args.lr,
"epochs": self.args.num_epochs,
"batch_size": self.args.batch_size,
"kl_weight": self.args.kl_weight,
"chunk_size": self.args.chunk_size,
"hidden_dim": self.args.hidden_dim,
"dim_feedforward": self.args.dim_feedforward,
"model": "ACTPolicy",
"dataset": getattr(self.args, "dataset_name", "Unknown"),
"device": torch.cuda.get_device_name(0)
if torch.cuda.is_available()
else "CPU",
},
tags=["experiment-tracking", "sweep-ready"],
notes="Training with wandb integration",
)

# Log gradients and parameters
wandb.watch(self.policy, log="all")

step = 0
for epoch in tqdm(range(self.args.num_epochs)):
# Run train step
self.policy.train()
Expand All @@ -81,8 +115,22 @@ def train_loop(self):
loss = batch_result["loss"]
loss.backward()
self.optimizer.step()
batch_result_list.append(self.detach_batch_result(batch_result))
self.log_epoch_summary(batch_result_list, "train", epoch)
batch_result_detached = self.detach_batch_result(batch_result)
batch_result_list.append(batch_result_detached)

step += 1
wandb.log(
{
"step": step,
"train_loss_step": loss.item(),
"train_l1_step": batch_result_detached.get("l1", 0),
"train_kl_step": batch_result_detached.get("kl", 0),
}
)

train_epoch_summary = self.log_epoch_summary(
batch_result_list, "train", epoch
)

# Run validation step
with torch.inference_mode():
Expand All @@ -96,12 +144,55 @@ def train_loop(self):
# Update best checkpoint
self.update_best_ckpt(epoch_summary)

# Save current checkpoint
if epoch % max(self.args.num_epochs // 10, 1) == 0:
self.save_current_ckpt(f"epoch{epoch:0>3}")
wandb.log(
{
"epoch": epoch,
"train_loss": train_epoch_summary.get("loss", 0),
"train_l1": train_epoch_summary.get("l1", 0),
"train_kl": train_epoch_summary.get("kl", 0),
"val_loss": epoch_summary.get("loss", 0),
"val_l1": epoch_summary.get("l1", 0),
"val_kl": epoch_summary.get("kl", 0),
}
)

# Save last checkpoint
self.save_current_ckpt("last")

# Save best checkpoint
self.save_best_ckpt()
if epoch % max(self.args.num_epochs // 10, 1) == 0:
model_path = self.save_current_ckpt(f"epoch{epoch:0>3}")
if model_path:
wandb.save(model_path)

# Save last model
last_ckpt_path = self.save_current_ckpt("last")
if last_ckpt_path:
wandb.save(last_ckpt_path)

# Save best model
best_ckpt_path = self.save_best_ckpt()
if best_ckpt_path:
wandb.save(best_ckpt_path)

wandb.finish()

# Sweep entrypoint
@classmethod
def sweep_entrypoint(cls):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to define this method in a per-policy class?
If it is sufficient to define it in the if args.sweep: block of Train.py, that would be simpler.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! You're right — I’ve moved the sweep logic to Train.py's if args.sweep: block as suggested.

def sweep_train():
trainer = cls()
trainer.run()
trainer.close()

return sweep_train

# Sweep config
@classmethod
def get_sweep_config(cls):
return {
"method": "bayes",
"metric": {"name": "val_loss", "goal": "minimize"},
"parameters": {
"lr": {"min": 1e-6, "max": 5e-4},
"kl_weight": {"values": [1, 5, 10, 20]},
"chunk_size": {"values": [50, 100, 200]},
"hidden_dim": {"values": [256, 512, 1024]},
},
}
Loading