Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ set(FVDB_CU_FILES
fvdb/detail/ops/GridEdgeNetwork.cu
fvdb/detail/ops/gsplat/FusedSSIM.cu
fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.cu
fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu
fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu
fvdb/detail/ops/gsplat/GaussianProjectionForward.cu
fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu
Expand Down
115 changes: 115 additions & 0 deletions src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright Contributors to the OpenVDB Project
// SPDX-License-Identifier: Apache-2.0
//

#include <fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h>
#include <fvdb/detail/ops/gsplat/GaussianUtils.cuh>
#include <fvdb/detail/utils/AccessorHelpers.cuh>
#include <fvdb/detail/utils/Nvtx.h>
#include <fvdb/detail/utils/cuda/GridDim.h>

#include <nanovdb/math/Math.h>

#include <c10/cuda/CUDAGuard.h>

namespace fvdb::detail::ops {

template <typename ScalarType>
inline __device__ ScalarType
logistic(ScalarType x, ScalarType k = 100, ScalarType x0 = 0.995) {
return 1 / (1 + exp(-k * (x - x0)));
}

template <typename ScalarType>
__global__ void
gaussianMCMCAddNoiseKernel(fvdb::TorchRAcc64<ScalarType, 2> outMeans,
fvdb::TorchRAcc64<ScalarType, 2> logScales,
fvdb::TorchRAcc64<ScalarType, 1> logitOpacities,
fvdb::TorchRAcc64<ScalarType, 2> quats,
fvdb::TorchRAcc64<ScalarType, 2> baseNoise,
ScalarType noiseScale) {
const auto N = outMeans.size(0);
for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N;
idx += blockDim.x * gridDim.x) {
auto opacity = ScalarType(1.0) / (1 + exp(-logitOpacities[idx]));

const auto quatAcc = quats[idx];
const auto logScaleAcc = logScales[idx];
auto covar = quaternionAndScaleToCovariance<ScalarType>(
nanovdb::math::Vec4<ScalarType>(quatAcc[0], quatAcc[1], quatAcc[2], quatAcc[3]),
nanovdb::math::Vec3<ScalarType>(::cuda::std::exp(logScaleAcc[0]),
::cuda::std::exp(logScaleAcc[1]),
::cuda::std::exp(logScaleAcc[2])));

nanovdb::math::Vec3<ScalarType> noise = {
baseNoise[idx][0], baseNoise[idx][1], baseNoise[idx][2]};
noise *= logistic(1 - opacity) * noiseScale;
noise = covar * noise;
outMeans[idx][0] += noise[0];
outMeans[idx][1] += noise[1];
outMeans[idx][2] += noise[2];
}
}

template <typename ScalarType>
void
launchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3]
const torch::Tensor &logScales, // [N, 3]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
ScalarType noiseScale) {
const auto N = means.size(0);

const int blockDim = DEFAULT_BLOCK_DIM;
const int gridDim = fvdb::GET_BLOCKS(N, blockDim);
const at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

auto baseNoise = torch::randn_like(means);

gaussianMCMCAddNoiseKernel<ScalarType><<<gridDim, blockDim, 0, stream>>>(
means.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
logScales.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
logitOpacities.packed_accessor64<ScalarType, 1, torch::RestrictPtrTraits>(),
quats.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
baseNoise.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
noiseScale);

C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <>
void
dispatchGaussianMCMCAddNoise<torch::kCUDA>(torch::Tensor &means, // [N, 3]
const torch::Tensor &logScales, // [N]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale) { // [N]
FVDB_FUNC_RANGE();
const at::cuda::OptionalCUDAGuard device_guard(device_of(means));

const auto N = means.size(0);

launchGaussianMCMCAddNoise<float>(means, logScales, logitOpacities, quats, noiseScale);
}

template <>
void
dispatchGaussianMCMCAddNoise<torch::kPrivateUse1>(torch::Tensor &means, // [N, 3] input/output
const torch::Tensor &logScales, // [N, 3]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale) { // [N]
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for PrivateUse1");
}

template <>
void
dispatchGaussianMCMCAddNoise<torch::kCPU>(torch::Tensor &means, // [N, 3] input/output
const torch::Tensor &logScales, // [N, 3]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale) { // [N]
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for CPU");
}

} // namespace fvdb::detail::ops
25 changes: 25 additions & 0 deletions src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright Contributors to the OpenVDB Project
// SPDX-License-Identifier: Apache-2.0
//

