Skip to content
Closed
14 changes: 12 additions & 2 deletions fvdb_reality_capture/cli/frgs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,24 @@
from ._mesh_basic import MeshBasic
from ._mesh_dlnr import MeshDLNR
from ._points import Points
from ._reconstruct import Reconstruct
from ._reconstruct import Reconstruct, ReconstructMCMC
from ._resume import Resume
from ._show import Show
from ._show_data import ShowData


def frgs():
cmd: BaseCommand = tyro.cli(
Download | Reconstruct | Convert | ShowData | Show | Resume | Evaluate | MeshBasic | MeshDLNR | Points
Download
| Reconstruct
| ReconstructMCMC
| Convert
| ShowData
| Show
| Resume
| Evaluate
| MeshBasic
| MeshDLNR
| Points
)
cmd.execute()
14 changes: 14 additions & 0 deletions fvdb_reality_capture/cli/frgs/_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fvdb_reality_capture.cli import BaseCommand
from fvdb_reality_capture.radiance_fields import (
GaussianSplatOptimizerConfig,
GaussianSplatOptimizerMCMCConfig,
GaussianSplatReconstruction,
GaussianSplatReconstructionConfig,
GaussianSplatReconstructionWriter,
Expand Down Expand Up @@ -351,3 +352,16 @@ def execute(self) -> None:
self._run_single_reconstruction(sfm_scene, writer, viz_scene)
else:
self._run_chunked_reconstruction(sfm_scene, writer, viz_scene)


@dataclass
class ReconstructMCMC(Reconstruct):
"""
Reconstruct a Gaussian Splat Radiance Field using the MCMC optimizer strategy.

This command is identical to :class:`Reconstruct`, but exposes the
:class:`~fvdb_reality_capture.radiance_fields.GaussianSplatOptimizerMCMCConfig`
configuration under ``--opt.*``.
"""

opt: GaussianSplatOptimizerMCMCConfig = field(default_factory=GaussianSplatOptimizerMCMCConfig)
6 changes: 6 additions & 0 deletions fvdb_reality_capture/radiance_fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
InsertionGrad2dThresholdMode,
SpatialScaleMode,
)
from .gaussian_splat_optimizer_mcmc import (
GaussianSplatOptimizerMCMC,
GaussianSplatOptimizerMCMCConfig,
)
from .gaussian_splat_reconstruction import (
GaussianSplatReconstruction,
GaussianSplatReconstructionConfig,
Expand All @@ -30,6 +34,8 @@
"BaseGaussianSplatOptimizer",
"GaussianSplatOptimizer",
"GaussianSplatOptimizerConfig",
"GaussianSplatOptimizerMCMC",
"GaussianSplatOptimizerMCMCConfig",
"InsertionGrad2dThresholdMode",
"SpatialScaleMode",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,38 @@
# SPDX-License-Identifier: Apache-2.0
#
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, TypeVar

import torch
from fvdb import GaussianSplat3d

# Keeps track of names of registered optimizers and their classes.
REGISTERED_OPTIMIZERS = {}


DerivedOptimizer = TypeVar("DerivedOptimizer", bound=type)


def splat_optimizer(cls: DerivedOptimizer) -> DerivedOptimizer:
"""
Decorator to register a optimizer class which inherits from :class:`BaseGaussianSplatOptimizer`.

Args:
cls: The optimizer class to register.

Returns:
cls: The registered optimizer class.
"""
if not issubclass(cls, BaseGaussianSplatOptimizer):
raise TypeError(f"Optimizer {cls} must inherit from BaseGaussianSplatOptimizer.")

if cls.name() in REGISTERED_OPTIMIZERS:
del REGISTERED_OPTIMIZERS[cls.name()]

REGISTERED_OPTIMIZERS[cls.name()] = cls

return cls


class BaseGaussianSplatOptimizer(ABC):
"""
Expand All @@ -19,6 +46,15 @@ class BaseGaussianSplatOptimizer(ABC):
`original Gaussian Splatting paper <https://arxiv.org/abs/2308.04079>`_.
"""

@classmethod
def name(cls) -> str:
"""
Stable name used for optimizer (de)serialization and registry lookup.

By default we use the class name. Override in subclasses if you need a different stable identifier.
"""
return cls.__name__

@classmethod
@abstractmethod
def from_state_dict(cls, model: GaussianSplat3d, state_dict: dict[str, Any]) -> "BaseGaussianSplatOptimizer":
Expand All @@ -32,7 +68,14 @@ def from_state_dict(cls, model: GaussianSplat3d, state_dict: dict[str, Any]) ->
Returns:
optimizer (BaseGaussianSplatOptimizer): A new :class:`BaseGaussianSplatOptimizer` instance.
"""
pass
OptimizerType = REGISTERED_OPTIMIZERS.get(state_dict["name"], None)
if OptimizerType is None:
raise ValueError(
f"Optimizer '{state_dict['name']}' is not registered. Optimizer classes must be registered "
f"with the `splat_optimizer` decorator which will be called when the optimizer is defined. "
f"Ensure the optimizer class uses the `splat_optimizer` decorator and was imported before calling from_state_dict."
)
return OptimizerType.from_state_dict(model, state_dict)

@abstractmethod
def state_dict(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from fvdb_reality_capture.sfm_scene import SfmScene

from .base_gaussian_splat_optimizer import BaseGaussianSplatOptimizer
from .base_gaussian_splat_optimizer import BaseGaussianSplatOptimizer, splat_optimizer


class InsertionGrad2dThresholdMode(str, Enum):
Expand Down Expand Up @@ -248,7 +248,15 @@ class GaussianSplatOptimizerConfig:
Learning rate for the specular spherical harmonics (order > 0).
"""

def make_optimizer(self, model: GaussianSplat3d, sfm_scene: SfmScene) -> BaseGaussianSplatOptimizer:
return GaussianSplatOptimizer.from_model_and_scene(
model=model,
sfm_scene=sfm_scene,
config=self,
)


@splat_optimizer
class GaussianSplatOptimizer(BaseGaussianSplatOptimizer):
"""
Optimizer for reconstructing a scene using Gaussian Splat radiance fields over a collection of posed images.
Expand Down Expand Up @@ -324,6 +332,7 @@ def __init__(
# This hook corrects the count even if backward() is called multiple times per iteration.
self._num_grad_accumulation_steps = 1 # Number of times we've called backward since zeroing the gradients

@torch.utils.hooks.unserializable_hook
def _count_accumulation_steps_backward_hook(_):
self._num_grad_accumulation_steps += 1

Expand Down Expand Up @@ -497,6 +506,7 @@ def state_dict(self) -> dict[str, Any]:
state_dict (dict[str, Any]): A state dict containing the state of the optimizer.
"""
return {
"name": self.__class__.name(),
"optimizer": self._optimizer.state_dict(),
"means_lr_decay_exponent": self._means_lr_decay_exponent,
"insertion_grad_2d_abs_threshold": self._insertion_grad_2d_abs_threshold,
Expand Down
Loading
Loading