Skip to content

Commit d82742e

Browse files
committed
misc: update the scripts for exporting ply files
1 parent efb8bad commit d82742e

File tree

3 files changed

+86
-40
lines changed

3 files changed

+86
-40
lines changed

examples/benchmarks/dyngs/dyngs.sh

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ TEST_VIEWS=(
77
["Bartender"]="8 10 12"
88
)
99

10+
declare -A START_FRAMES
11+
START_FRAMES=(
12+
["CBA"]=0
13+
["Bartender"]=50
14+
)
15+
1016
RESULT_DIR="results/dyngs"
1117

1218
NUM_FRAME=65
@@ -15,28 +21,29 @@ run_single_scene() {
1521
local GPU_ID=$1
1622
local SCENE=$2
1723
local TEST_VIEW_IDS=${TEST_VIEWS[$SCENE]}
24+
local START_FRAME=${START_FRAMES[$SCENE]}
1825

19-
echo "Running $SCENE"
26+
echo "Running $SCENE START_FRAME @ ${START_FRAME}"
2027

21-
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer_dyngs.py compression_sim \
22-
--model_path $RESULT_DIR/$SCENE/ \
23-
--data_dir $SCENE_DIR/$SCENE/colmap/colmap_50 \
24-
--result_dir $RESULT_DIR/$SCENE/ \
25-
--downscale_factor 1 \
26-
--duration $NUM_FRAME \
27-
--batch_size 2 \
28-
--max_steps 60_000 \
29-
--refine_start_iter 3_000 \
30-
--refine_stop_iter 30_000 \
31-
--refine_every 100 \
32-
--reset_every 6_000 \
33-
--pause_refine_after_reset 500 \
34-
--strategy Modified_STG_Strategy \
35-
--test_view_id $TEST_VIEW_IDS
28+
# CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer_dyngs.py compression_sim \
29+
# --model_path $RESULT_DIR/$SCENE/ \
30+
# --data_dir $SCENE_DIR/$SCENE/colmap/colmap_${START_FRAME} \
31+
# --result_dir $RESULT_DIR/$SCENE/ \
32+
# --downscale_factor 1 \
33+
# --duration $NUM_FRAME \
34+
# --batch_size 2 \
35+
# --max_steps 60_000 \
36+
# --refine_start_iter 3_000 \
37+
# --refine_stop_iter 30_000 \
38+
# --refine_every 100 \
39+
# --reset_every 6_000 \
40+
# --pause_refine_after_reset 500 \
41+
# --strategy Modified_STG_Strategy \
42+
# --test_view_id $TEST_VIEW_IDS
3643

3744
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer_dyngs.py default \
3845
--model_path $RESULT_DIR/$SCENE/ \
39-
--data_dir $SCENE_DIR/$SCENE/colmap/colmap_50 \
46+
--data_dir $SCENE_DIR/$SCENE/colmap/colmap_${START_FRAME} \
4047
--result_dir $RESULT_DIR/$SCENE/ \
4148
--downscale_factor 1 \
4249
--duration $NUM_FRAME \
@@ -46,7 +53,7 @@ run_single_scene() {
4653
--test_view_id $TEST_VIEW_IDS
4754
}
4855

49-
GPU_LIST=(7)
56+
GPU_LIST=(6)
5057
GPU_COUNT=${#GPU_LIST[@]}
5158

5259
SCENE_IDX=-1
Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,51 @@
1+
# Option for pruning
2+
PRUNE_SPLATS=False
3+
echo "PRUNE_SPLATS is set to: $PRUNE_SPLATS"
4+
5+
# Datadir and dataset
16
SCENE_DIR="data/GSC"
2-
SCENE_LIST="Bartender" # CBA Bartender
7+
SCENE_LIST="CBA" # CBA Bartender
38

49
declare -A TEST_VIEWS
510
TEST_VIEWS=(
611
["CBA"]="7 22"
712
["Bartender"]="8 10 12"
813
)
914

10-
RESULT_DIR="results/dyngs"
15+
declare -A START_FRAMES
16+
START_FRAMES=(
17+
["CBA"]=0
18+
["Bartender"]=50
19+
)
1120

21+
RESULT_DIR="results/dyngs"
1222
NUM_FRAME=65
1323

1424
run_single_scene() {
1525
local GPU_ID=$1
1626
local SCENE=$2
1727
local TEST_VIEW_IDS=${TEST_VIEWS[$SCENE]}
28+
local START_FRAME=${START_FRAMES[$SCENE]}
1829

1930
echo "Running $SCENE"
2031

21-
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer_dyngs.py default \
32+
CMD="CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer_dyngs.py default \
2233
--model_path $RESULT_DIR/$SCENE/ \
23-
--data_dir $SCENE_DIR/$SCENE/colmap/colmap_50 \
34+
--data_dir $SCENE_DIR/$SCENE/colmap/colmap_${START_FRAME} \
2435
--result_dir $RESULT_DIR/$SCENE/ \
2536
--downscale_factor 1 \
2637
--duration $NUM_FRAME \
2738
--lpips_net vgg \
28-
--compression stg \
2939
--ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_best_rank0.pt \
3040
--test_view_id $TEST_VIEW_IDS \
31-
--enable_dyn_splats_export
41+
--enable_dyn_splats_export"
42+
43+
if [ "$PRUNE_SPLATS" = "True" ]; then
44+
CMD="$CMD --temp_opa_vis_pruning"
45+
echo "Pruning splat is $PRUNE_SPLATS"
46+
fi
47+
48+
eval "$CMD"
3249
}
3350

3451
GPU_LIST=(7)
@@ -39,8 +56,19 @@ SCENE_IDX=-1
3956
for SCENE in $SCENE_LIST;
4057
do
4158
SCENE_IDX=$((SCENE_IDX + 1))
42-
{
59+
{
60+
# export plys
4361
run_single_scene ${GPU_LIST[$SCENE_IDX]} $SCENE
4462
} #&
4563

46-
done
64+
# pack up all splats in ply dir
65+
if [ "$PRUNE_SPLATS" = "True" ]; then
66+
echo "Zip pruned splats to $RESULT_DIR/$SCENE/${SCENE}_pruned_splats.zip"
67+
zip -r $RESULT_DIR/$SCENE/${SCENE}_pruned_splats.zip $RESULT_DIR/$SCENE/plys
68+
else
69+
echo "Zip splats to $RESULT_DIR/$SCENE/${SCENE}_splats.zip"
70+
zip -r $RESULT_DIR/$SCENE/${SCENE}_splats.zip $RESULT_DIR/$SCENE/plys
71+
fi
72+
73+
done
74+

examples/simple_trainer_dyngs.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ class Config:
177177
enable_autograd_detect_anomaly: bool = False
178178

179179
# Enable exporting dynamic splats to per-frame ply files
180-
enable_dyn_splats_export: bool = True
180+
enable_dyn_splats_export: bool = False
181181
# Time-variant opacity based pruning
182-
temp_opa_vis_pruning: bool = True
182+
temp_opa_vis_pruning: bool = False
183183
# ply file saving mode
184184
sliced_splats_saving_mode: Literal["splats", "pcd"] = "splats"
185185

@@ -1161,6 +1161,11 @@ def export_dyn_splats_to_ply_sequence(self,):
11611161
shutil.rmtree(ply_dir)
11621162
os.makedirs(ply_dir, exist_ok=True)
11631163

1164+
if self.cfg.temp_opa_vis_pruning:
1165+
print("[Info] Pruning splats whose opacities are below threshold before saving splats into .ply files.")
1166+
else:
1167+
print("[Info] Directly saving splats into .ply files, without pruning.")
1168+
11641169
for f_id in range(cfg.duration):
11651170
timestamp = f_id/cfg.duration
11661171
sliced_splats = self.get_sliced_splats_from_dyn_splats(self.splats, timestamp, self.cfg.temp_opa_vis_pruning)
@@ -1203,7 +1208,7 @@ def trbfunction(x):
12031208
means_motion = means + motion[:, 0:3] * tforpoly + motion[:, 3:6] * tforpoly * tforpoly + motion[:, 6:9] * tforpoly *tforpoly * tforpoly
12041209
# Calculate rotations
12051210
rotations = torch.nn.functional.normalize(quats + tforpoly * omega)
1206-
import pdb; pdb.set_trace()
1211+
# import pdb; pdb.set_trace()
12071212
if opa_vis_mask:
12081213
vis_mask = trbfoutput.squeeze() > 0.05
12091214
means_motion = means_motion[vis_mask]
@@ -1215,17 +1220,20 @@ def trbfunction(x):
12151220
num_vis_mask = vis_mask.sum()
12161221
num_all_splats = vis_mask.shape[0]
12171222

1223+
def inverse_sigmoid(x):
1224+
return torch.log(x/(1-x))
1225+
12181226
sliced_splats = {
12191227
"means": means_motion,
1220-
"scales": scales,
1228+
"scales": torch.log(scales),
12211229
"quats": rotations,
1222-
"opacities": opacity,
1230+
"opacities": inverse_sigmoid(opacity),
12231231
"rgb": color
12241232
}
12251233

12261234
return sliced_splats
12271235

1228-
def save_static_splats_to_ply(self, path, splats, mode="splats"):
1236+
def save_static_splats_to_ply(self, path, splats, mode="splats", use_text=False):
12291237
from plyfile import PlyElement, PlyData
12301238

12311239
xyz = splats["means"].detach().cpu().numpy()
@@ -1235,29 +1243,31 @@ def save_static_splats_to_ply(self, path, splats, mode="splats"):
12351243
rgb = splats["rgb"].detach().contiguous().cpu().numpy()
12361244

12371245
if mode == "splats":
1246+
sh0 = rgb_to_sh(rgb)
1247+
12381248
dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()]
12391249
elements = np.empty(xyz.shape[0], dtype=dtype_full)
1240-
attributes = np.concatenate((xyz, scales, quats, opacities, rgb), axis=1)
1250+
attributes = np.concatenate((xyz, scales, quats, opacities, sh0), axis=1)
12411251
elements[:] = list(map(tuple, attributes))
1252+
12421253
elif mode == "pcd":
12431254
rgb = np.round(np.clip(rgb, 0, 1) * 255).astype(np.uint8)
1255+
12441256
dtype_full = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
12451257
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
1246-
12471258
elements = np.empty(len(xyz), dtype=dtype_full)
12481259

12491260
elements['x'] = xyz[:, 0]
12501261
elements['y'] = xyz[:, 1]
12511262
elements['z'] = xyz[:, 2]
1252-
12531263
elements['red'] = rgb[:, 0]
12541264
elements['green'] = rgb[:, 1]
12551265
elements['blue'] = rgb[:, 2]
12561266

12571267
el = PlyElement.describe(elements, 'vertex')
1258-
PlyData([el], text=True).write(path)
1268+
PlyData([el], text=use_text).write(path)
12591269

1260-
print(f"Save splats to: {path}")
1270+
print(f"Save {xyz.shape[0]} splats to: {path}")
12611271

12621272
def construct_list_of_attributes():
12631273
l = ['x', 'y', 'z']
@@ -1267,7 +1277,7 @@ def construct_list_of_attributes():
12671277
l.append('rot_{}'.format(i))
12681278
l.append('opacity')
12691279
for i in range(3):
1270-
l.append('sh0_{}'.format(i))
1280+
l.append('f_dc_{}'.format(i))
12711281

12721282
return l
12731283

@@ -1287,8 +1297,9 @@ def main(cfg: Config):
12871297
runner.decoder.load_state_dict(ckpts[0]["decoder"])
12881298
step = ckpts[0]["step"]
12891299

1290-
print(f"Evaluate ckpt saved at step {step}")
1291-
runner.eval(step=step)
1300+
if cfg.compression is not None:
1301+
print(f"Evaluate ckpt saved at step {step}")
1302+
runner.eval(step=step)
12921303

12931304
# print(f"Render trajectory using ckpt saved at step {step}")
12941305
# runner.render_traj(step=step)

0 commit comments

Comments
 (0)