#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H
#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H

#include <torch/types.h>

namespace fvdb {
namespace detail {
namespace ops {

template <torch::DeviceType DeviceType>
void dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] input/output
const torch::Tensor &logScales, // [N]
const torch::Tensor &logitOpacities, // [N]
const torch::Tensor &quats, // [N, 4]
float noiseScale); // [N]

} // namespace ops
} // namespace detail
} // namespace fvdb

#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H
1 change: 1 addition & 0 deletions src/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ ConfigureTest(GaussianProjectionForwardTest "GaussianProjectionForwardTest.cpp")
ConfigureTest(GaussianProjectionBackwardTest "GaussianProjectionBackwardTest.cpp")
ConfigureTest(GaussianRasterizeTopContributorsTest "GaussianRasterizeTopContributorsTest.cpp")
ConfigureTest(GaussianRasterizeContributingGaussianIdsTest "GaussianRasterizeContributingGaussianIdsTest.cpp")
ConfigureTest(GaussianMCMCAddNoiseTest "GaussianMCMCAddNoiseTest.cpp")
ConfigureTest(GaussianRelocationTest "GaussianRelocationTest.cpp")

if(NOT NANOVDB_EDITOR_SKIP)
Expand Down
161 changes: 161 additions & 0 deletions src/tests/GaussianMCMCAddNoiseTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright Contributors to the OpenVDB Project
// SPDX-License-Identifier: Apache-2.0

#include "utils/Tensor.h"

#include <fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h>

#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <torch/torch.h>

#include <gtest/gtest.h>

#include <cmath>

namespace {

// Match kernel logistic parameters (k = 100, x0 = 0.995).
torch::Tensor
logisticTensor(const torch::Tensor &x) {
return 1.0f / (1.0f + torch::exp(-100.0f * (x - 0.995f)));
}

class GaussianMCMCAddNoiseTest : public ::testing::Test {
protected:
void
SetUp() override {
if (!torch::cuda::is_available()) {
GTEST_SKIP() << "CUDA is required for GaussianMCMCAddNoise tests";
}
torch::manual_seed(0);
}

torch::TensorOptions
floatOpts() const {
return fvdb::test::tensorOpts<float>(torch::kCUDA);
}

// Save the current CUDA RNG state so we can reproduce the baseNoise that
// dispatchGaussianMCMCAddNoise draws internally.
torch::Tensor
saveCudaGeneratorState() {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
return gen.get_state();
}

void
restoreCudaGeneratorState(const torch::Tensor &state) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
gen.set_state(state);
}
};

TEST_F(GaussianMCMCAddNoiseTest, AppliesNoiseWithDeterministicBaseNoise) {
auto means = torch::tensor({{0.0f, 0.0f, 0.0f}, {1.0f, 2.0f, 3.0f}}, floatOpts()).contiguous();
const auto logScales = torch::zeros({2, 3}, floatOpts()).contiguous(); // unit covariance
const auto opacities = torch::tensor({0.25f, 0.6f}, floatOpts());
const auto logitOpacities = torch::log(opacities) - torch::log1p(-opacities);
const auto quats =
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts())
.contiguous();
constexpr float noiseScale = 0.4f;

const auto rngState = saveCudaGeneratorState();
auto meansBaseline = means.clone();

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, noiseScale);

restoreCudaGeneratorState(rngState);
const auto baseNoise = torch::randn_like(meansBaseline);

