Skip to content

Commit 3554e21

Browse files
authored
MCMC add noise kernel and gtests (#377)
Fixes #375 Implements the MCMC optimizer "add noise to gaussian positions" operation in a CUDA kernel. Signed-off-by: Mark Harris <mharris@nvidia.com>
1 parent 53f374a commit 3554e21

File tree

5 files changed

+303
-0
lines changed

5 files changed

+303
-0
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ set(FVDB_CU_FILES
106106
fvdb/detail/ops/GridEdgeNetwork.cu
107107
fvdb/detail/ops/gsplat/FusedSSIM.cu
108108
fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.cu
109+
fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu
109110
fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu
110111
fvdb/detail/ops/gsplat/GaussianProjectionForward.cu
111112
fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright Contributors to the OpenVDB Project
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h>
6+
#include <fvdb/detail/ops/gsplat/GaussianUtils.cuh>
7+
#include <fvdb/detail/utils/AccessorHelpers.cuh>
8+
#include <fvdb/detail/utils/Nvtx.h>
9+
#include <fvdb/detail/utils/cuda/GridDim.h>
10+
11+
#include <nanovdb/math/Math.h>
12+
13+
#include <c10/cuda/CUDAGuard.h>
14+
15+
namespace fvdb::detail::ops {
16+
17+
template <typename ScalarType>
18+
inline __device__ ScalarType
19+
logistic(ScalarType x, ScalarType k = 100, ScalarType x0 = 0.995) {
20+
return 1 / (1 + exp(-k * (x - x0)));
21+
}
22+
23+
template <typename ScalarType>
24+
__global__ void
25+
gaussianMCMCAddNoiseKernel(fvdb::TorchRAcc64<ScalarType, 2> outMeans,
26+
fvdb::TorchRAcc64<ScalarType, 2> logScales,
27+
fvdb::TorchRAcc64<ScalarType, 1> logitOpacities,
28+
fvdb::TorchRAcc64<ScalarType, 2> quats,
29+
fvdb::TorchRAcc64<ScalarType, 2> baseNoise,
30+
ScalarType noiseScale) {
31+
const auto N = outMeans.size(0);
32+
for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N;
33+
idx += blockDim.x * gridDim.x) {
34+
auto opacity = ScalarType(1.0) / (1 + exp(-logitOpacities[idx]));
35+
36+
const auto quatAcc = quats[idx];
37+
const auto logScaleAcc = logScales[idx];
38+
auto covar = quaternionAndScaleToCovariance<ScalarType>(
39+
nanovdb::math::Vec4<ScalarType>(quatAcc[0], quatAcc[1], quatAcc[2], quatAcc[3]),
40+
nanovdb::math::Vec3<ScalarType>(::cuda::std::exp(logScaleAcc[0]),
41+
::cuda::std::exp(logScaleAcc[1]),
42+
::cuda::std::exp(logScaleAcc[2])));
43+
44+
nanovdb::math::Vec3<ScalarType> noise = {
45+
baseNoise[idx][0], baseNoise[idx][1], baseNoise[idx][2]};
46+
noise *= logistic(1 - opacity) * noiseScale;
47+
noise = covar * noise;
48+
outMeans[idx][0] += noise[0];
49+
outMeans[idx][1] += noise[1];
50+
outMeans[idx][2] += noise[2];
51+
}
52+
}
53+
54+
template <typename ScalarType>
55+
void
56+
launchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3]
57+
const torch::Tensor &logScales, // [N, 3]
58+
const torch::Tensor &logitOpacities, // [N]
59+
const torch::Tensor &quats, // [N, 4]
60+
ScalarType noiseScale) {
61+
const auto N = means.size(0);
62+
63+
const int blockDim = DEFAULT_BLOCK_DIM;
64+
const int gridDim = fvdb::GET_BLOCKS(N, blockDim);
65+
const at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
66+
67+
auto baseNoise = torch::randn_like(means);
68+
69+
gaussianMCMCAddNoiseKernel<ScalarType><<<gridDim, blockDim, 0, stream>>>(
70+
means.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
71+
logScales.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
72+
logitOpacities.packed_accessor64<ScalarType, 1, torch::RestrictPtrTraits>(),
73+
quats.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
74+
baseNoise.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>(),
75+
noiseScale);
76+
77+
C10_CUDA_KERNEL_LAUNCH_CHECK();
78+
}
79+
80+
template <>
81+
void
82+
dispatchGaussianMCMCAddNoise<torch::kCUDA>(torch::Tensor &means, // [N, 3]
83+
const torch::Tensor &logScales, // [N]
84+
const torch::Tensor &logitOpacities, // [N]
85+
const torch::Tensor &quats, // [N, 4]
86+
float noiseScale) { // [N]
87+
FVDB_FUNC_RANGE();
88+
const at::cuda::OptionalCUDAGuard device_guard(device_of(means));
89+
90+
const auto N = means.size(0);
91+
92+
launchGaussianMCMCAddNoise<float>(means, logScales, logitOpacities, quats, noiseScale);
93+
}
94+
95+
template <>
96+
void
97+
dispatchGaussianMCMCAddNoise<torch::kPrivateUse1>(torch::Tensor &means, // [N, 3] input/output
98+
const torch::Tensor &logScales, // [N, 3]
99+
const torch::Tensor &logitOpacities, // [N]
100+
const torch::Tensor &quats, // [N, 4]
101+
float noiseScale) { // [N]
102+
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for PrivateUse1");
103+
}
104+
105+
template <>
106+
void
107+
dispatchGaussianMCMCAddNoise<torch::kCPU>(torch::Tensor &means, // [N, 3] input/output
108+
const torch::Tensor &logScales, // [N, 3]
109+
const torch::Tensor &logitOpacities, // [N]
110+
const torch::Tensor &quats, // [N, 4]
111+
float noiseScale) { // [N]
112+
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for CPU");
113+
}
114+
115+
} // namespace fvdb::detail::ops
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright Contributors to the OpenVDB Project
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H
6+
#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H
7+
8+
#include <torch/types.h>
9+
10+
namespace fvdb {
11+
namespace detail {
12+
namespace ops {
13+
14+
template <torch::DeviceType DeviceType>
15+
void dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] input/output
16+
const torch::Tensor &logScales, // [N]
17+
const torch::Tensor &logitOpacities, // [N]
18+
const torch::Tensor &quats, // [N, 4]
19+
float noiseScale); // [N]
20+
21+
} // namespace ops
22+
} // namespace detail
23+
} // namespace fvdb
24+
25+
#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H

