88import subprocess
99import os
1010import itertools
11- from dataclasses import dataclass
12- from typing import List , Tuple
11+ from dataclasses import dataclass , field
12+ from typing import List , Optional , Tuple
1313from concurrent .futures import ProcessPoolExecutor , as_completed
1414
1515
2626 "frame_ids" : [1 ],
2727 "resolution" : 4 ,
2828 "config_path" : "./configs/actorshq.toml" ,
29+ "test_every" : 8 ,
2930 },
3031 "neural3d" : {
3132 "base_data_dir" : "/synology/Neural_3D_Video" ,
3233 "sequences" : ["coffee_martini" , "cook_spinach" , "cut_roasted_beef" ,
33- # "flame_salmon_1", "flame_steak", "sear_steak"
34+ "flame_salmon_1" , "flame_steak" , "sear_steak"
3435 ],
3536 "frame_ids" : [0 ],
3637 "config_path" : "./configs/actorshq.toml" ,
38+ # "test_every": 8,
39+ "val_indices" : [0 , 20 ],
40+ "train_indices" : [i for i in range (0 , 21 ) if i not in [0 , 20 ]],
3741 },
3842}
3943# =========================================================
@@ -50,6 +54,9 @@ class JobConfig:
5054 config_path : str = "./configs/actorshq.toml"
5155 run_script_path : str = ""
5256 root_run_path : str = ""
57+ test_every : Optional [int ] = None
58+ val_indices : Optional [List [int ]] = None
59+ train_indices : Optional [List [int ]] = None
5360
5461
5562def build_data_dir_actorshq (cfg : dict , actor : str , sequence : str , frame_id : int ) -> str :
@@ -70,6 +77,9 @@ def create_jobs_actorshq(cfg: dict, method: str, cuda_devices: List[str],
7077 sequences = cfg .get ("sequences" , [])
7178 frame_ids = cfg .get ("frame_ids" , [])
7279 config_path = cfg .get ("config_path" , "./configs/actorshq.toml" )
80+ test_every = cfg .get ("test_every" , None )
81+ val_indices = cfg .get ("val_indices" , None )
82+ train_indices = cfg .get ("train_indices" , None )
7383
7484 all_combinations = list (itertools .product (actors , sequences , frame_ids ))
7585 num_gpus = len (cuda_devices )
@@ -91,6 +101,9 @@ def create_jobs_actorshq(cfg: dict, method: str, cuda_devices: List[str],
91101 config_path = config_path ,
92102 run_script_path = run_script_path ,
93103 root_run_path = root_run_path ,
104+ test_every = test_every ,
105+ val_indices = val_indices ,
106+ train_indices = train_indices ,
94107 )
95108 jobs .append (job )
96109
@@ -103,6 +116,9 @@ def create_jobs_neural3d(cfg: dict, method: str, cuda_devices: List[str],
103116 sequences = cfg .get ("sequences" , [])
104117 frame_ids = cfg .get ("frame_ids" , [])
105118 config_path = cfg .get ("config_path" , "./configs/actorshq.toml" )
119+ test_every = cfg .get ("test_every" , None )
120+ val_indices = cfg .get ("val_indices" , None )
121+ train_indices = cfg .get ("train_indices" , None )
106122
107123 all_combinations = list (itertools .product (sequences , frame_ids ))
108124 num_gpus = len (cuda_devices )
@@ -124,6 +140,9 @@ def create_jobs_neural3d(cfg: dict, method: str, cuda_devices: List[str],
124140 config_path = config_path ,
125141 run_script_path = run_script_path ,
126142 root_run_path = root_run_path ,
143+ test_every = test_every ,
144+ val_indices = val_indices ,
145+ train_indices = train_indices ,
127146 )
128147 jobs .append (job )
129148
@@ -158,6 +177,12 @@ def run_single_experiment(config: JobConfig) -> int:
158177 "--config" , config .config_path ,
159178 "--disable_viewer" ,
160179 ]
180+ if config .test_every is not None :
181+ cmd .extend (["--test_every" , str (config .test_every )])
182+ if config .val_indices is not None :
183+ cmd .extend (["--val_indices" , "," .join (str (i ) for i in config .val_indices )])
184+ if config .train_indices is not None :
185+ cmd .extend (["--train_indices" , "," .join (str (i ) for i in config .train_indices )])
161186
162187 result = subprocess .run (cmd , env = env , cwd = config .root_run_path )
163188 return result .returncode
0 commit comments