@@ -118,62 +118,35 @@ def step_post_backward(
118118 if step >= self .refine_stop_iter :
119119 return True
120120
121- # if step >= self.refine_stop_iter:
122- # # freeze weights of omega
123- # params["omega"].grad = params["omega"].grad * self.omegamask # TODO check if this is proceed as expected
124- # self.rotationmask = torch.logical_not(self.omegamask)
125- # # freeze weights of rotation
126- # params["quats"].grad = params["quats"].grad * self.rotationmask # TODO check if this is proceed as expected
127- # if step % 1000 == 500 :
128- # zmask = params["means"][:,2] < 4.5 # sicheng: actually, it is a really bad impl. in original STG
129- # remove(params=params, optimizers=optimizers, state=state, mask=zmask)
130- # self.omegamask = self._zero_omegabymotion(params, optimizers) # calculate omegamask again to adjust the change of gaussian numbers
131- # torch.cuda.empty_cache()
132- # if step == 10000:
133- # self.removeminmax(params=params, optimizers=optimizers, state=state, maxbounds=maxbounds, minbounds=minbounds)
134- # self.omegamask = self._zero_omegabymotion(params, optimizers) # calculate omegamask again to adjust the change of gaussian numbers
135- # return flag
136-
137121 self ._update_state (params , state , info , packed = packed )
138122
139123 # TODO: need to consider more strategy, there are totally 3 types of strategy in STG (densify = 1,2,3)
140124 # sicheng: in original STG, n3d scenes in night use densify=1, scenes in day use densify=2
141125 # here is a implementation of densify=1
142126 # omega & rotation mask
143-
144- # if step == 8001 :
145- # omegamask = self._zero_omegabymotion(params, optimizers)
146- # self.omegamask = omegamask
147- # # record process
148- # elif step > 8001:
149- # # freeze weights of omega
150- # params["omega"].grad = params["omega"].grad * self.omegamask # this is likely wrong
151- # self.rotationmask = torch.logical_not(self.omegamask)
152- # # freeze weights of rotation
153- # params["quats"].grad = params["quats"].grad * self.rotationmask # this is likely wrong
154127
155128 if (
156129 step > self .refine_start_iter
157130 and step % self .refine_every == 0
158- and step % self .reset_every >= self .pause_refine_after_reset
159131 ):
160- # if flag < desicnt:
161132 # grow GSs
162133 n_dupli , n_split = self ._grow_gs (params , optimizers , state , step )
163134 if self .verbose :
164135 print (
165136 f"Step { step } : { n_dupli } GSs duplicated, { n_split } GSs split. "
166137 f"Now having { len (params ['means' ])} GSs."
167138 )
168- # according to STG, pruning don't proceed here
169-
170- n_prune = self ._prune_gs (params , optimizers , state , step )
171- if self .verbose :
172- print (
173- f"Step { step } : { n_prune } GSs pruned. "
174- f"Now having { len (params ['means' ])} GSs."
175- )
176- # torch.cuda.empty_cache() # check if this is needed
139+
140+ # Don't prune points immediately after a opacity reset, as this might remove many useful points.
141+ if step % self .reset_every >= self .pause_refine_after_reset :
142+ # prune GSs
143+ n_prune = self ._prune_gs (params , optimizers , state , step )
144+ if self .verbose :
145+ print (
146+ f"Step { step } : { n_prune } GSs pruned. "
147+ f"Now having { len (params ['means' ])} GSs."
148+ )
149+ # torch.cuda.empty_cache() # check if this is needed
177150
178151 # reset running stats
179152 state ["grad2d" ].zero_ ()
@@ -182,25 +155,13 @@ def step_post_backward(
182155 state ["radii" ].zero_ ()
183156 torch .cuda .empty_cache ()
184157
185- # flag+=1
186- # else:
187- # if step < 7000 : # defalt 7000.
188- # # prune GSs
189- # n_prune = self._prune_gs(params, optimizers, state, step)
190- # if self.verbose:
191- # print(
192- # f"Step {step}: {n_prune} GSs pruned. "
193- # f"Now having {len(params['means'])} GSs."
194- # )
195- # torch.cuda.empty_cache() # check if this is needed
196-
197- # if step % self.reset_every == 0:
198- # reset_opa(
199- # params=params,
200- # optimizers=optimizers,
201- # state=state,
202- # value=self.prune_opa * 2.0,
203- # )
158+ if step % self .reset_every == 0 :
159+ reset_opa (
160+ params = params ,
161+ optimizers = optimizers ,
162+ state = state ,
163+ value = self .prune_opa * 2.0 ,
164+ )
204165
205166 return flag
206167
@@ -326,11 +287,10 @@ def _prune_gs(
326287 ) -> int :
327288 is_prune = torch .sigmoid (params ["opacities" ].flatten ()) < self .prune_opa
328289 if step > self .reset_every :
329- # In STG, scale is not considered when pruning
330- # is_too_big = (
331- # torch.exp(params["scales"]).max(dim=-1).values
332- # > self.prune_scale3d * state["scene_scale"]
333- # )
290+ is_too_big = (
291+ torch .exp (params ["scales" ]).max (dim = - 1 ).values
292+ > self .prune_scale3d * state ["scene_scale" ]
293+ )
334294
335295 # The official code also implements sreen-size pruning but
336296 # it's actually not being used due to a bug:
@@ -340,7 +300,7 @@ def _prune_gs(
340300 if step < self .refine_scale2d_stop_iter :
341301 is_too_big |= state ["radii" ] > self .prune_scale2d
342302
343- # is_prune = is_prune | is_too_big
303+ is_prune = is_prune | is_too_big
344304
345305 n_prune = is_prune .sum ().item ()
346306 if n_prune > 0 :
0 commit comments