src/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ ConfigureTest(GaussianProjectionForwardTest "GaussianProjectionForwardTest.cpp")
152152
ConfigureTest(GaussianProjectionBackwardTest "GaussianProjectionBackwardTest.cpp")
153153
ConfigureTest(GaussianRasterizeTopContributorsTest "GaussianRasterizeTopContributorsTest.cpp")
154154
ConfigureTest(GaussianRasterizeContributingGaussianIdsTest "GaussianRasterizeContributingGaussianIdsTest.cpp")
155+
ConfigureTest(GaussianMCMCAddNoiseTest "GaussianMCMCAddNoiseTest.cpp")
155156
ConfigureTest(GaussianRelocationTest "GaussianRelocationTest.cpp")
156157

157158
if(NOT NANOVDB_EDITOR_SKIP)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Copyright Contributors to the OpenVDB Project
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "utils/Tensor.h"
5+
6+
#include <fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h>
7+
8+
#include <ATen/cuda/CUDAGeneratorImpl.h>
9+
#include <torch/torch.h>
10+
11+
#include <gtest/gtest.h>
12+
13+
#include <cmath>
14+
15+
namespace {
16+
17+
// Match kernel logistic parameters (k = 100, x0 = 0.995).
18+
torch::Tensor
19+
logisticTensor(const torch::Tensor &x) {
20+
return 1.0f / (1.0f + torch::exp(-100.0f * (x - 0.995f)));
21+
}
22+
23+
class GaussianMCMCAddNoiseTest : public ::testing::Test {
24+
protected:
25+
void
26+
SetUp() override {
27+
if (!torch::cuda::is_available()) {
28+
GTEST_SKIP() << "CUDA is required for GaussianMCMCAddNoise tests";
29+
}
30+
torch::manual_seed(0);
31+
}
32+
33+
torch::TensorOptions
34+
floatOpts() const {
35+
return fvdb::test::tensorOpts<float>(torch::kCUDA);
36+
}
37+
38+
// Save the current CUDA RNG state so we can reproduce the baseNoise that
39+
// dispatchGaussianMCMCAddNoise draws internally.
40+
torch::Tensor
41+
saveCudaGeneratorState() {
42+
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
43+
return gen.get_state();
44+
}
45+
46+
void
47+
restoreCudaGeneratorState(const torch::Tensor &state) {
48+
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
49+
gen.set_state(state);
50+
}
51+
};
52+
53+
TEST_F(GaussianMCMCAddNoiseTest, AppliesNoiseWithDeterministicBaseNoise) {
54+
auto means = torch::tensor({{0.0f, 0.0f, 0.0f}, {1.0f, 2.0f, 3.0f}}, floatOpts()).contiguous();
55+
const auto logScales = torch::zeros({2, 3}, floatOpts()).contiguous(); // unit covariance
56+
const auto opacities = torch::tensor({0.25f, 0.6f}, floatOpts());
57+
const auto logitOpacities = torch::log(opacities) - torch::log1p(-opacities);
58+
const auto quats =
59+
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts())
60+
.contiguous();
61+
constexpr float noiseScale = 0.4f;
62+
63+
const auto rngState = saveCudaGeneratorState();
64+
auto meansBaseline = means.clone();
65+
66+
fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
67+
means, logScales, logitOpacities, quats, noiseScale);
68+
69+
restoreCudaGeneratorState(rngState);
70+
const auto baseNoise = torch::randn_like(meansBaseline);
71+
72+
// Expected delta on CPU: gate * noiseScale * baseNoise, then scaled by covariance diag.
73+
auto opacityCpu = opacities.cpu();
74+
auto gate = logisticTensor(torch::ones_like(opacityCpu) - opacityCpu); // [N]
75+
auto delta = baseNoise.cpu() * gate.unsqueeze(1) * noiseScale; // [N,3]
76+
const auto expected = meansBaseline.cpu() + delta;
77+
78+
EXPECT_TRUE(torch::allclose(means.cpu(), expected, 1e-5, 1e-6));
79+
}
80+
81+
TEST_F(GaussianMCMCAddNoiseTest, RespectsAnisotropicScales) {
82+
auto means = torch::zeros({1, 3}, floatOpts()).contiguous();
83+
const auto scales =
84+
torch::tensor({std::log(2.0f), std::log(1.0f), std::log(0.5f)}, floatOpts());
85+
const auto logScales = scales.view({1, 3}).contiguous();
86+
const auto opacities = torch::tensor({0.3f}, floatOpts());
87+
const auto logitOpacities = torch::log(opacities) - torch::log1p(-opacities);
88+
const auto quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts()).contiguous();
89+
constexpr float noiseScale = 1.0f;
90+
91+
const auto rngState = saveCudaGeneratorState();
92+
93+
fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
94+
means, logScales, logitOpacities, quats, noiseScale);
95+
96+
restoreCudaGeneratorState(rngState);
97+
const auto baseNoise = torch::randn_like(means);
98+
99+
auto gate = logisticTensor(torch::ones({1}, torch::kFloat32) - opacities.cpu()); // scalar
100+
const auto covarDiag = torch::pow(torch::exp(logScales.cpu()), 2); // [1,3]
101+
const auto expected = (baseNoise.cpu() * gate.unsqueeze(1) * noiseScale) * covarDiag +
102+
torch::zeros_like(baseNoise.cpu());
103+
104+
// With identity rotation, covariance is diagonal; check elementwise scaling.
105+
EXPECT_TRUE(torch::allclose(means.cpu(), expected, 1e-5, 1e-6));
106+
}
107+
108+
TEST_F(GaussianMCMCAddNoiseTest, HighOpacitySuppressesNoise) {
109+
auto means = torch::zeros({2, 3}, floatOpts()).contiguous();
110+
const auto logScales = torch::zeros({2, 3}, floatOpts()).contiguous();
111+
const auto logitOpacities = torch::full({2}, 10.0f, floatOpts()); // opacity ~ 1
112+
const auto quats =
113+
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts())
114+
.contiguous();
115+
constexpr float noiseScale = 1.0f;
116+
117+
fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
118+
means, logScales, logitOpacities, quats, noiseScale);
119+
120+
// Gate approaches zero when opacity ~1; expect negligible movement.
121+
const auto maxAbs = torch::abs(means).max().item<float>();
122+
EXPECT_LT(maxAbs, 1e-5f);
123+
}
124+
125+
TEST_F(GaussianMCMCAddNoiseTest, ZeroNoiseScaleNoOp) {
126+
auto means = torch::rand({3, 3}, floatOpts()).contiguous();
127+
const auto origMeans = means.clone();
128+
const auto logScales = torch::zeros({3, 3}, floatOpts()).contiguous();
129+
const auto opacities = torch::tensor({0.2f, 0.5f, 0.8f}, floatOpts());
130+
const auto logitOpacities = torch::log(opacities) - torch::log1p(-opacities);
131+
const auto quats = torch::tensor(
132+
{{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}},
133+
floatOpts());
134+
135+
fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCUDA>(
136+
means, logScales, logitOpacities, quats, /*noiseScale=*/0.0f);
137+
138+
EXPECT_TRUE(torch::allclose(means, origMeans));
139+
}
140+
141+
TEST_F(GaussianMCMCAddNoiseTest, CpuAndPrivateUseNotImplemented) {
142+
auto means = torch::zeros({1, 3}, fvdb::test::tensorOpts<float>(torch::kCPU));
143+
const auto logScales = torch::zeros({1, 3}, fvdb::test::tensorOpts<float>(torch::kCPU));
144+
const auto logitOpacities = torch::zeros({1}, fvdb::test::tensorOpts<float>(torch::kCPU));
145+
const auto quats =
146+
torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, fvdb::test::tensorOpts<float>(torch::kCPU));
147+
148+
EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kCPU>(
149+
means, logScales, logitOpacities, quats, 1.0f)),
150+
c10::Error);
151+
152+
auto meansCuda = means.cuda();
153+
auto logScalesCuda = logScales.cuda();
154+
auto logitOpacitiesCuda = logitOpacities.cuda();
155+
auto quatsCuda = quats.cuda();
156+
EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise<torch::kPrivateUse1>(
157+
meansCuda, logScalesCuda, logitOpacitiesCuda, quatsCuda, 1.0f)),
158+
c10::Error);
159+
}
160+
161+
} // namespace

0 commit comments

Comments
 (0)