Skip to content

Commit 96a6359

Browse files
committed
bugfix: fix the color shift problem in STG and update the dynamic viewer
1 parent 39881f0 commit 96a6359

File tree

3 files changed

+25
-65
lines changed

3 files changed

+25
-65
lines changed

examples/helper/STG/dataset_readers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_center_and_diag(cam_centers):
7272
C2W_list.append(C2W)
7373

7474
center, diagonal = get_center_and_diag(cam_centers)
75-
radius = diagonal * 1.1
75+
radius = diagonal
7676

7777
translate = -center
7878

examples/simple_viewer_dyn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def main(local_rank: int, world_rank, world_size: int, args):
247247

248248
dyn_gs = DynGSRenderer(args)
249249

250-
gui = ViserViewer(port=8080)
250+
gui = ViserViewer(port=args.port)
251251

252252
gui.set_scene_rep(dyn_gs)
253253

gsplat/strategy/modified_stg.py

Lines changed: 23 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -118,62 +118,35 @@ def step_post_backward(
118118
if step >= self.refine_stop_iter:
119119
return True
120120

121-
# if step >= self.refine_stop_iter:
122-
# # freeze weights of omega
123-
# params["omega"].grad = params["omega"].grad * self.omegamask # TODO check if this is proceed as expected
124-
# self.rotationmask = torch.logical_not(self.omegamask)
125-
# # freeze weights of rotation
126-
# params["quats"].grad = params["quats"].grad * self.rotationmask # TODO check if this is proceed as expected
127-
# if step % 1000 == 500 :
128-
# zmask = params["means"][:,2] < 4.5 # sicheng: actually, it is a really bad impl. in original STG
129-
# remove(params=params, optimizers=optimizers, state=state, mask=zmask)
130-
# self.omegamask = self._zero_omegabymotion(params, optimizers) # calculate omegamask again to adjust the change of gaussian numbers
131-
# torch.cuda.empty_cache()
132-
# if step == 10000:
133-
# self.removeminmax(params=params, optimizers=optimizers, state=state, maxbounds=maxbounds, minbounds=minbounds)
134-
# self.omegamask = self._zero_omegabymotion(params, optimizers) # calculate omegamask again to adjust the change of gaussian numbers
135-
# return flag
136-
137121
self._update_state(params, state, info, packed=packed)
138122

139123
# TODO: need to consider more strategy, there are totally 3 types of strategy in STG (densify = 1,2,3)
140124
# sicheng: in original STG, n3d scenes in night use densify=1, scenes in day use densify=2
141125
# here is a implementation of densify=1
142126
# omega & rotation mask
143-
144-
# if step == 8001 :
145-
# omegamask = self._zero_omegabymotion(params, optimizers)
146-
# self.omegamask = omegamask
147-
# # record process
148-
# elif step > 8001:
149-
# # freeze weights of omega
150-
# params["omega"].grad = params["omega"].grad * self.omegamask # this is likely wrong
151-
# self.rotationmask = torch.logical_not(self.omegamask)
152-
# # freeze weights of rotation
153-
# params["quats"].grad = params["quats"].grad * self.rotationmask # this is likely wrong
154127

155128
if (
156129
step > self.refine_start_iter
157130
and step % self.refine_every == 0
158-
and step % self.reset_every >= self.pause_refine_after_reset
159131
):
160-
# if flag < desicnt:
161132
# grow GSs
162133
n_dupli, n_split = self._grow_gs(params, optimizers, state, step)
163134
if self.verbose:
164135
print(
165136
f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. "
166137
f"Now having {len(params['means'])} GSs."
167138
)
168-
# according to STG, pruning don't proceed here
169-
170-
n_prune = self._prune_gs(params, optimizers, state, step)
171-
if self.verbose:
172-
print(
173-
f"Step {step}: {n_prune} GSs pruned. "
174-
f"Now having {len(params['means'])} GSs."
175-
)
176-
# torch.cuda.empty_cache() # check if this is needed
139+
140+
# Don't prune points immediately after a opacity reset, as this might remove many useful points.
141+
if step % self.reset_every >= self.pause_refine_after_reset:
142+
# prune GSs
143+
n_prune = self._prune_gs(params, optimizers, state, step)
144+
if self.verbose:
145+
print(
146+
f"Step {step}: {n_prune} GSs pruned. "
147+
f"Now having {len(params['means'])} GSs."
148+
)
149+
# torch.cuda.empty_cache() # check if this is needed
177150

178151
# reset running stats
179152
state["grad2d"].zero_()
@@ -182,25 +155,13 @@ def step_post_backward(
182155
state["radii"].zero_()
183156
torch.cuda.empty_cache()
184157

185-
# flag+=1
186-
# else:
187-
# if step < 7000 : # defalt 7000.
188-
# # prune GSs
189-
# n_prune = self._prune_gs(params, optimizers, state, step)
190-
# if self.verbose:
191-
# print(
192-
# f"Step {step}: {n_prune} GSs pruned. "
193-
# f"Now having {len(params['means'])} GSs."
194-
# )
195-
# torch.cuda.empty_cache() # check if this is needed
196-
197-
# if step % self.reset_every == 0:
198-
# reset_opa(
199-
# params=params,
200-
# optimizers=optimizers,
201-
# state=state,
202-
# value=self.prune_opa * 2.0,
203-
# )
158+
if step % self.reset_every == 0:
159+
reset_opa(
160+
params=params,
161+
optimizers=optimizers,
162+
state=state,
163+
value=self.prune_opa * 2.0,
164+
)
204165

205166
return flag
206167

@@ -326,11 +287,10 @@ def _prune_gs(
326287
) -> int:
327288
is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa
328289
if step > self.reset_every:
329-
# In STG, scale is not considered when pruning
330-
# is_too_big = (
331-
# torch.exp(params["scales"]).max(dim=-1).values
332-
# > self.prune_scale3d * state["scene_scale"]
333-
# )
290+
is_too_big = (
291+
torch.exp(params["scales"]).max(dim=-1).values
292+
> self.prune_scale3d * state["scene_scale"]
293+
)
334294

335295
# The official code also implements sreen-size pruning but
336296
# it's actually not being used due to a bug:
@@ -340,7 +300,7 @@ def _prune_gs(
340300
if step < self.refine_scale2d_stop_iter:
341301
is_too_big |= state["radii"] > self.prune_scale2d
342302

343-
# is_prune = is_prune | is_too_big
303+
is_prune = is_prune | is_too_big
344304

345305
n_prune = is_prune.sum().item()
346306
if n_prune > 0:

0 commit comments

Comments
 (0)