Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fvdb/_fvdb_cpp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class GaussianSplat3d:
n_max: int,
min_opacity: float,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def add_noise_to_means(self, noise_scale: float) -> None: ...
def add_noise_to_means(self, noise_scale: float, t: float = ..., k: float = ...) -> None: ...
def reset_accumulated_gradient_state(self) -> None: ...
def save_ply(self, filename: str, metadata: dict[str, str | int | float | torch.Tensor] | None) -> None: ...
@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions fvdb/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,14 +2592,16 @@ def relocate_gaussians(
min_opacity,
)

def add_noise_to_means(self, noise_scale: float) -> None:
def add_noise_to_means(self, noise_scale: float, t: float = 0.005, k: float = 100.0) -> None:
"""
Add noise to the Gaussian positions (means), scaled by ``noise_scale``.

Args:
noise_scale (float): Noise scale factor applied to scale-dependent noise.
t (float): Parameter t for noise scaling. Defaults to 0.005.
k (float): Parameter k for noise scaling. Defaults to 100.0.
"""
self._impl.add_noise_to_means(noise_scale)
self._impl.add_noise_to_means(noise_scale, t, k)

def reset_accumulated_gradient_state(self) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/fvdb/GaussianSplat3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,10 +1066,10 @@ GaussianSplat3d::relocateGaussians(const torch::Tensor &logScales,
}

void
GaussianSplat3d::addNoiseToMeans(const float noiseScale) {
GaussianSplat3d::addNoiseToMeans(const float noiseScale, const float t, const float k) {
FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() {
return detail::ops::dispatchGaussianMCMCAddNoise<DeviceTag>(
mMeans, mLogScales, mLogitOpacities, mQuats, noiseScale);
mMeans, mLogScales, mLogitOpacities, mQuats, noiseScale, t, k);
});
}

Expand Down
4 changes: 3 additions & 1 deletion src/fvdb/GaussianSplat3d.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,9 @@ class GaussianSplat3d {

/// @brief Add noise to the Gaussian positions (means), scaled by noiseScale.
/// @param noiseScale Noise scale
void addNoiseToMeans(const float noiseScale);
/// @param t Cutoff for opacity scaling
/// @param k Exponent for opacity scaling
void addNoiseToMeans(const float noiseScale, const float t = 0.005, const float k = 100.0);

/// @brief Select a subset of the Gaussians in this scene based on the given slice.
/// @param begin The start index of the slice (inclusive)
Expand Down
41 changes: 29 additions & 12 deletions src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace fvdb::detail::ops {

template <typename ScalarType>
inline __device__ ScalarType
logistic(ScalarType x, ScalarType k = 100, ScalarType x0 = 0.995) {
return 1 / (1 + exp(-k * (x - x0)));
sigmoid(ScalarType x) {
return ScalarType(1) / (ScalarType(1) + ::cuda::std::exp(-x));
}

template <typename ScalarType>
Expand All @@ -27,23 +27,30 @@ gaussianMCMCAddNoiseKernel(fvdb::TorchRAcc64<ScalarType, 2> outMeans,
fvdb::TorchRAcc64<ScalarType, 1> logitOpacities,
fvdb::TorchRAcc64<ScalarType, 2> quats,
fvdb::TorchRAcc64<ScalarType, 2> baseNoise,
ScalarType noiseScale) {
const ScalarType noiseScale,
const ScalarType t,
const ScalarType k) {
const auto N = outMeans.size(0);
for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N;
idx += blockDim.x * gridDim.x) {
auto opacity = ScalarType(1.0) / (1 + exp(-logitOpacities[idx]));
const auto opacity = sigmoid(logitOpacities[idx]);

const auto quatAcc = quats[idx];
const auto logScaleAcc = logScales[idx];
auto covar = quaternionAndScaleToCovariance<ScalarType>(
const auto covar = quaternionAndScaleToCovariance<ScalarType>(
nanovdb::math::Vec4<ScalarType>(quatAcc[0], quatAcc[1], quatAcc[2], quatAcc[3]),
nanovdb::math::Vec3<ScalarType>(::cuda::std::exp(logScaleAcc[0]),
::cuda::std::exp(logScaleAcc[1]),
::cuda::std::exp(logScaleAcc[2])));

nanovdb::math::Vec3<ScalarType> noise = {
baseNoise[idx][0], baseNoise[idx][1], baseNoise[idx][2]};
noise *= logistic(1 - opacity) * noiseScale;

// The noise term is scaled down based on the opacity of the Gaussian.
// More opaque Gaussians get less noise added to them.
// The parameters t and k control the transition point and sharpness
// of the scaling function.
noise *= sigmoid(-k * (opacity - t)) * noiseScale;
noise = covar * noise;
outMeans[idx][0] += noise[0];
outMeans[idx][1] += noise[1];
Expand All @@ -57,7 +64,9 @@ launchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3]
const torch::Tensor &logScales, // [N, 3]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
ScalarType noiseScale) {
const ScalarType noiseScale,
const ScalarType t,
const ScalarType k) {
const auto N = means.size(0);

const int blockDim = DEFAULT_BLOCK_DIM;
Expand All @@ -72,7 +81,9 @@ launchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3]
logitOpacities.packed_accessor64<ScalarType, 1, torch::RestrictPtrTraits>(),
quats.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
baseNoise.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
noiseScale);
noiseScale,
t,
k);

C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand All @@ -83,13 +94,15 @@ dispatchGaussianMCMCAddNoise<torch::kCUDA>(torch::Tensor &means,
const torch::Tensor &logScales, // [N]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale) { // [N]
const float noiseScale,
const float t,
const float k) {
FVDB_FUNC_RANGE();
const at::cuda::OptionalCUDAGuard device_guard(device_of(means));

const auto N = means.size(0);

launchGaussianMCMCAddNoise<float>(means, logScales, logitOpacities, quats, noiseScale);
launchGaussianMCMCAddNoise<float>(means, logScales, logitOpacities, quats, noiseScale, t, k);
}

template <>
Expand All @@ -98,7 +111,9 @@ dispatchGaussianMCMCAddNoise<torch::kPrivateUse1>(torch::Tensor &means, // [N, 3
const torch::Tensor &logScales, // [N, 3]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale) { // [N]
const float noiseScale,
const float t,
const float k) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for PrivateUse1");
}

