Skip to content

Commit db4b462

Browse files
committed
Add correct transient rendering and loss to nerfacto
Disable by default
1 parent 84881f4 commit db4b462

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

nerfstudio/models/nerfacto.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@
4141
scale_gradients_by_distance_squared,
4242
)
4343
from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler, UniformSampler
44-
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, NormalsRenderer, RGBRenderer
44+
from nerfstudio.model_components.renderers import (
45+
AccumulationRenderer,
46+
DepthRenderer,
47+
NormalsRenderer,
48+
RGBRenderer,
49+
UncertaintyRenderer,
50+
)
4551
from nerfstudio.model_components.scene_colliders import NearFarCollider
4652
from nerfstudio.model_components.shaders import NormalsShader
4753
from nerfstudio.models.base_model import Model, ModelConfig
@@ -63,9 +69,9 @@ class NerfactoModelConfig(ModelConfig):
6369
"""Dimension of hidden layers"""
6470
hidden_dim_color: int = 64
6571
"""Dimension of hidden layers for color network"""
66-
use_transient_embedding: bool = True
72+
use_transient_embedding: bool = False
6773
"""Whether to use an transient embedding."""
68-
hidden_dim_transient: int = 128
74+
hidden_dim_transient: int = 64
6975
"""Dimension of hidden layers for transient network"""
7076
transient_embed_dim: int = 16
7177
"""Dimension of the transient embedding."""
@@ -240,6 +246,7 @@ def update_schedule(step):
240246
self.renderer_accumulation = AccumulationRenderer()
241247
self.renderer_depth = DepthRenderer(method="median")
242248
self.renderer_expected_depth = DepthRenderer(method="expected")
249+
self.renderer_uncertainty = UncertaintyRenderer()
243250
self.renderer_normals = NormalsRenderer()
244251

245252
# shaders
@@ -311,11 +318,25 @@ def get_outputs(self, ray_bundle: RayBundle):
311318
if self.config.use_gradient_scaling:
312319
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)
313320

314-
weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
321+
if self.training and self.config.use_transient_embedding:
322+
static_density = field_outputs[FieldHeadNames.DENSITY]
323+
transient_density = field_outputs[FieldHeadNames.TRANSIENT_DENSITY]
324+
weights_static = ray_samples.get_weights(static_density)
325+
weights_transient = ray_samples.get_weights(transient_density)
326+
weights = weights_static
327+
rgb_static_component = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights_static)
328+
rgb_transient_component = self.renderer_rgb(
329+
rgb=field_outputs[FieldHeadNames.TRANSIENT_RGB], weights=weights_transient
330+
)
331+
rgb = rgb_static_component + rgb_transient_component
332+
else:
333+
weights_static = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
334+
weights = weights_static
335+
rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
336+
315337
weights_list.append(weights)
316338
ray_samples_list.append(ray_samples)
317339

318-
rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
319340
with torch.no_grad():
320341
depth = self.renderer_depth(weights=weights, ray_samples=ray_samples)
321342
expected_depth = self.renderer_expected_depth(weights=weights, ray_samples=ray_samples)
@@ -351,6 +372,13 @@ def get_outputs(self, ray_bundle: RayBundle):
351372

352373
for i in range(self.config.num_proposal_iterations):
353374
outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i])
375+
376+
# transients
377+
if self.training and self.config.use_transient_embedding:
378+
uncertainty = self.renderer_uncertainty(field_outputs[FieldHeadNames.UNCERTAINTY], weights_transient)
379+
outputs["uncertainty"] = uncertainty + 0.1 # NOTE(ethan): this is the uncertainty min
380+
outputs["density_transient"] = field_outputs[FieldHeadNames.TRANSIENT_DENSITY]
381+
354382
return outputs
355383

356384
def get_metrics_dict(self, outputs, batch):
@@ -375,7 +403,15 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None):
375403
gt_image=image,
376404
)
377405

378-
loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)
406+
if self.training and self.config.use_transient_embedding:
407+
# transient loss
408+
betas = outputs["uncertainty"]
409+
loss_dict["uncertainty_loss"] = 3 + torch.log(betas).mean()
410+
loss_dict["density_loss"] = 0.01 * outputs["density_transient"].mean()
411+
loss_dict["rgb_loss"] = (((gt_rgb - pred_rgb) ** 2).sum(-1) / (betas[..., 0] ** 2)).mean()
412+
else:
413+
loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)
414+
379415
if self.training:
380416
loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss(
381417
outputs["weights_list"], outputs["ray_samples_list"]

0 commit comments

Comments
 (0)