Skip to content

Commit e6c896b

Browse files
committed
misc
1 parent 99ffc5e commit e6c896b

File tree

3 files changed

+62
-36
lines changed

3 files changed

+62
-36
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ setup_gscodec.py
131131

132132
figs
133133
stats
134+
temp
134135
Readme_GSCodec.md
135136

136137
!examples/benchmarks/compression/results/

examples/simple_trainer_STG.py

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Config:
8484
source_path: str = ""
8585
model_path: str = ""
8686
images: str = "images"
87-
resolution: int = 2 #-1
87+
downscale_factor: int = 2 #-1
8888
white_background: bool = False
8989
veryrify_llff: int = 0
9090
eval: bool = True
@@ -95,7 +95,7 @@ class Config:
9595
# Optimization Params / op
9696
max_steps: int = 30_000
9797
init_opa: float = 0.1 # Initial opacity of GS
98-
batch_size: int = 2 # TODO Do not support batch size = 1 for now
98+
batch_size: int = 2
9999
feature_dim: int = 3
100100
device: str = "cuda"
101101
global_scale: float = 1.0 # A global scaler that applies to the scene size related parameters
@@ -106,14 +106,15 @@ class Config:
106106
refine_stop_iter: int = 9_000 # STG changed this param from 15_000 to 9_000, comprared with 3dgs
107107
reset_every: int = 3_000
108108
refine_every: int = 100
109+
pause_refine_after_reset: int = 0
109110
absgrad: bool = False
110111
packed: bool = False
111112
sparse_grad: bool = False
112113
antialiased: bool = False
113114
duration: int = 50 # 20 # number of frames to train
114115
ssim_lambda: float = 0.2 # Weight for SSIM loss
115-
save_steps: List[int] = field(default_factory=lambda: [i for i in range(9_000, 30_001, 3_000)]) # Steps to save the model
116-
eval_steps: List[int] = field(default_factory=lambda: [i for i in range(9_000, 30_001, 3_000)]) # Steps to evaluate the model # 7_000, 30_000
116+
save_steps: List[int] = field(default_factory=lambda: [i for i in range(9_000, 75_001, 3_000)]) # Steps to save the model
117+
eval_steps: List[int] = field(default_factory=lambda: [i for i in range(0, 75_001, 3_000)]) # Steps to evaluate the model # 7_000, 30_000
117118
# eval_steps: List[int] = field(default_factory=lambda: [1_000, 2_000, 3_000, 4_000, 5_000, 6_000, 7_000, 25_000, 30_000])
118119
# Number of densification
119120
desicnt: int = 6 # default: 6
@@ -215,15 +216,15 @@ def create_splats_with_optimizers(
215216

216217
# Didn't introduce world_rank and world_size
217218
N = points.shape[0]
218-
# quats = torch.rand((N, 4)) # [N, 4]
219-
quats = torch.zeros((N, 4))
219+
# quats = torch.rand((N, 4))
220+
quats = torch.zeros((N, 4)) # [N, 4]
220221
quats[:, 0] = 1
221222
# opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
222223
opacities = inverse_sigmoid(0.1 * torch.ones(N,))
223-
trbf_scale = torch.ones((N, 1))
224+
trbf_scale = torch.log(torch.ones((N, 1))) # [N, 1]
224225
times = parser.timestamp
225226
times = torch.tensor(times)
226-
trbf_center = times.contiguous()
227+
trbf_center = times.contiguous() # [N, 1]
227228
motion = torch.zeros((N, 9))
228229
omega = torch.zeros((N, 4))
229230

@@ -285,7 +286,7 @@ def __init__(self, cfg: Config) -> None:
285286
# only enable when debug!!
286287
if cfg.enable_autograd_detect_anomaly:
287288
torch.autograd.set_detect_anomaly(True)
288-
289+
289290
self.cfg = cfg
290291
# Write cfg file: Skipped
291292
self.device = self.cfg.device
@@ -310,10 +311,10 @@ def __init__(self, cfg: Config) -> None:
310311

311312
# Load data: Training data should contain initial points and colors.
312313
parser = Parser(model_path=self.cfg.model_path, source_path=self.cfg.data_dir, duration=cfg.duration,
313-
shuffle=False, eval=self.cfg.eval, resolution=cfg.resolution, data_device='cpu')
314+
shuffle=False, eval=self.cfg.eval, downscale_factor=cfg.downscale_factor, data_device='cpu', test_view_id=cfg.test_view_id)
314315
self.parser = parser
315316
self.trainset = Dataset(parser=self.parser, split="train", num_views=cfg.batch_size, use_fake_length=True, fake_length=cfg.max_steps+100)
316-
self.testset = Dataset(parser=self.parser, split="test", num_views=1)
317+
self.testset = Dataset(parser=self.parser, split="test", num_views=len(cfg.test_view_id))
317318

