Skip to content

Commit b402af3

Browse files
committed
Support manually set val_incides and train_indices
1 parent 713ef77 commit b402af3

File tree

6 files changed

+65
-7
lines changed

6 files changed

+65
-7
lines changed

configs/actorshq.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ scene_id = 0
1616
resolution = 4 # Resolution of the actorshq dataset (1, 2, 4)
1717
data_factor = 1 # Downsample factor for the dataset. ActorsHQ dataset is already downsampled with correct intrinsics.
1818
test_every = 8 # Every N images is a test image
19+
# val_indices = [0, 8, 16, 24] # Manual validation view indices (overrides test_every)
20+
# train_indices = [1, 2, 3, 5, 7] # Manual training view indices (overrides test_every)
21+
1922
# patch_size = null # Random crop size for training (experimental)
2023
# global_scale = 1.0 # Global scale factor for scene size
2124
# normalize_world_space = true # Normalize the world space

examples/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class Config:
5757
result_dir: str = "results/garden"
5858
# Every N images there is a test image
5959
test_every: int = 8
60+
# Manual validation view indices (overrides test_every)
61+
val_indices: Optional[List[int]] = None
62+
# Manual training view indices (overrides test_every)
63+
train_indices: Optional[List[int]] = None
6064
# Random crop size for training (experimental)
6165
patch_size: Optional[int] = None
6266
# A global scaler that applies to the scene size related parameters

examples/datasets/colmap.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(
6868
factor: int = 1,
6969
normalize: bool = False,
7070
test_every: int = 8,
71+
val_indices: Optional[List[int]] = None,
72+
train_indices: Optional[List[int]] = None,
7173
):
7274
self.data_dir = data_dir
7375
self.factor = factor
@@ -271,11 +273,19 @@ def __init__(
271273
self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
272274
self.transform = transform # np.ndarray, (4, 4)
273275

274-
# Randomly select validation indices to avoid systematic bias in selecting validation cameras
276+
# Determine train/val split
275277
indices = np.arange(len(self.image_names))
276-
num_val = len(self.image_names) // self.test_every
277-
self.val_indices = np.random.choice(indices, size=num_val, replace=False)
278-
print(f"Randomly selected {num_val} validation cameras to avoid systematic bias")
278+
if val_indices is not None:
279+
self.val_indices = np.array(val_indices, dtype=np.int64)
280+
print(f"[Parser] Using {len(self.val_indices)} manual validation indices: {sorted(self.val_indices.tolist())}")
281+
elif train_indices is not None:
282+
train_set = set(train_indices)
283+
self.val_indices = np.array([i for i in indices if i not in train_set], dtype=np.int64)
284+
print(f"[Parser] Using {len(train_indices)} manual training indices; {len(self.val_indices)} views assigned to validation")
285+
else:
286+
num_val = len(self.image_names) // self.test_every
287+
self.val_indices = np.random.choice(indices, size=num_val, replace=False)
288+
print(f"[Parser] Randomly selected {num_val} validation cameras (test_every={self.test_every})")
279289

280290
# load one image to check the size. In the case of tanksandtemples dataset, the
281291
# intrinsics stored in COLMAP corresponds to 2x upsampled images.

examples/simple_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def __init__(
199199
factor=cfg.data_factor,
200200
normalize=cfg.normalize_world_space,
201201
test_every=cfg.test_every,
202+
val_indices=cfg.val_indices,
203+
train_indices=cfg.train_indices,
202204
)
203205
self.trainset = Dataset(
204206
self.parser,

scripts/batch_run.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import subprocess
99
import os
1010
import itertools
11-
from dataclasses import dataclass
12-
from typing import List, Tuple
11+
from dataclasses import dataclass, field
12+
from typing import List, Optional, Tuple
1313
from concurrent.futures import ProcessPoolExecutor, as_completed
1414

1515

@@ -26,14 +26,18 @@
2626
"frame_ids": [1],
2727
"resolution": 4,
2828
"config_path": "./configs/actorshq.toml",
29+
"test_every": 8,
2930
},
3031
"neural3d": {
3132
"base_data_dir": "/synology/Neural_3D_Video",
3233
"sequences": ["coffee_martini", "cook_spinach", "cut_roasted_beef",
33-
# "flame_salmon_1", "flame_steak", "sear_steak"
34+
"flame_salmon_1", "flame_steak", "sear_steak"
3435
],
3536
"frame_ids": [0],
3637
"config_path": "./configs/actorshq.toml",
38+
# "test_every": 8,
39+
"val_indices": [0, 20],
40+
"train_indices": [i for i in range(0, 21) if i not in [0, 20]],
3741
},
3842
}
3943
# =========================================================
@@ -50,6 +54,9 @@ class JobConfig:
5054
config_path: str = "./configs/actorshq.toml"
5155
run_script_path: str = ""
5256
root_run_path: str = ""
57+
test_every: Optional[int] = None
58+
val_indices: Optional[List[int]] = None
59+
train_indices: Optional[List[int]] = None
5360

