-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtrain_libero_policy_diff_action.py
More file actions
executable file
·52 lines (41 loc) · 2.26 KB
/
train_libero_policy_diff_action.py
File metadata and controls
executable file
·52 lines (41 loc) · 2.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import argparse
# environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# default track transformer path
DEFAULT_TRACK_TRANSFORMERS = {
"libero_spatial": "./results/track_transformer/libero_track_transformer_libero-spatial/",
"libero_object": "./results/track_transformer/libero_track_transformer_libero-object/",
"libero_goal": "./results/track_transformer/libero_track_transformer_libero-goal/",
"libero_10": "./results/track_transformer/libero_track_transformer_libero-100/",
}
# input parameters
parser = argparse.ArgumentParser()
parser.add_argument("--suite", default="libero_10", choices=['libero_base', "libero_complex", "libero_spatial", "libero_object", "libero_goal", "libero_10"],
help="The name of the desired suite, where libero_10 is the alias of libero_long.")
parser.add_argument("-tt", "--track-transformer", default=None, help="Then path to the trained track transformer.")
parser.add_argument("--config", default='libero_vilt_dino_siglip_wm', help="Then config name of yaml.")
args = parser.parse_args()
# training configs
CONFIG_NAME = args.config
train_gpu_ids = [0, ]
NUM_DEMOS = 10
root_dir = "/home/huang/code/ATM/data/atm_libero/"
suite_name = args.suite
task_dir_list = os.listdir(os.path.join(root_dir, suite_name))
task_dir_list.sort()
exp_name = 'libero_vilt_dino_siglip_wm_policy'
# dataset
train_path_list = [f"{root_dir}/{suite_name}/{task_dir}/bc_train_{NUM_DEMOS}" for task_dir in task_dir_list]
val_path_list = [f"{root_dir}/{suite_name}/{task_dir}/val" for task_dir in task_dir_list]
track_fn = args.track_transformer or DEFAULT_TRACK_TRANSFORMERS[suite_name]
for seed in range(1):
commond = (f'python -m engine.train_bc_diff_action --config-name={CONFIG_NAME} train_gpus="{train_gpu_ids}" '
f'experiment={suite_name.replace("_", "-")}_demo{NUM_DEMOS}_{exp_name} '
f'train_dataset="{train_path_list}" val_dataset="{val_path_list}" '
f'model_cfg.track_cfg.track_fn={track_fn} '
f'model_cfg.track_cfg.use_zero_track=False '
f'model_cfg.spatial_transformer_cfg.use_language_token=False '
f'model_cfg.temporal_transformer_cfg.use_language_token=False '
f'seed={seed} ')
os.system(commond)