318319
self.trainloader = torch.utils.data.DataLoader(
319320
self.trainset,
@@ -352,7 +353,7 @@ def __init__(self, cfg: Config) -> None:
352353
self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=0.0001)
353354

354355
currentxyz = self.splats["means"]
355-
maxx, maxy, maxz = torch.amax(currentxyz[:,0]), torch.amax(currentxyz[:,1]), torch.amax(currentxyz[:,2])# z wrong...
356+
maxx, maxy, maxz = torch.amax(currentxyz[:,0]), torch.amax(currentxyz[:,1]), torch.amax(currentxyz[:,2])
356357
minx, miny, minz = torch.amin(currentxyz[:,0]), torch.amin(currentxyz[:,1]), torch.amin(currentxyz[:,2])
357358
self.maxbounds = [maxx, maxy, maxz]
358359
self.minbounds = [minx, miny, minz]
@@ -878,6 +879,13 @@ def eval(self, step: int, stage: str = "val"):
878879
ellipse_time = 0
879880
metrics = defaultdict(list)
880881
pbar = tqdm.tqdm(range(0, len(self.testloader)))
882+
883+
# save path
884+
eval_save_path = f"{self.render_dir}/{stage}_step{step}"
885+
os.makedirs(eval_save_path, exist_ok=True)
886+
887+
## init writer(s) based on num. of test views
888+
writers = [imageio.get_writer(f"{eval_save_path}/{stage}_step{step}_testv{i}.mp4", fps=30, quality=10) for i in range(len(cfg.test_view_id))]
881889

882890
for t_idx, batch in enumerate(self.testloader): # t_idx
883891

@@ -887,10 +895,6 @@ def eval(self, step: int, stage: str = "val"):
887895
timestamp = batch["timestamp"][0].float().to(device)
888896
rays = batch["ray"][0].float().to(device)
889897
camtoworld = batch['camtoworld'][0].float().to(device)
890-
891-
# R = camtoworld.cpu().numpy()[0, :3,:3]
892-
# T = torch.inverse(camtoworld).cpu().numpy()[0, :3,-1]
893-
# new_rays = self.get_rays(R, T, Ks[0,0,0], Ks[0,1,1], width, height).float().to(device)
894898

895899
torch.cuda.synchronize()
896900
tic = time.time()
@@ -906,29 +910,39 @@ def eval(self, step: int, stage: str = "val"):
906910
torch.cuda.synchronize()
907911
ellipse_time += time.time() - tic
908912

909-
colors = torch.clamp(colors, 0.0, 1.0)
913+
colors = torch.clamp(colors, 0.0, 1.0) # colors: [N, H, W, C]
910914
canvas_list = [pixels, colors]
911915