// Expected delta on CPU: gate * noiseScale * baseNoise, then scaled by covariance diag.
auto opacityCpu = opacities.cpu();
auto gate = logisticTensor(torch::ones_like(opacityCpu) - opacityCpu); // [N]
auto delta = baseNoise.cpu() * gate.unsqueeze(1) * noiseScale; // [N,3]
const auto expected = meansBaseline.cpu() + delta;

EXPECT_TRUE(torch::allclose(means.cpu(), expected, 1e-5, 1e-6));
}

TEST_F(GaussianMCMCAddNoiseTest, RespectsAnisotropicScales) {
auto means = torch::zeros({1, 3}, floatOpts()).contiguous();
const auto scales =
torch::tensor({std::log(2.0f), std::log(1.0f), std::log(0.5f)}, floatOpts());
const auto logScales = scales.view({1, 3}).contiguous();
const auto opacities = torch::tensor({0.3f}, floatOpts());
const auto logitOpacities = torch::log(opacities) - torch::log1p(-opacities);
const auto quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts()).contiguous();
constexpr float noiseScale = 1.0f;

const auto rngState = saveCudaGeneratorState();

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, noiseScale);

restoreCudaGeneratorState(rngState);
const auto baseNoise = torch::randn_like(means);

auto gate = logisticTensor(torch::ones({1}, torch::kFloat32) - opacities.cpu()); // scalar
const auto covarDiag = torch::pow(torch::exp(logScales.cpu()), 2); // [1,3]
const auto expected = (baseNoise.cpu() * gate.unsqueeze(1) * noiseScale) * covarDiag +
torch::zeros_like(baseNoise.cpu());

// With identity rotation, covariance is diagonal; check elementwise scaling.
EXPECT_TRUE(torch::allclose(means.cpu(), expected, 1e-5, 1e-6));
}

TEST_F(GaussianMCMCAddNoiseTest, HighOpacitySuppressesNoise) {
auto means = torch::zeros({2, 3}, floatOpts()).contiguous();
const auto logScales = torch::zeros({2, 3}, floatOpts()).contiguous();
const auto logitOpacities = torch::full({2}, 10.0f, floatOpts()); // opacity ~ 1
const auto quats =
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts())
.contiguous();
constexpr float noiseScale = 1.0f;

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, noiseScale);

// Gate approaches zero when opacity ~1; expect negligible movement.
const auto maxAbs = torch::abs(means).max().item<float>();
EXPECT_LT(maxAbs, 1e-5f);
}

TEST_F(GaussianMCMCAddNoiseTest, ZeroNoiseScaleNoOp) {
auto means = torch::rand({3, 3}, floatOpts()).contiguous();
const auto origMeans = means.clone();
const auto logScales = torch::zeros({3, 3}, floatOpts()).contiguous();
const auto opacities = torch::tensor({0.2f, 0.5f, 0.8f}, floatOpts());
const auto logitOpacities = torch::log(opacities) - torch::log1p(-opacities);
const auto quats = torch::tensor(
{{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}},
floatOpts());

fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
means, logScales, logitOpacities, quats, /*noiseScale=*/0.0f);

EXPECT_TRUE(torch::allclose(means, origMeans));
}

TEST_F(GaussianMCMCAddNoiseTest, CpuAndPrivateUseNotImplemented) {
auto means = torch::zeros({1, 3}, fvdb::test::tensorOpts<float>(torch::kCPU));
const auto logScales = torch::zeros({1, 3}, fvdb::test::tensorOpts<float>(torch::kCPU));
const auto logitOpacities = torch::zeros({1}, fvdb::test::tensorOpts<float>(torch::kCPU));
const auto quats =
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, fvdb::test::tensorOpts<float>(torch::kCPU));

EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCPU>(
means, logScales, logitOpacities, quats, 1.0f)),
c10::Error);

auto meansCuda = means.cuda();
auto logScalesCuda = logScales.cuda();
auto logitOpacitiesCuda = logitOpacities.cuda();
auto quatsCuda = quats.cuda();
EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kPrivateUse1>(
meansCuda, logScalesCuda, logitOpacitiesCuda, quatsCuda, 1.0f)),
c10::Error);
}

} // namespace
Loading