Skip to content

Commit 2fcf679

Browse files
committed
host side check for num sigma points
Signed-off-by: Francis Williams <francis@fwilliams.info>
1 parent 9e8e1bf commit 2fcf679

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,14 @@ dispatchGaussianProjectionForwardUT<torch::kCUDA>(
768768
TORCH_CHECK_VALUE(false, "Unknown DistortionModel for GaussianProjectionForwardUT");
769769
}
770770

771+
// This kernel currently implements only the canonical 3D UT with 2D+1 sigma points (7).
772+
// Validate on the host so misconfiguration is reported loudly instead of silently discarding.
773+
TORCH_CHECK_VALUE(
774+
utParams.numSigmaPoints == 7,
775+
"GaussianProjectionForwardUT currently supports only utParams.numSigmaPoints == 7 "
776+
"(3D UT with 2D+1 sigma points). Got ",
777+
utParams.numSigmaPoints);
778+
771779
const at::cuda::OptionalCUDAGuard device_guard(device_of(means));
772780

773781
const auto N = means.size(0); // number of gaussians

src/tests/GaussianProjectionUTTest.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,4 +963,65 @@ TEST_F(GaussianProjectionUTTestFixture, RollingShutterNone_DepthUsesStartPoseNot
963963
EXPECT_NEAR(depths_cpu[0][0].item<float>(), z, 1e-4f);
964964
}
965965

966+
TEST_F(GaussianProjectionUTTestFixture, RejectsNonSevenSigmaPointsOnHost) {
967+
const int64_t C = 1;
968+
969+
means = torch::tensor({{0.0f, 0.0f, 5.0f}}, torch::kFloat32);
970+
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
971+
logScales = torch::log(torch::tensor({{0.2f, 0.2f, 0.2f}}, torch::kFloat32));
972+
973+
worldToCamMatricesStart =
974+
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
975+
worldToCamMatricesEnd = worldToCamMatricesStart.clone();
976+
977+
projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32));
978+
auto projectionMatricesAcc = projectionMatrices.accessor<float, 3>();
979+
projectionMatricesAcc[0][0][0] = 100.0f;
980+
projectionMatricesAcc[0][1][1] = 100.0f;
981+
projectionMatricesAcc[0][0][2] = 320.0f;
982+
projectionMatricesAcc[0][1][2] = 240.0f;
983+
projectionMatricesAcc[0][2][2] = 1.0f;
984+
985+
distortionModel = DistortionModel::NONE;
986+
distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32);
987+
988+
imageWidth = 640;
989+
imageHeight = 480;
990+
eps2d = 0.3f;
991+
nearPlane = 0.1f;
992+
farPlane = 100.0f;
993+
minRadius2d = 0.0f;
994+
995+
utParams = UTParams{};
996+
utParams.numSigmaPoints = 5; // invalid: kernel supports only 7
997+
998+
means = means.cuda();
999+
quats = quats.cuda();
1000+
logScales = logScales.cuda();
1001+
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
1002+
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
1003+
projectionMatrices = projectionMatrices.cuda();
1004+
distortionCoeffs = distortionCoeffs.cuda();
1005+
1006+
EXPECT_THROW((dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
1007+
quats,
1008+
logScales,
1009+
worldToCamMatricesStart,
1010+
worldToCamMatricesEnd,
1011+
projectionMatrices,
1012+
RollingShutterType::NONE,
1013+
utParams,
1014+
distortionModel,
1015+
distortionCoeffs,
1016+
imageWidth,
1017+
imageHeight,
1018+
eps2d,
1019+
nearPlane,
1020+
farPlane,
1021+
minRadius2d,
1022+
false,
1023+
false)),
1024+
c10::Error);
1025+
}
1026+
9661027
} // namespace fvdb::detail::ops

0 commit comments

Comments
 (0)