912916
desc = ""
913917
if world_rank == 0:
914-
# write GT-vs-rendered image
915918
try:
916-
# new version - fpng
917-
# canvas = torch.cat(canvas_list, dim=2).squeeze(0).contiguous().cpu().numpy() # fpnge needs [H,W,C]
918-
# import pdb; pdb.set_trace()
919-
canvas = canvas_list[-1].squeeze(0).contiguous().cpu().numpy()
920-
canvas = (canvas * 255).astype(np.uint8)
921-
import fpnge
922-
png = fpnge.fromNP(canvas)
923-
with open(f"{self.render_dir}/{stage}_step{step}_{t_idx:04d}.png", 'wb') as f:
924-
f.write(png)
925-
926-
canvas = canvas_list[0].squeeze(0).contiguous().cpu().numpy()
927-
canvas = (canvas * 255).astype(np.uint8)
928919
import fpnge
929-
png = fpnge.fromNP(canvas)
930-
with open(f"{self.render_dir}/{stage}_step{step}_{t_idx:04d}_gt.png", 'wb') as f:
931-
f.write(png)
920+
canvases = torch.cat(canvas_list, dim=2) # canvas: [N, H, 2*W, C]
921+
for i in range(canvases.shape[0]): # loop on test views
922+
# save side-by-side comparison
923+
canvas = canvases[i].contiguous().cpu().numpy()
924+
canvas = (canvas * 255).astype(np.uint8)
925+
# new version - fpng
926+
927+
png = fpnge.fromNP(canvas) # fpnge needs tensor in order as [H,W,C]
928+
# with open(f"{self.render_dir}/{stage}_step{step}_{t_idx:04d}_testv{i}.png", 'wb') as f:
929+
# f.write(png)
930+
with open(f"{eval_save_path}/sidebyside_testv{i}_fid{t_idx:04d}.png", 'wb') as f:
931+
f.write(png)
932+
933+
# save gt
934+
gt = pixels[i].contiguous().cpu().numpy()
935+
gt = (gt * 255).astype(np.uint8)
936+
png = fpnge.fromNP(gt)
937+
with open(f"{eval_save_path}/gt_testv{i}_fid{t_idx:04d}.png", 'wb') as f:
938+
f.write(png)
939+
940+
# save rendered
941+
rendered = colors[i].contiguous().cpu().numpy()
942+
rendered = (rendered * 255).astype(np.uint8)
943+
png = fpnge.fromNP(rendered)
944+
with open(f"{eval_save_path}/rendered_testv{i}_fid{t_idx:04d}.png", 'wb') as f:
945+
f.write(png)
932946

933947
except:
934948
# original version - imageio
@@ -939,6 +953,12 @@ def eval(self, step: int, stage: str = "val"):
939953
canvas,
940954
)
941955

956+
# save rendered test-view videos
957+
for i in range(colors.shape[0]):
958+
color = colors[i].cpu().numpy()
959+
color = (color * 255).astype(np.uint8)
960+
writers[i].append_data(color)
961+
942962
# write difference image
943963
# difference = abs(colors - pixels).squeeze().detach().cpu().numpy()
944964
# imageio.imwrite(
@@ -956,6 +976,10 @@ def eval(self, step: int, stage: str = "val"):
956976
pbar.update(1)
957977
pbar.close()
958978

979+
for i, writer in enumerate(writers):
980+
writer.close()
981+
print(f"Video saved to {self.render_dir}/{stage}_step{step}_testv{i}.mp4")
982+
959983
if world_rank == 0:
960984
ellipse_time /= len(self.testloader)
961985

@@ -1221,9 +1245,11 @@ def main(cfg: Config):
12211245
runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts])
12221246
runner.decoder.load_state_dict(ckpts[0]["decoder"])
12231247
step = ckpts[0]["step"]
1224-
runner.render_traj(step=step)
1225-
# runner.eval(step=step)
1248+
print(f"Evaluate ckpt saved at step {step}")
1249+
# runner.render_traj(step=step)
1250+
runner.eval(step=step)
12261251
if cfg.compression is not None:
1252+
print(f"Compress ckpt saved at step {step}")
12271253
runner.run_compression(step=step)
12281254

12291255
else:

gsplat/strategy/ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def split(
150150
sel = torch.where(mask)[0]
151151
rest = torch.where(~mask)[0]
152152

153-
# spatial resampling
154153
scales = torch.exp(params["scales"][sel])
155154
quats = F.normalize(params["quats"][sel], dim=-1)
156155
rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3]

0 commit comments

Comments
 (0)