Expand All @@ -108,7 +123,9 @@ dispatchGaussianMCMCAddNoise<torch::kCPU>(torch::Tensor &means, // [N,
const torch::Tensor &logScales, // [N, 3]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale) { // [N]
const float noiseScale,
const float t,
const float k) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for CPU");
}

Expand Down
4 changes: 3 additions & 1 deletion src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ void dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3]
const torch::Tensor &logScales, // [N]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale); // [N]
const float noiseScale,
const float t,
const float k);

} // namespace ops
} // namespace detail
Expand Down
6 changes: 5 additions & 1 deletion src/python/GaussianSplatBinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,11 @@ bind_gaussian_splat3d(py::module &m) {
py::arg("n_max"),
py::arg("min_opacity"))

.def("add_noise_to_means", &fvdb::GaussianSplat3d::addNoiseToMeans, py::arg("noise_scale"))
.def("add_noise_to_means",
&fvdb::GaussianSplat3d::addNoiseToMeans,
py::arg("noise_scale"),
py::arg("t") = 0.005,
py::arg("k") = 100.0)

.def("index_select", &fvdb::GaussianSplat3d::indexSelect, py::arg("indices"))
.def("mask_select", &fvdb::GaussianSplat3d::maskSelect, py::arg("mask"))
Expand Down
12 changes: 6 additions & 6 deletions src/tests/GaussianMCMCAddNoiseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ TEST_F(GaussianMCMCAddNoiseTest, AppliesNoiseWithDeterministicBaseNoise) {
auto meansBaseline = means.clone();

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, noiseScale);
means, logScales, logitOpacities, quats, noiseScale, 0.005, 100.0);

restoreCudaGeneratorState(rngState);
const auto baseNoise = torch::randn_like(meansBaseline);
Expand All @@ -91,7 +91,7 @@ TEST_F(GaussianMCMCAddNoiseTest, RespectsAnisotropicScales) {
const auto rngState = saveCudaGeneratorState();

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, noiseScale);
means, logScales, logitOpacities, quats, noiseScale, 0.005, 100);

restoreCudaGeneratorState(rngState);
const auto baseNoise = torch::randn_like(means);
Expand All @@ -115,7 +115,7 @@ TEST_F(GaussianMCMCAddNoiseTest, HighOpacitySuppressesNoise) {
constexpr float noiseScale = 1.0f;

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, noiseScale);
means, logScales, logitOpacities, quats, noiseScale, 0.005, 100.0);

// Gate approaches zero when opacity ~1; expect negligible movement.
const auto maxAbs = torch::abs(means).max().item<float>();
Expand All @@ -133,7 +133,7 @@ TEST_F(GaussianMCMCAddNoiseTest, ZeroNoiseScaleNoOp) {
floatOpts());

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, /*noiseScale=*/0.0f);
means, logScales, logitOpacities, quats, /*noiseScale=*/0.0f, 0.005, 100.0);

EXPECT_TRUE(torch::allclose(means, origMeans));
}
Expand All @@ -146,15 +146,15 @@ TEST_F(GaussianMCMCAddNoiseTest, CpuAndPrivateUseNotImplemented) {
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, fvdb::test::tensorOpts<float>(torch::kCPU));

EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCPU>(
means, logScales, logitOpacities, quats, 1.0f)),
means, logScales, logitOpacities, quats, 1.0f, 0.005, 100)),
c10::Error);

auto meansCuda = means.cuda();
auto logScalesCuda = logScales.cuda();
auto logitOpacitiesCuda = logitOpacities.cuda();
auto quatsCuda = quats.cuda();
EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kPrivateUse1>(
meansCuda, logScalesCuda, logitOpacitiesCuda, quatsCuda, 1.0f)),
meansCuda, logScalesCuda, logitOpacitiesCuda, quatsCuda, 1.0f, 0.005, 100.0)),
c10::Error);
}

Expand Down
Loading