5461

5562
def build_data_dir_actorshq(cfg: dict, actor: str, sequence: str, frame_id: int) -> str:
@@ -70,6 +77,9 @@ def create_jobs_actorshq(cfg: dict, method: str, cuda_devices: List[str],
7077
sequences = cfg.get("sequences", [])
7178
frame_ids = cfg.get("frame_ids", [])
7279
config_path = cfg.get("config_path", "./configs/actorshq.toml")
80+
test_every = cfg.get("test_every", None)
81+
val_indices = cfg.get("val_indices", None)
82+
train_indices = cfg.get("train_indices", None)
7383

7484
all_combinations = list(itertools.product(actors, sequences, frame_ids))
7585
num_gpus = len(cuda_devices)
@@ -91,6 +101,9 @@ def create_jobs_actorshq(cfg: dict, method: str, cuda_devices: List[str],
91101
config_path=config_path,
92102
run_script_path=run_script_path,
93103
root_run_path=root_run_path,
104+
test_every=test_every,
105+
val_indices=val_indices,
106+
train_indices=train_indices,
94107
)
95108
jobs.append(job)
96109

@@ -103,6 +116,9 @@ def create_jobs_neural3d(cfg: dict, method: str, cuda_devices: List[str],
103116
sequences = cfg.get("sequences", [])
104117
frame_ids = cfg.get("frame_ids", [])
105118
config_path = cfg.get("config_path", "./configs/actorshq.toml")
119+
test_every = cfg.get("test_every", None)
120+
val_indices = cfg.get("val_indices", None)
121+
train_indices = cfg.get("train_indices", None)
106122

107123
all_combinations = list(itertools.product(sequences, frame_ids))
108124
num_gpus = len(cuda_devices)
@@ -124,6 +140,9 @@ def create_jobs_neural3d(cfg: dict, method: str, cuda_devices: List[str],
124140
config_path=config_path,
125141
run_script_path=run_script_path,
126142
root_run_path=root_run_path,
143+
test_every=test_every,
144+
val_indices=val_indices,
145+
train_indices=train_indices,
127146
)
128147
jobs.append(job)
129148

@@ -158,6 +177,12 @@ def run_single_experiment(config: JobConfig) -> int:
158177
"--config", config.config_path,
159178
"--disable_viewer",
160179
]
180+
if config.test_every is not None:
181+
cmd.extend(["--test_every", str(config.test_every)])
182+
if config.val_indices is not None:
183+
cmd.extend(["--val_indices", ",".join(str(i) for i in config.val_indices)])
184+
if config.train_indices is not None:
185+
cmd.extend(["--train_indices", ",".join(str(i) for i in config.train_indices)])
161186

162187
result = subprocess.run(cmd, env=env, cwd=config.root_run_path)
163188
return result.returncode

scripts/run_actorshq.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ def parse_args():
2727
help="Path to config file")
2828
parser.add_argument("--disable_viewer", action="store_true",
2929
help="Disable the viewer")
30+
parser.add_argument("--test_every", type=int, default=None,
31+
help="Every N images is a test image (overrides config)")
32+
parser.add_argument("--val_indices", type=str, default=None,
33+
help="Comma-separated validation view indices (overrides test_every)")
34+
parser.add_argument("--train_indices", type=str, default=None,
35+
help="Comma-separated training view indices (overrides test_every)")
3036
return parser.parse_args()
3137

3238
def run_experiment(config: Config, dist=False):
@@ -149,6 +155,14 @@ def build_exp_name(cfg: Config, prefix: str = "actorshq") -> str:
149155
# Build experiment name
150156
exp_name = build_exp_name(cfg, args.exp_name_prefix)
151157

158+
# Override test_every / val/train indices if provided via CLI
159+
if args.test_every is not None:
160+
cfg.test_every = args.test_every
161+
if args.val_indices is not None:
162+
cfg.val_indices = [int(x) for x in args.val_indices.split(",")]
163+
if args.train_indices is not None:
164+
cfg.train_indices = [int(x) for x in args.train_indices.split(",")]
165+
152166
# Set viewer
153167
cfg.disable_viewer = args.disable_viewer
154168

0 commit comments

Comments
 (0)