4141 scale_gradients_by_distance_squared ,
4242)
4343from nerfstudio .model_components .ray_samplers import ProposalNetworkSampler , UniformSampler
44- from nerfstudio .model_components .renderers import AccumulationRenderer , DepthRenderer , NormalsRenderer , RGBRenderer
44+ from nerfstudio .model_components .renderers import (
45+ AccumulationRenderer ,
46+ DepthRenderer ,
47+ NormalsRenderer ,
48+ RGBRenderer ,
49+ UncertaintyRenderer ,
50+ )
4551from nerfstudio .model_components .scene_colliders import NearFarCollider
4652from nerfstudio .model_components .shaders import NormalsShader
4753from nerfstudio .models .base_model import Model , ModelConfig
@@ -63,9 +69,9 @@ class NerfactoModelConfig(ModelConfig):
6369 """Dimension of hidden layers"""
6470 hidden_dim_color : int = 64
6571 """Dimension of hidden layers for color network"""
66- use_transient_embedding : bool = True
72+ use_transient_embedding : bool = False
6773 """Whether to use an transient embedding."""
68- hidden_dim_transient : int = 128
74+ hidden_dim_transient : int = 64
6975 """Dimension of hidden layers for transient network"""
7076 transient_embed_dim : int = 16
7177 """Dimension of the transient embedding."""
@@ -240,6 +246,7 @@ def update_schedule(step):
240246 self .renderer_accumulation = AccumulationRenderer ()
241247 self .renderer_depth = DepthRenderer (method = "median" )
242248 self .renderer_expected_depth = DepthRenderer (method = "expected" )
249+ self .renderer_uncertainty = UncertaintyRenderer ()
243250 self .renderer_normals = NormalsRenderer ()
244251
245252 # shaders
@@ -311,11 +318,25 @@ def get_outputs(self, ray_bundle: RayBundle):
311318 if self .config .use_gradient_scaling :
312319 field_outputs = scale_gradients_by_distance_squared (field_outputs , ray_samples )
313320
314- weights = ray_samples .get_weights (field_outputs [FieldHeadNames .DENSITY ])
321+ if self .training and self .config .use_transient_embedding :
322+ static_density = field_outputs [FieldHeadNames .DENSITY ]
323+ transient_density = field_outputs [FieldHeadNames .TRANSIENT_DENSITY ]
324+ weights_static = ray_samples .get_weights (static_density )
325+ weights_transient = ray_samples .get_weights (transient_density )
326+ weights = weights_static
327+ rgb_static_component = self .renderer_rgb (rgb = field_outputs [FieldHeadNames .RGB ], weights = weights_static )
328+ rgb_transient_component = self .renderer_rgb (
329+ rgb = field_outputs [FieldHeadNames .TRANSIENT_RGB ], weights = weights_transient
330+ )
331+ rgb = rgb_static_component + rgb_transient_component
332+ else :
333+ weights_static = ray_samples .get_weights (field_outputs [FieldHeadNames .DENSITY ])
334+ weights = weights_static
335+ rgb = self .renderer_rgb (rgb = field_outputs [FieldHeadNames .RGB ], weights = weights )
336+
315337 weights_list .append (weights )
316338 ray_samples_list .append (ray_samples )
317339
318- rgb = self .renderer_rgb (rgb = field_outputs [FieldHeadNames .RGB ], weights = weights )
319340 with torch .no_grad ():
320341 depth = self .renderer_depth (weights = weights , ray_samples = ray_samples )
321342 expected_depth = self .renderer_expected_depth (weights = weights , ray_samples = ray_samples )
@@ -351,6 +372,13 @@ def get_outputs(self, ray_bundle: RayBundle):
351372
352373 for i in range (self .config .num_proposal_iterations ):
353374 outputs [f"prop_depth_{ i } " ] = self .renderer_depth (weights = weights_list [i ], ray_samples = ray_samples_list [i ])
375+
376+ # transients
377+ if self .training and self .config .use_transient_embedding :
378+ uncertainty = self .renderer_uncertainty (field_outputs [FieldHeadNames .UNCERTAINTY ], weights_transient )
379+ outputs ["uncertainty" ] = uncertainty + 0.1 # NOTE(ethan): this is the uncertainty min
380+ outputs ["density_transient" ] = field_outputs [FieldHeadNames .TRANSIENT_DENSITY ]
381+
354382 return outputs
355383
356384 def get_metrics_dict (self , outputs , batch ):
@@ -375,7 +403,15 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None):
375403 gt_image = image ,
376404 )
377405
378- loss_dict ["rgb_loss" ] = self .rgb_loss (gt_rgb , pred_rgb )
406+ if self .training and self .config .use_transient_embedding :
407+ # transient loss
408+ betas = outputs ["uncertainty" ]
409+ loss_dict ["uncertainty_loss" ] = 3 + torch .log (betas ).mean ()
410+ loss_dict ["density_loss" ] = 0.01 * outputs ["density_transient" ].mean ()
411+ loss_dict ["rgb_loss" ] = (((gt_rgb - pred_rgb ) ** 2 ).sum (- 1 ) / (betas [..., 0 ] ** 2 )).mean ()
412+ else :
413+ loss_dict ["rgb_loss" ] = self .rgb_loss (gt_rgb , pred_rgb )
414+
379415 if self .training :
380416 loss_dict ["interlevel_loss" ] = self .config .interlevel_loss_mult * interlevel_loss (
381417 outputs ["weights_list" ], outputs ["ray_samples_list" ]
0 commit comments