diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 14268a59..ce8dcc07 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -110,6 +110,7 @@ set(FVDB_CU_FILES fvdb/detail/ops/gsplat/GaussianMCMCRelocation.cu fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu fvdb/detail/ops/gsplat/GaussianProjectionForward.cu + fvdb/detail/ops/gsplat/GaussianProjectionUT.cu fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.cu fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu diff --git a/src/fvdb/detail/GridBatchImpl.cu b/src/fvdb/detail/GridBatchImpl.cu index af948f34..4ebb2cb8 100644 --- a/src/fvdb/detail/GridBatchImpl.cu +++ b/src/fvdb/detail/GridBatchImpl.cu @@ -1307,7 +1307,7 @@ GridBatchImpl::deserializeV0(const torch::Tensor &serialized) { }; TORCH_CHECK(serialized.scalar_type() == torch::kInt8, "Serialized data must be of type int8"); - TORCH_CHECK(serialized.numel() >= sizeof(V01Header), + TORCH_CHECK(serialized.numel() >= static_cast(sizeof(V01Header)), "Serialized data is too small to be a valid grid handle"); const int8_t *serializedPtr = serialized.data_ptr(); @@ -1316,7 +1316,7 @@ GridBatchImpl::deserializeV0(const torch::Tensor &serialized) { TORCH_CHECK(header->magic == 0x0F0F0F0F0F0F0F0F, "Serialized data is not a valid grid handle. Bad magic."); TORCH_CHECK(header->version == 0, "Serialized data is not a valid grid handle. Bad version."); - TORCH_CHECK(serialized.numel() == header->totalBytes, + TORCH_CHECK(static_cast(serialized.numel()) == header->totalBytes, "Serialized data is not a valid grid handle. Bad total bytes."); const uint64_t numGrids = header->numGrids; @@ -1752,7 +1752,7 @@ GridBatchImpl::dilate(const int64_t dilationAmount) { c10::intrusive_ptr GridBatchImpl::dilate(const std::vector dilationAmount) { c10::DeviceGuard guard(device()); - TORCH_CHECK_VALUE(dilationAmount.size() == batchSize(), + TORCH_CHECK_VALUE(static_cast(dilationAmount.size()) == batchSize(), "dilationAmount should have same size as batch size, got ", dilationAmount.size(), " != ", diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu b/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu index e4cb2f84..950cec6e 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu +++ b/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu @@ -91,15 +91,7 @@ template struct ProjectionForward { mOutDepthsAcc(outDepths.packed_accessor64()), mOutConicsAcc(outConics.packed_accessor64()), mOutCompensationsAcc(outCompensations.defined() ? outCompensations.data_ptr() - : nullptr) { - mMeansAcc = means.packed_accessor64(); - mQuatsAcc = quats.packed_accessor64(); - mLogScalesAcc = logScales.packed_accessor64(); - mWorldToCamMatricesAcc = - worldToCamMatrices.packed_accessor32(); - mProjectionMatricesAcc = - projectionMatrices.packed_accessor32(); - } + : nullptr) {} inline __device__ Mat3 computeCovarianceMatrix(int64_t gid) const { diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu b/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu new file mode 100644 index 00000000..4c1125a2 --- /dev/null +++ b/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu @@ -0,0 +1,1240 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace fvdb::detail::ops { + +namespace { + +// OpenCV camera distortion conventions: +// https://docs.opencv.org/4.x/d9/d0c/group__calib3d.html +// Distortion coefficients are interpreted as: +// - Radial (rational): k1..k6 +// - Tangential: p1, p2 +// - Thin prism: s1..s4 +// +/// @brief OpenCV camera model (pinhole intrinsics + distortion). +/// +/// This is an internal helper used by the UT projection kernel. +/// It owns the camera intrinsics pointer `K` and the distortion coefficient pointers and can +/// project `p_cam -> pixel`. +/// @see https://docs.opencv.org/4.x/d9/d0c/group__calib3d.html +/// +/// @tparam T Scalar type +template class OpenCVCameraModel { + public: + using Vec2 = nanovdb::math::Vec2; + using Vec3 = nanovdb::math::Vec3; + using Mat3 = nanovdb::math::Mat3; + + /// @brief Internal distortion evaluation mode. + /// + /// This is intentionally separate from the public API `CameraModel` to keep the math paths + /// explicit and avoid implying that `CameraModel` is “just distortion” (since it also includes + /// projection as well). + enum class Model : uint8_t { + NONE = 0, // no distortion + RADTAN_5 = 5, // k1,k2,p1,p2,k3 (polynomial radial up to r^6) + RATIONAL_8 = 8, // k1,k2,p1,p2,k3,k4,k5,k6 (rational radial) + RADTAN_THIN_9 = 9, // RADTAN_5 + thin prism s1..s4 (polynomial radial + thin-prism) + THIN_PRISM_12 = 12, // RATIONAL_8 + s1,s2,s3,s4 + }; + + /// @brief Construct a camera model for projection. + /// + /// For OpenCV models, coefficients are read from a per-camera packed layout: + /// `[k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]`. + /// + /// For `CameraModel::PINHOLE` / `CameraModel::ORTHOGRAPHIC`, `distortionCoeffs` is ignored. + /// + /// Preconditions are asserted on-device (trap) rather than returning status codes. + /// + /// @param[in] cameraModel Public camera model selector. + /// @param[in] K_in 3x3 intrinsics matrix (typically backed by shared memory). + /// @param[in] distortionCoeffs Pointer to per-camera distortion coefficients (layout depends on + /// cameraModel). + __device__ __forceinline__ + OpenCVCameraModel(const CameraModel cameraModel, const Mat3 &K_in, const T *distortionCoeffs) + : K(K_in) { + // ORTHOGRAPHIC is implemented as pinhole intrinsics without the perspective divide. + // Distortion is intentionally not supported for orthographic projection. + orthographic = (cameraModel == CameraModel::ORTHOGRAPHIC); + + if (cameraModel == CameraModel::ORTHOGRAPHIC) { + radial = tangential = thinPrism = nullptr; + numRadial = numTangential = numThinPrism = 0; + model = Model::NONE; + return; + } + + if (cameraModel == CameraModel::PINHOLE) { + radial = tangential = thinPrism = nullptr; + numRadial = numTangential = numThinPrism = 0; + model = Model::NONE; + return; + } + + // OpenCV models require the packed coefficient layout. + deviceAssertOrTrap(distortionCoeffs != nullptr); + const T *radial_in = distortionCoeffs + kRadialOffset; // k1..k6 + const T *tang_in = distortionCoeffs + kTangentialOffset; // p1,p2 + const T *thin_in = distortionCoeffs + kThinPrismOffset; // s1..s4 + + if (cameraModel == CameraModel::OPENCV_RADTAN_5) { + radial = radial_in; + numRadial = 3; // k1,k2,k3 (polynomial) + tangential = tang_in; + numTangential = 2; + thinPrism = nullptr; + numThinPrism = 0; + model = Model::RADTAN_5; + return; + } + if (cameraModel == CameraModel::OPENCV_RATIONAL_8) { + radial = radial_in; + numRadial = 6; // k1..k6 (rational) + tangential = tang_in; + numTangential = 2; + thinPrism = nullptr; + numThinPrism = 0; + model = Model::RATIONAL_8; + return; + } + if (cameraModel == CameraModel::OPENCV_THIN_PRISM_12) { + radial = radial_in; + numRadial = 6; // k1..k6 (rational) + tangential = tang_in; + numTangential = 2; + thinPrism = thin_in; + numThinPrism = 4; // s1..s4 + model = Model::THIN_PRISM_12; + return; + } + if (cameraModel == CameraModel::OPENCV_RADTAN_THIN_PRISM_9) { + // Polynomial radial + thin-prism; ignore k4..k6 by construction. + radial = radial_in; + numRadial = 3; // k1,k2,k3 (polynomial) + tangential = tang_in; + numTangential = 2; + thinPrism = thin_in; + numThinPrism = 4; // s1..s4 + model = Model::RADTAN_THIN_9; + return; + } + + // Unknown camera model: should be unreachable if host validation is correct. + deviceAssertOrTrap(false); + } + + /// @brief Project a 3D point in camera coordinates to pixel coordinates. + /// + /// - Perspective (`PINHOLE`/`OPENCV_*`): normalize by depth \(x/z, y/z\). + /// - Orthographic (`ORTHOGRAPHIC`): no divide; uses \(x,y\) directly. + /// + /// @param[in] p_cam Point in camera coordinates. + /// @return Pixel coordinates (u,v). + __device__ Vec2 + project(const Vec3 &p_cam) const { + // Normalize to camera plane. + Vec2 p_normalized; + if (orthographic) { + p_normalized = Vec2(p_cam[0], p_cam[1]); + } else { + // For perspective models, callers are expected to reject points with small/invalid + // depth before calling `project()`. Avoid clamping z here so the pinhole math remains + // consistent near the camera plane. + const T z_inv = T(1) / p_cam[2]; + p_normalized = Vec2(p_cam[0] * z_inv, p_cam[1] * z_inv); + } + + const Vec2 p_distorted = applyDistortion(p_normalized); + + // Project to pixel coordinates. + const T fx = K[0][0]; + const T fy = K[1][1]; + const T cx = K[0][2]; + const T cy = K[1][2]; + return Vec2(fx * p_distorted[0] + cx, fy * p_distorted[1] + cy); + } + + /// @brief Whether this camera model is orthographic. + /// @return True for `CameraModel::ORTHOGRAPHIC`. + __device__ __forceinline__ bool + isOrthographic() const { + return orthographic; + } + + private: + /// @brief Device-side assert that always traps on failure. + /// + /// @param[in] cond Condition that must hold. + __device__ __forceinline__ static void + deviceAssertOrTrap(const bool cond) { + if (!cond) { + // `assert()` is typically compiled out in release builds; use a trap to guarantee a + // loud failure when invariants are violated. + asm volatile("trap;"); + } + } + + // Camera intrinsics (typically backed by shared memory). + const Mat3 &K; + bool orthographic = false; + + // Packed OpenCV coefficient layout offsets: + // [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4] + static constexpr int kRadialOffset = 0; // k1..k6 + static constexpr int kTangentialOffset = 6; // p1,p2 + static constexpr int kThinPrismOffset = 8; // s1..s4 + + // Coefficients for the distortion model. + const T *radial = nullptr; // k1..k6 (but k4..k6 only used in rational model) + int numRadial = 0; // Number of radial coefficients + const T *tangential = nullptr; // p1,p2 + int numTangential = 0; // Number of tangential coefficients + const T *thinPrism = nullptr; // s1..s4 + int numThinPrism = 0; // Number of thin prism coefficients + Model model = Model::NONE; // Distortion model + + /// @brief Read a coefficient if present, otherwise return 0. + /// + /// @param[in] ptr Coefficient array (may be null). + /// @param[in] n Number of coefficients in ptr. + /// @param[in] i Index to read. + /// @return Coefficient value or 0 if out-of-range / null. + __host__ __device__ inline T + coeffOrZero(const T *ptr, const int n, const int i) const { + return (ptr != nullptr && i >= 0 && i < n) ? ptr[i] : T(0); + } + + /// @brief Apply OpenCV distortion to a normalized camera-plane point. + /// + /// Input coordinates are assumed to already be normalized to the camera plane: + /// - perspective: \((x/z, y/z)\) + /// - orthographic: \((x, y)\) + /// + /// @param[in] p_normalized Normalized camera-plane coordinates. + /// @return Distorted normalized coordinates. + __device__ Vec2 + applyDistortion(const Vec2 &p_normalized) const { + const T x = p_normalized[0]; + const T y = p_normalized[1]; + const T x2 = x * x; + const T y2 = y * y; + const T xy = x * y; + + const T r2 = x2 + y2; + const T r4 = r2 * r2; + const T r6 = r4 * r2; + + // Radial distortion. + T radial_dist = T(1); + if (model == Model::RATIONAL_8 || model == Model::THIN_PRISM_12) { + const T k1 = coeffOrZero(radial, numRadial, 0); + const T k2 = coeffOrZero(radial, numRadial, 1); + const T k3 = coeffOrZero(radial, numRadial, 2); + const T k4 = coeffOrZero(radial, numRadial, 3); + const T k5 = coeffOrZero(radial, numRadial, 4); + const T k6 = coeffOrZero(radial, numRadial, 5); + const T num = T(1) + r2 * (k1 + r2 * (k2 + r2 * k3)); + const T den = T(1) + r2 * (k4 + r2 * (k5 + r2 * k6)); + radial_dist = (den != T(0)) ? (num / den) : T(0); + } else if (model == Model::RADTAN_5 || model == Model::RADTAN_THIN_9) { + // Polynomial radial (up to k3 / r^6). Thin-prism terms are applied below if enabled. + const T k1 = coeffOrZero(radial, numRadial, 0); + const T k2 = coeffOrZero(radial, numRadial, 1); + const T k3 = coeffOrZero(radial, numRadial, 2); + radial_dist = T(1) + k1 * r2 + k2 * r4 + k3 * r6; + } + + T x_dist = x * radial_dist; + T y_dist = y * radial_dist; + + // Tangential distortion. + const T p1 = coeffOrZero(tangential, numTangential, 0); + const T p2 = coeffOrZero(tangential, numTangential, 1); + x_dist += T(2) * p1 * xy + p2 * (r2 + T(2) * x2); + y_dist += p1 * (r2 + T(2) * y2) + T(2) * p2 * xy; + + // Thin-prism distortion. + if (model == Model::THIN_PRISM_12 || model == Model::RADTAN_THIN_9) { + const T s1 = coeffOrZero(thinPrism, numThinPrism, 0); + const T s2 = coeffOrZero(thinPrism, numThinPrism, 1); + const T s3 = coeffOrZero(thinPrism, numThinPrism, 2); + const T s4 = coeffOrZero(thinPrism, numThinPrism, 3); + x_dist += s1 * r2 + s2 * r4; + y_dist += s3 * r2 + s4 * r4; + } + + return Vec2(x_dist, y_dist); + } +}; + +/// @brief UT-local rigid transform (cached rotation + translation). +/// +/// Quaternion is stored as \([w,x,y,z]\) and is assumed to represent a rotation. +/// The corresponding rotation matrix \(R(q)\) is cached to avoid recomputing it for every point +/// transform (UT sigma points, rolling-shutter iterations, depth culls, etc.). +template struct RigidTransform { + nanovdb::math::Mat3 R; + nanovdb::math::Vec4 q; + nanovdb::math::Vec3 t; + + /// @brief Default constructor (identity transform). + /// + /// Initializes to unit quaternion \([1,0,0,0]\) and zero translation. + __device__ + RigidTransform() + : R(nanovdb::math::Mat3(nanovdb::math::Vec3(T(1), T(0), T(0)), + nanovdb::math::Vec3(T(0), T(1), T(0)), + nanovdb::math::Vec3(T(0), T(0), T(1)))), + q(T(1), T(0), T(0), T(0)), t(T(0), T(0), T(0)) {} + + /// @brief Construct from quaternion and translation. + /// @param[in] q_in Rotation quaternion \([w,x,y,z]\). + /// @param[in] t_in Translation vector. + __device__ + RigidTransform(const nanovdb::math::Vec4 &q_in, const nanovdb::math::Vec3 &t_in) + : R(quaternionToRotationMatrix(q_in)), q(q_in), t(t_in) {} + + /// @brief Construct from rotation matrix and translation. + /// @param[in] R_in Rotation matrix. + /// @param[in] t_in Translation vector. + __device__ + RigidTransform(const nanovdb::math::Mat3 &R_in, const nanovdb::math::Vec3 &t_in) + : R(R_in), q(rotationMatrixToQuaternion(R_in)), t(t_in) {} + + /// @brief Apply the transform to a 3D point: \(R(q)\,p + t\). + /// @param[in] p_world Point to transform. + /// @return Transformed point. + __device__ __forceinline__ nanovdb::math::Vec3 + apply(const nanovdb::math::Vec3 &p_world) const { + // p_cam = R * p_world + t + return R * p_world + t; + } + + /// @brief Interpolate between two rigid transforms. + /// + /// Translation is linearly interpolated; rotation uses NLERP along the shortest arc. + /// + /// @param[in] u Interpolation parameter in \([0,1]\). + /// @param[in] start Start transform. + /// @param[in] end End transform. + /// @return Interpolated transform. + static inline __device__ RigidTransform + interpolate(const T u, const RigidTransform &start, const RigidTransform &end) { + const nanovdb::math::Vec3 t_interp = start.t + u * (end.t - start.t); + const nanovdb::math::Vec4 q_interp = nlerpQuaternionShortestPath(start.q, end.q, u); + return RigidTransform(q_interp, t_interp); + } +}; + +/// @brief Projection status for a single world point. +/// +/// The kernel treats `BehindCamera` as a hard failure (discontinuous projection), while +/// `OutOfBounds` may still be usable depending on UTParams. +enum class ProjStatus : uint8_t { BehindCamera, OutOfBounds, InImage }; + +/// @brief World-space point -> pixel transform with rolling shutter. +/// +/// This wraps the camera model, rolling shutter policy, and in-image bounds checks used when +/// projecting UT sigma points. +/// +/// @tparam ScalarType Scalar type (float). +template struct WorldToPixelTransform { + using Vec2 = nanovdb::math::Vec2; + using Vec3 = nanovdb::math::Vec3; + using Mat3 = nanovdb::math::Mat3; + + const OpenCVCameraModel &camera; + RollingShutterType rollingShutterType; + int64_t imageWidth; + int64_t imageHeight; + ScalarType inImageMargin; + RigidTransform worldToCamStart; + RigidTransform worldToCamEnd; + + __device__ __forceinline__ + WorldToPixelTransform(const OpenCVCameraModel &camera_in, + const RollingShutterType rollingShutterType_in, + const int64_t imageWidth_in, + const int64_t imageHeight_in, + const ScalarType inImageMargin_in, + const RigidTransform &worldToCamStart_in, + const RigidTransform &worldToCamEnd_in) + : camera(camera_in), rollingShutterType(rollingShutterType_in), imageWidth(imageWidth_in), + imageHeight(imageHeight_in), inImageMargin(inImageMargin_in), + worldToCamStart(worldToCamStart_in), worldToCamEnd(worldToCamEnd_in) {} + + /// @brief Helper: whether a projection status is in-image. + /// @param[in] s Projection status. + /// @return True if s is InImage. + __device__ __forceinline__ static bool + isInImage(const ProjStatus s) { + return s == ProjStatus::InImage; + } + + /// @brief Transform a world-space point with a given world->cam transform and project to pixel. + /// + /// @param[in] p_world World-space point. + /// @param[in] xf World->camera transform. + /// @param[out] out_pix Pixel coordinate output (always written). + /// @return Projection status. + __device__ __forceinline__ ProjStatus + projectWithTransform(const Vec3 &p_world, + const RigidTransform &xf, + Vec2 &out_pix) const { + const Vec3 p_cam = xf.apply(p_world); + // Reject points close to/behind the camera plane. + // + // For perspective cameras, we reject z <= z_eps to avoid numerical instability and to avoid + // clamping z in the projection math (which would change the pinhole model near z=0). + // For ORTHOGRAPHIC this is a policy choice (not a mathematical necessity); we keep the + // original z<=0 behavior. + const ScalarType z_eps = camera.isOrthographic() ? ScalarType(0) : ScalarType(1e-6); + if (p_cam[2] <= z_eps) { + // Ensure deterministic output to avoid UB on callers that assign/read even on invalid + // projections. This value is ignored when we treat BehindCamera as a hard reject. + out_pix = Vec2(ScalarType(0), ScalarType(0)); + return ProjStatus::BehindCamera; + } + + out_pix = camera.project(p_cam); + const ScalarType margin_x = ScalarType(imageWidth) * inImageMargin; + const ScalarType margin_y = ScalarType(imageHeight) * inImageMargin; + const bool in_img = + (out_pix[0] >= -margin_x) && (out_pix[0] < ScalarType(imageWidth) + margin_x) && + (out_pix[1] >= -margin_y) && (out_pix[1] < ScalarType(imageHeight) + margin_y); + return in_img ? ProjStatus::InImage : ProjStatus::OutOfBounds; + } + + /// @brief Project a world-space point to pixel coordinates. + /// + /// For rolling shutter modes, this uses a small fixed-point iteration that estimates shutter + /// time from the current pixel coordinate (row/col -> time). + /// + /// @param[in] p_world World-space point. + /// @param[out] out_pixel Pixel coordinate output (always written). + /// @return Projection status. + __device__ __forceinline__ ProjStatus + projectWorldPoint(const Vec3 &p_world, Vec2 &out_pixel) const { + // Rolling shutter: iterate pose based on the current pixel estimate (row/col -> time). + + // Start/end projections for initialization. + const RigidTransform &pose_start = worldToCamStart; + const RigidTransform &pose_end = worldToCamEnd; + Vec2 pix_start(ScalarType(0), ScalarType(0)); + Vec2 pix_end(ScalarType(0), ScalarType(0)); + const ProjStatus status_start = projectWithTransform(p_world, pose_start, pix_start); + const ProjStatus status_end = projectWithTransform(p_world, pose_end, pix_end); + + if (rollingShutterType == RollingShutterType::NONE) { + out_pixel = pix_start; + return status_start; + } + + // If both endpoints are behind the camera, treat as a hard invalid (discontinuous). + if (status_start == ProjStatus::BehindCamera && status_end == ProjStatus::BehindCamera) { + out_pixel = pix_end; + return ProjStatus::BehindCamera; + } + + // If neither endpoint is in-image (but at least one is in front), treat as invalid. + // (We require an in-image seed for the fixed-point iteration.) + if (!isInImage(status_start) && !isInImage(status_end)) { + out_pixel = (status_end != ProjStatus::BehindCamera) ? pix_end : pix_start; + return ProjStatus::OutOfBounds; + } + + Vec2 pix_prev = isInImage(status_start) ? pix_start : pix_end; + // Fixed small iteration count (good enough for convergence in practice). + constexpr int kIters = 6; + for (int it = 0; it < kIters; ++it) { + ScalarType t_rs = ScalarType(0); + if (rollingShutterType == RollingShutterType::VERTICAL) { + t_rs = floor(pix_prev[1]) / max(ScalarType(1), ScalarType(imageHeight - 1)); + } else if (rollingShutterType == RollingShutterType::HORIZONTAL) { + t_rs = floor(pix_prev[0]) / max(ScalarType(1), ScalarType(imageWidth - 1)); + } + t_rs = min(ScalarType(1), max(ScalarType(0), t_rs)); + const RigidTransform pose_rs = + RigidTransform::interpolate(t_rs, worldToCamStart, worldToCamEnd); + Vec2 pix_rs(ScalarType(0), ScalarType(0)); + const ProjStatus status_rs = projectWithTransform(p_world, pose_rs, pix_rs); + pix_prev = pix_rs; + if (status_rs == ProjStatus::BehindCamera) { + out_pixel = pix_rs; + return ProjStatus::BehindCamera; + } + if (!isInImage(status_rs)) { + out_pixel = pix_rs; + return ProjStatus::OutOfBounds; + } + } + + out_pixel = pix_prev; + return ProjStatus::InImage; + } +}; + +/// @brief Generate 3D UT sigma points and weights (fixed 7-point UT in 3D). +/// +/// Sigma points are generated in **world space** from \((\mu, R, s)\) where \(R\) comes from the +/// input quaternion and \(s\) are the axis scales. For a 3D UT with the canonical \(2D+1\) +/// formulation, D=3 => 7 sigma points. +/// +/// @tparam T Scalar type. +/// @param[in] mean_world Mean in world space. +/// @param[in] quat_wxyz Rotation quaternion \([w,x,y,z]\). +/// @param[in] scale_world Axis-aligned scale in world space (per-axis standard deviation). +/// @param[in] params UT hyperparameters. +/// @param[out] sigma_points Output sigma points (size 7). +/// @param[out] weights_mean UT mean weights (size 7). +/// @param[out] weights_cov UT covariance weights (size 7). +template +__device__ void +generateWorldSigmaPoints(const nanovdb::math::Vec3 &mean_world, + const nanovdb::math::Vec4 &quat_wxyz, + const nanovdb::math::Vec3 &scale_world, + const UTParams ¶ms, + nanovdb::math::Vec3 (&sigma_points)[7], + T (&weights_mean)[7], + T (&weights_cov)[7]) { + constexpr int D = 3; + // This kernel currently supports only the canonical 3D UT with 2D+1 points. + // (We keep the arrays fixed-size for performance and simplicity.) + const T alpha = T(params.alpha); + const T beta = T(params.beta); + const T kappa = T(params.kappa); + const T lambda = alpha * alpha * (T(D) + kappa) - T(D); + const T denom = T(D) + lambda; + + // Rotation matrix from quaternion. NOTE: `quaternionToRotationMatrix` expects [w,x,y,z]. + const nanovdb::math::Mat3 R = quaternionToRotationMatrix(quat_wxyz); + + sigma_points[0] = mean_world; + weights_mean[0] = lambda / denom; + weights_cov[0] = lambda / denom + (T(1) - alpha * alpha + beta); + + const T wi = T(1) / (T(2) * denom); + for (int i = 0; i < 2 * D; ++i) { + weights_mean[i + 1] = wi; + weights_cov[i + 1] = wi; + } + + // sqrt(D + lambda) scaling + const T gamma = sqrt(max(T(0), denom)); + + // For covariance C = R * diag(scale^2) * R^T, the columns of R are the singular vectors. + // Generate sigma points: mean +/- gamma * scale[i] * col_i(R) + for (int i = 0; i < D; ++i) { + const nanovdb::math::Vec3 col_i(R[0][i], R[1][i], R[2][i]); + const nanovdb::math::Vec3 delta = (gamma * scale_world[i]) * col_i; + sigma_points[i + 1] = mean_world + delta; + sigma_points[i + 1 + D] = mean_world - delta; + } +} + +/// @brief Reconstruct a 2D covariance matrix from projected sigma points. +/// +/// This computes \(\Sigma = \sum_i w_i (x_i-\mu)(x_i-\mu)^T\). +/// +/// @tparam T Scalar type. +/// @param[in] projected_points Projected sigma points (size num_points). +/// @param[in] weights_cov Covariance weights (size num_points). +/// @param[in] num_points Number of sigma points. +/// @param[in] mean2d Precomputed 2D mean. +/// @return 2x2 covariance matrix. +template +__device__ nanovdb::math::Mat2 +reconstructCovarianceFromSigmaPoints(const nanovdb::math::Vec2 (&projected_points)[7], + const T (&weights_cov)[7], + const nanovdb::math::Vec2 &mean2d) { + nanovdb::math::Mat2 covar2d(T(0), T(0), T(0), T(0)); + constexpr int kNumSigmaPoints = 7; + for (int i = 0; i < kNumSigmaPoints; ++i) { + const nanovdb::math::Vec2 diff = projected_points[i] - mean2d; + covar2d[0][0] += weights_cov[i] * diff[0] * diff[0]; + covar2d[0][1] += weights_cov[i] * diff[0] * diff[1]; + covar2d[1][0] += weights_cov[i] * diff[1] * diff[0]; + covar2d[1][1] += weights_cov[i] * diff[1] * diff[1]; + } + return covar2d; +} + +/// @brief Enforce a positive-semidefinite 2x2 covariance matrix. +/// +/// UT covariance reconstruction can produce an indefinite (or even negative definite) matrix due +/// to negative covariance weights combined with the nonlinear projection. This function clamps the +/// eigenvalues to a minimum threshold and reconstructs the matrix, ensuring downstream operations +/// (sqrt, inverse) remain numerically well-defined. +template +__device__ __forceinline__ void +enforcePSD2x2(const T minEigen, nanovdb::math::Mat2 &covar2d) { + using Vec2 = nanovdb::math::Vec2; + + // Symmetrize defensively. + const T a = covar2d[0][0]; + const T c = covar2d[1][1]; + const T b = T(0.5) * (covar2d[0][1] + covar2d[1][0]); + + const T trace = a + c; + const T det = a * c - b * b; + + const T half_trace = T(0.5) * trace; + T disc = half_trace * half_trace - det; + disc = max(T(0), disc); + const T s = sqrt(disc); + + // Eigenvalues (v1 >= v2). + const T v1 = half_trace + s; + const T v2 = half_trace - s; + + // Clamp eigenvalues to ensure PSD + invertibility. + const T v1c = max(v1, minEigen); + const T v2c = max(v2, minEigen); + + // Eigenvector for v1. For a 2x2 symmetric matrix, we can form a stable vector from either: + // [b, v1-a] or [v1-c, b] + Vec2 u(T(1), T(0)); + const T eps = (sizeof(T) == sizeof(float)) ? T(1e-8) : T(1e-12); + if (::cuda::std::fabs(b) > eps || ::cuda::std::fabs(v1 - a) > eps || + ::cuda::std::fabs(v1 - c) > eps) { + T ux = b; + T uy = v1 - a; + // Prefer the formulation with the larger component to avoid cancellation. + if (::cuda::std::fabs(v1 - c) > ::cuda::std::fabs(v1 - a)) { + ux = v1 - c; + uy = b; + } + const T n = sqrt(ux * ux + uy * uy); + if (n > eps) { + u = Vec2(ux / n, uy / n); + } + } else { + // Diagonal (or near-diagonal) case. + u = (a >= c) ? Vec2(T(1), T(0)) : Vec2(T(0), T(1)); + } + + // Orthonormal basis. + const Vec2 v(-u[1], u[0]); + + // Reconstruct: cov = Q * diag(v1c, v2c) * Q^T + covar2d[0][0] = v1c * u[0] * u[0] + v2c * v[0] * v[0]; + covar2d[0][1] = v1c * u[0] * u[1] + v2c * v[0] * v[1]; + covar2d[1][0] = covar2d[0][1]; + covar2d[1][1] = v1c * u[1] * u[1] + v2c * v[1] * v[1]; +} + +} // namespace + +/// @brief CUDA kernel functor for UT forward projection. +/// +/// This struct owns tensor accessors, shared memory pointers, and scalar configuration for +/// projecting N gaussians into C camera views. +template struct ProjectionForwardUT { + using Mat3 = nanovdb::math::Mat3; + using Vec3 = nanovdb::math::Vec3; + using Vec4 = nanovdb::math::Vec4; + using Mat2 = nanovdb::math::Mat2; + using Vec2 = nanovdb::math::Vec2; + + // Scalar Inputs + const int64_t C; + const int64_t N; + const int32_t mImageWidth; + const int32_t mImageHeight; + const ScalarType mEps2d; + const ScalarType mNearPlane; + const ScalarType mFarPlane; + const ScalarType mRadiusClip; + const RollingShutterType mRollingShutterType; + const UTParams mUTParams; + const CameraModel mCameraModel; + const int64_t + mNumDistortionCoeffs; // Number of distortion coeffs per camera (e.g. 12 for OPENCV) + + // Tensor Inputs + const fvdb::TorchRAcc64 mMeansAcc; // [N, 3] + const fvdb::TorchRAcc64 mQuatsAcc; // [N, 4] + const fvdb::TorchRAcc64 mLogScalesAcc; // [N, 3] + const fvdb::TorchRAcc32 mWorldToCamMatricesStartAcc; // [C, 4, 4] + const fvdb::TorchRAcc32 mWorldToCamMatricesEndAcc; // [C, 4, 4] + const fvdb::TorchRAcc32 mProjectionMatricesAcc; // [C, 3, 3] + const fvdb::TorchRAcc64 mDistortionCoeffsAcc; // [C, K] + + // Outputs + fvdb::TorchRAcc64 mOutRadiiAcc; // [C, N] + fvdb::TorchRAcc64 mOutMeans2dAcc; // [C, N, 2] + fvdb::TorchRAcc64 mOutDepthsAcc; // [C, N] + fvdb::TorchRAcc64 mOutConicsAcc; // [C, N, 3] + + // Optional Outputs + // + // NOTE: This is intentionally a raw pointer to represent optional (nullable) outputs. + // Required inputs are passed/stored as references where possible to avoid null-deref hazards. + ScalarType *__restrict__ mOutCompensationsAcc; // [C, N] optional + + // Shared memory pointers + Mat3 *__restrict__ projectionMatsShared = nullptr; + Mat3 *__restrict__ worldToCamRotMatsStartShared = nullptr; + Mat3 *__restrict__ worldToCamRotMatsEndShared = nullptr; + Vec3 *__restrict__ worldToCamTranslationStartShared = nullptr; + Vec3 *__restrict__ worldToCamTranslationEndShared = nullptr; + ScalarType *__restrict__ distortionCoeffsShared = nullptr; + + /// @brief Construct the functor with configuration and tensor references. + /// + /// @param[in] imageWidth Image width in pixels. + /// @param[in] imageHeight Image height in pixels. + /// @param[in] eps2d Blur epsilon added to covariance for numerical stability. + /// @param[in] nearPlane Near-plane threshold for depth culling. + /// @param[in] farPlane Far-plane threshold for depth culling. + /// @param[in] minRadius2d Minimum radius threshold; smaller gaussians are discarded. + /// @param[in] rollingShutterType Rolling shutter policy. + /// @param[in] utParams UT hyperparameters. + /// @param[in] cameraModel Camera model selector. + /// @param[in] calcCompensations Whether to compute compensation factors. + /// @param[in] means [N,3] tensor. + /// @param[in] quats [N,4] tensor. + /// @param[in] logScales [N,3] tensor. + /// @param[in] worldToCamMatricesStart [C,4,4] tensor. + /// @param[in] worldToCamMatricesEnd [C,4,4] tensor. + /// @param[in] projectionMatrices [C,3,3] tensor. + /// @param[in] distortionCoeffs [C,K] tensor (K=0 for PINHOLE/ORTHOGRAPHIC; K=12 for OPENCV). + /// @param[out] outRadii [C,N] tensor. + /// @param[out] outMeans2d [C,N,2] tensor. + /// @param[out] outDepths [C,N] tensor. + /// @param[out] outConics [C,N,3] tensor. + /// @param[out] outCompensations [C,N] tensor (optional, may be undefined). + ProjectionForwardUT(const int64_t imageWidth, + const int64_t imageHeight, + const ScalarType eps2d, + const ScalarType nearPlane, + const ScalarType farPlane, + const ScalarType minRadius2d, + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const CameraModel cameraModel, + const bool calcCompensations, + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const torch::Tensor &distortionCoeffs, // [C, K] + torch::Tensor &outRadii, // [C, N] + torch::Tensor &outMeans2d, // [C, N, 2] + torch::Tensor &outDepths, // [C, N] + torch::Tensor &outConics, // [C, N, 3] + torch::Tensor &outCompensations // [C, N] optional + ) + : C(projectionMatrices.size(0)), N(means.size(0)), + mImageWidth(static_cast(imageWidth)), + mImageHeight(static_cast(imageHeight)), mEps2d(eps2d), mNearPlane(nearPlane), + mFarPlane(farPlane), mRadiusClip(minRadius2d), mRollingShutterType(rollingShutterType), + mUTParams(utParams), mCameraModel(cameraModel), + mNumDistortionCoeffs(distortionCoeffs.size(1)), + mMeansAcc(means.packed_accessor64()), + mQuatsAcc(quats.packed_accessor64()), + mLogScalesAcc(logScales.packed_accessor64()), + mWorldToCamMatricesStartAcc( + worldToCamMatricesStart.packed_accessor32()), + mWorldToCamMatricesEndAcc( + worldToCamMatricesEnd.packed_accessor32()), + mProjectionMatricesAcc( + projectionMatrices.packed_accessor32()), + mDistortionCoeffsAcc( + distortionCoeffs.packed_accessor64()), + mOutRadiiAcc(outRadii.packed_accessor64()), + mOutMeans2dAcc(outMeans2d.packed_accessor64()), + mOutDepthsAcc(outDepths.packed_accessor64()), + mOutConicsAcc(outConics.packed_accessor64()), + mOutCompensationsAcc(outCompensations.defined() ? outCompensations.data_ptr() + : nullptr) {} + + /// @brief Load per-camera matrices/coeffs into shared memory for faster access. + /// + /// Layout is `[K, R_start, R_end, t_start, t_end, distortionCoeffs]` per camera. + inline __device__ void + loadCameraInfoIntoSharedMemory() { + // Load projection matrices and world-to-camera matrices into shared memory + alignas(Mat3) extern __shared__ char sharedMemory[]; + + // Alignment sanity checks for the shared-memory layout below. If any of these fail, the + // pointer-bump scheme could produce misaligned pointers and UB. + static_assert(alignof(Mat3) >= alignof(Vec3), "Mat3 alignment must cover Vec3 alignment"); + static_assert(alignof(Mat3) >= alignof(ScalarType), + "Mat3 alignment must cover ScalarType alignment"); + + constexpr int64_t kMat3Elements = 9; // 3x3 + constexpr int64_t kVec3Elements = 3; // 3 + + // Keep a running pointer which we increment to assign shared memory blocks + uint8_t *pointer = reinterpret_cast(sharedMemory); + + projectionMatsShared = reinterpret_cast(pointer); + pointer += C * sizeof(Mat3); + + worldToCamRotMatsStartShared = reinterpret_cast(pointer); + pointer += C * sizeof(Mat3); + + worldToCamRotMatsEndShared = reinterpret_cast(pointer); + pointer += C * sizeof(Mat3); + + worldToCamTranslationStartShared = reinterpret_cast(pointer); + pointer += C * sizeof(Vec3); + + worldToCamTranslationEndShared = reinterpret_cast(pointer); + pointer += C * sizeof(Vec3); + + distortionCoeffsShared = + mNumDistortionCoeffs > 0 ? reinterpret_cast(pointer) : nullptr; + pointer += C * mNumDistortionCoeffs * sizeof(ScalarType); + + // Layout in element units: + const int64_t projectionOffset = 0; + const int64_t rotStartOffset = projectionOffset + C * kMat3Elements; + const int64_t rotEndOffset = rotStartOffset + C * kMat3Elements; + const int64_t transStartOffset = rotEndOffset + C * kMat3Elements; + const int64_t transEndOffset = transStartOffset + C * kVec3Elements; + const int64_t distortionOffset = transEndOffset + C * kVec3Elements; + const int64_t totalElements = distortionOffset + C * mNumDistortionCoeffs; + + for (int64_t i = threadIdx.x; i < totalElements; i += blockDim.x) { + if (i < rotStartOffset) { + const auto camId = (i - projectionOffset) / kMat3Elements; + const auto entryId = (i - projectionOffset) % kMat3Elements; + const auto rowId = entryId / kVec3Elements; + const auto colId = entryId % kVec3Elements; + projectionMatsShared[camId][rowId][colId] = + mProjectionMatricesAcc[camId][rowId][colId]; + } else if (i < rotEndOffset) { + const auto camId = (i - rotStartOffset) / kMat3Elements; + const auto entryId = (i - rotStartOffset) % kMat3Elements; + const auto rowId = entryId / kVec3Elements; + const auto colId = entryId % kVec3Elements; + worldToCamRotMatsStartShared[camId][rowId][colId] = + mWorldToCamMatricesStartAcc[camId][rowId][colId]; + } else if (i < transStartOffset) { + const auto camId = (i - rotEndOffset) / kMat3Elements; + const auto entryId = (i - rotEndOffset) % kMat3Elements; + const auto rowId = entryId / kVec3Elements; + const auto colId = entryId % kVec3Elements; + worldToCamRotMatsEndShared[camId][rowId][colId] = + mWorldToCamMatricesEndAcc[camId][rowId][colId]; + } else if (i < transEndOffset) { + const auto camId = (i - transStartOffset) / kVec3Elements; + const auto entryId = (i - transStartOffset) % kVec3Elements; + worldToCamTranslationStartShared[camId][entryId] = + mWorldToCamMatricesStartAcc[camId][entryId][3]; + } else if (i < distortionOffset) { + const auto camId = (i - transEndOffset) / kVec3Elements; + const auto entryId = (i - transEndOffset) % kVec3Elements; + worldToCamTranslationEndShared[camId][entryId] = + mWorldToCamMatricesEndAcc[camId][entryId][3]; + } else if (mNumDistortionCoeffs > 0) { + const auto baseIdx = i - distortionOffset; + const auto camId = baseIdx / mNumDistortionCoeffs; + const auto entryId = baseIdx % mNumDistortionCoeffs; + distortionCoeffsShared[camId * mNumDistortionCoeffs + entryId] = + mDistortionCoeffsAcc[camId][entryId]; + } + } + } + + /// @brief Project one gaussian for one camera. + /// + /// @param[in] idx Flattened index in \([0, C*N)\) mapping to (camId, gaussianId). + /// @return true if the gaussian is projected successfully, false otherwise. + inline __device__ void + projectionForward(int64_t idx) { + if (idx >= C * N) { + return; + } + + const int64_t camId = idx / N; + const int64_t gaussianId = idx % N; + + // Get camera parameters + const Mat3 &projectionMatrix = projectionMatsShared[camId]; + const Mat3 &worldToCamRotStart = worldToCamRotMatsStartShared[camId]; + const Mat3 &worldToCamRotEnd = worldToCamRotMatsEndShared[camId]; + const Vec3 &worldToCamTransStart = worldToCamTranslationStartShared[camId]; + const Vec3 &worldToCamTransEnd = worldToCamTranslationEndShared[camId]; + const ScalarType *distortionCoeffs = + (mNumDistortionCoeffs > 0) ? &distortionCoeffsShared[camId * mNumDistortionCoeffs] + : nullptr; + + // Define the camera model (projection and distortion) using the shared memory pointers + OpenCVCameraModel camera(mCameraModel, projectionMatrix, distortionCoeffs); + + // Define the world-to-camera transforms using the shared memory pointers at the start and + // end of the shutter period + const RigidTransform worldToCamStart(worldToCamRotStart, + worldToCamTransStart); // t=0.0 + const RigidTransform worldToCamEnd(worldToCamRotEnd, + worldToCamTransEnd); // t=1.0 + + // Get Gaussian parameters + const Vec3 meanWorldSpace( + mMeansAcc[gaussianId][0], mMeansAcc[gaussianId][1], mMeansAcc[gaussianId][2]); + const Vec4 quat_wxyz(mQuatsAcc[gaussianId][0], + mQuatsAcc[gaussianId][1], + mQuatsAcc[gaussianId][2], + mQuatsAcc[gaussianId][3]); + const Vec3 scale_world(::cuda::std::exp(mLogScalesAcc[gaussianId][0]), + ::cuda::std::exp(mLogScalesAcc[gaussianId][1]), + ::cuda::std::exp(mLogScalesAcc[gaussianId][2])); + + // Depth culling should use the same shutter pose as projection: + // - RollingShutterType::NONE: use start pose (t=0.0), matching + // `WorldToPixelTransform::projectWorldPoint` which uses the start transform when NONE. + // - Rolling shutter modes: use center pose (t=0.5) as a conservative/stable cull. + { + const RigidTransform shutter_pose = + (mRollingShutterType == RollingShutterType::NONE) + ? worldToCamStart + : RigidTransform::interpolate( + ScalarType(0.5), worldToCamStart, worldToCamEnd); + const Vec3 meanCam = shutter_pose.apply(meanWorldSpace); + if (meanCam[2] < mNearPlane || meanCam[2] > mFarPlane) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + } + + // Generate world-space sigma points (7) and UT weights (mean/cov). + nanovdb::math::Vec3 sigma_points_world[7]; + ScalarType weights_mean[7]; + ScalarType weights_cov[7]; + generateWorldSigmaPoints(meanWorldSpace, + quat_wxyz, + scale_world, + mUTParams, + sigma_points_world, + weights_mean, + weights_cov); + + const WorldToPixelTransform worldToPixel(camera, + mRollingShutterType, + mImageWidth, + mImageHeight, + ScalarType(mUTParams.inImageMargin), + worldToCamStart, + worldToCamEnd); + + // Project sigma points through camera model + nanovdb::math::Vec2 projected_points[7]; + bool valid_any = false; + constexpr int kNumSigmaPoints = 7; + for (int i = 0; i < kNumSigmaPoints; ++i) { + Vec2 pix; + const ProjStatus status = worldToPixel.projectWorldPoint(sigma_points_world[i], pix); + // Hard reject if any sigma point is behind the camera since the projection will be + // discontinuous. + if (status == ProjStatus::BehindCamera) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + const bool valid_i = (status == ProjStatus::InImage); + projected_points[i] = pix; + valid_any |= valid_i; + if (mUTParams.requireAllSigmaPointsInImage && !valid_i) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + } + + if (!mUTParams.requireAllSigmaPointsInImage && !valid_any) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + + // Compute mean of projected points + nanovdb::math::Vec2 mean2d(ScalarType(0), ScalarType(0)); + for (int i = 0; i < kNumSigmaPoints; ++i) { + mean2d[0] += weights_mean[i] * projected_points[i][0]; + mean2d[1] += weights_mean[i] * projected_points[i][1]; + } + + // Reconstruct 2D covariance from projected sigma points + Mat2 covar2d = reconstructCovarianceFromSigmaPoints(projected_points, weights_cov, mean2d); + + // Add blur for numerical stability + ScalarType compensation; + const ScalarType det_blur = addBlur(mEps2d, covar2d, compensation); + if (det_blur <= ScalarType(0)) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + + // Ensure reconstructed covariance is PSD to avoid NaNs when taking square-roots or + // inverting. + enforcePSD2x2(mEps2d, covar2d); + + const ScalarType det_psd = covar2d[0][0] * covar2d[1][1] - covar2d[0][1] * covar2d[1][0]; + if (!(det_psd > ScalarType(0))) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + + const Mat2 covar2dInverse = covar2d.inverse(); + + // Compute bounding box radius (similar to standard projection) + const ScalarType b = 0.5f * (covar2d[0][0] + covar2d[1][1]); + const ScalarType tmp = sqrtf(max(0.01f, b * b - det_psd)); + const ScalarType v1 = b + tmp; // larger eigenvalue + const ScalarType extend = 3.0f; // 3 sigma + ScalarType r1 = extend * sqrtf(v1); + ScalarType radius_x = ceilf(min(extend * sqrtf(covar2d[0][0]), r1)); + ScalarType radius_y = ceilf(min(extend * sqrtf(covar2d[1][1]), r1)); + + if (radius_x <= mRadiusClip && radius_y <= mRadiusClip) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + + // Mask out gaussians outside the image region + if (mean2d[0] + radius_x <= 0 || mean2d[0] - radius_x >= mImageWidth || + mean2d[1] + radius_y <= 0 || mean2d[1] - radius_y >= mImageHeight) { + mOutRadiiAcc[camId][gaussianId] = 0; + return; + } + + // Write outputs (using radius_x for compatibility, but could use both) + mOutRadiiAcc[camId][gaussianId] = int32_t(max(radius_x, radius_y)); + mOutMeans2dAcc[camId][gaussianId][0] = mean2d[0]; + mOutMeans2dAcc[camId][gaussianId][1] = mean2d[1]; + // For depth we use the same shutter pose as the cull check above. + { + const ScalarType t_depth = (mRollingShutterType == RollingShutterType::NONE) + ? ScalarType(0.0) + : ScalarType(0.5); + const RigidTransform shutter_pose = + RigidTransform::interpolate(t_depth, worldToCamStart, worldToCamEnd); + const Vec3 meanCam = shutter_pose.apply(meanWorldSpace); + mOutDepthsAcc[camId][gaussianId] = meanCam[2]; + } + mOutConicsAcc[camId][gaussianId][0] = covar2dInverse[0][0]; + mOutConicsAcc[camId][gaussianId][1] = covar2dInverse[0][1]; + mOutConicsAcc[camId][gaussianId][2] = covar2dInverse[1][1]; + if (mOutCompensationsAcc != nullptr) { + mOutCompensationsAcc[idx] = compensation; + } + } +}; + +/// @brief CUDA kernel wrapper for `ProjectionForwardUT`. +/// +/// Each thread processes multiple (camera, gaussian) pairs in a grid-stride loop. +template +__global__ __launch_bounds__(256) void +projectionForwardUTKernel(int64_t offset, + int64_t count, + ProjectionForwardUT projectionForward) { + projectionForward.loadCameraInfoIntoSharedMemory(); + __syncthreads(); + + // parallelize over C * N + for (auto idx = blockIdx.x * blockDim.x + threadIdx.x; idx < count; + idx += blockDim.x * gridDim.x) { + projectionForward.projectionForward(idx + offset); + } +} + +/// @brief CUDA specialization for UT forward projection dispatch. +/// +/// Performs host-side validation and launches `projectionForwardUTKernel`. +template <> +std::tuple +dispatchGaussianProjectionForwardUT( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const CameraModel cameraModel, + const torch::Tensor &distortionCoeffs, // [C,12] for OPENCV, [C,0] for NONE + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations) { + FVDB_FUNC_RANGE(); + + TORCH_CHECK_VALUE(means.is_cuda(), "means must be a CUDA tensor"); + TORCH_CHECK_VALUE(quats.is_cuda(), "quats must be a CUDA tensor"); + TORCH_CHECK_VALUE(logScales.is_cuda(), "logScales must be a CUDA tensor"); + TORCH_CHECK_VALUE(worldToCamMatricesStart.is_cuda(), + "worldToCamMatricesStart must be a CUDA tensor"); + TORCH_CHECK_VALUE(worldToCamMatricesEnd.is_cuda(), + "worldToCamMatricesEnd must be a CUDA tensor"); + TORCH_CHECK_VALUE(projectionMatrices.is_cuda(), "projectionMatrices must be a CUDA tensor"); + TORCH_CHECK_VALUE(distortionCoeffs.is_cuda(), "distortionCoeffs must be a CUDA tensor"); + TORCH_CHECK_VALUE(distortionCoeffs.dim() == 2, "distortionCoeffs must be 2D"); + + // Validate UT hyperparameters on the host to avoid inf/NaNs from invalid scaling/weights. + // In the 3D UT, D=3 and: + // lambda = alpha^2 * (D + kappa) - D + // denom = D + lambda = alpha^2 * (D + kappa) + // denom must be finite and strictly positive. + constexpr float kUtDim = 3.0f; + TORCH_CHECK_VALUE(std::isfinite(utParams.alpha), "utParams.alpha must be finite"); + TORCH_CHECK_VALUE(std::isfinite(utParams.beta), "utParams.beta must be finite"); + TORCH_CHECK_VALUE(std::isfinite(utParams.kappa), "utParams.kappa must be finite"); + TORCH_CHECK_VALUE(utParams.alpha > 0.0f, "utParams.alpha must be > 0"); + TORCH_CHECK_VALUE(kUtDim + utParams.kappa > 0.0f, + "utParams.kappa must satisfy (D + kappa) > 0 for the 3D UT (D=3)"); + const float denom = utParams.alpha * utParams.alpha * (kUtDim + utParams.kappa); + TORCH_CHECK_VALUE(std::isfinite(denom) && denom > 0.0f, + "Invalid UTParams: expected denom = alpha^2*(D+kappa) to be finite and > 0"); + + if (cameraModel == CameraModel::PINHOLE || cameraModel == CameraModel::ORTHOGRAPHIC) { + // Distortion coefficients are ignored for these camera models. + // (Intrinsics `projectionMatrices` are always used.) + } else if (cameraModel == CameraModel::OPENCV_RADTAN_5 || + cameraModel == CameraModel::OPENCV_RATIONAL_8 || + cameraModel == CameraModel::OPENCV_RADTAN_THIN_PRISM_9 || + cameraModel == CameraModel::OPENCV_THIN_PRISM_12) { + TORCH_CHECK_VALUE(distortionCoeffs.size(1) == 12, + "For CameraModel::OPENCV_* , distortionCoeffs must have shape [C,12] " + "as [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]"); + } else { + TORCH_CHECK_VALUE(false, "Unknown CameraModel for GaussianProjectionForwardUT"); + } + + // This kernel implements only the canonical 3D UT with 2D+1 sigma points (7). + + const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); + + const auto N = means.size(0); // number of gaussians + const auto C = projectionMatrices.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(means.device().index()); + + TORCH_CHECK_VALUE(distortionCoeffs.size(0) == C, + "distortionCoeffs must have shape [C,K] matching projectionMatrices.size(0)"); + + torch::Tensor outRadii = torch::empty({C, N}, means.options().dtype(torch::kInt32)); + torch::Tensor outMeans2d = torch::empty({C, N, 2}, means.options()); + torch::Tensor outDepths = torch::empty({C, N}, means.options()); + torch::Tensor outConics = torch::empty({C, N, 3}, means.options()); + torch::Tensor outCompensations; + if (calcCompensations) { + outCompensations = torch::zeros({C, N}, means.options()); + } + + if (N == 0 || C == 0) { + return std::make_tuple(outRadii, outMeans2d, outDepths, outConics, outCompensations); + } + + using scalar_t = float; + + const size_t NUM_BLOCKS = GET_BLOCKS(C * N, 256); + // Orthographic is supported only for CameraModel::ORTHOGRAPHIC (undistorted). + + const size_t SHARED_MEM_SIZE = C * (3 * sizeof(nanovdb::math::Mat3) + + 2 * sizeof(nanovdb::math::Vec3)) + + C * distortionCoeffs.size(1) * sizeof(scalar_t); + + ProjectionForwardUT projectionForward(imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + rollingShutterType, + utParams, + cameraModel, + calcCompensations, + means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + distortionCoeffs, + outRadii, + outMeans2d, + outDepths, + outConics, + outCompensations); + + projectionForwardUTKernel + <<>>(0, C * N, projectionForward); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(outRadii, outMeans2d, outDepths, outConics, outCompensations); +} + +/// @brief CPU specialization (not implemented). +template <> +std::tuple +dispatchGaussianProjectionForwardUT( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const CameraModel cameraModel, + const torch::Tensor &distortionCoeffs, // [C,12] for OPENCV, [C,0] for NONE + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianProjectionForwardUT not implemented on the CPU"); +} + +/// @brief PrivateUse1 specialization (not implemented). +template <> +std::tuple +dispatchGaussianProjectionForwardUT( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const CameraModel cameraModel, + const torch::Tensor &distortionCoeffs, // [C,12] for OPENCV, [C,0] for NONE + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "GaussianProjectionForwardUT not implemented for this device type"); +} + +} // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h b/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h new file mode 100644 index 00000000..6fc780c6 --- /dev/null +++ b/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h @@ -0,0 +1,141 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONUT_H +#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONUT_H + +#include +#include + +#include + +namespace fvdb { +namespace detail { +namespace ops { + +enum class RollingShutterType { NONE = 0, VERTICAL = 1, HORIZONTAL = 2 }; + +/// @brief Camera model for projection in the UT kernel. +/// +/// This enum describes the camera projection family used by the UT kernel. It is intentionally +/// broader than "distortion model" so we can add more complex models (e.g. RPC) later. +/// +/// Notes: +/// - `PINHOLE` and `ORTHOGRAPHIC` ignore `distortionCoeffs`. +/// - `OPENCV_*` variants are pinhole + OpenCV-style distortion, and require packed coefficients. +enum class CameraModel : int32_t { + // Pinhole intrinsics only (no distortion). + PINHOLE = 0, + + // Orthographic intrinsics (no distortion). + ORTHOGRAPHIC = 5, + + // OpenCV variants which are just pinhole intrinsics + optional distortion (all of them use the + // same [C,12] distortion coefficients layout: [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]). + OPENCV_RADTAN_5 = 1, // polynomial radial (k1,k2,k3) + tangential (p1,p2)). + OPENCV_RATIONAL_8 = 2, // rational radial (k1..k6) + tangential (p1,p2)). + OPENCV_RADTAN_THIN_PRISM_9 = 3, // polynomial radial + tangential + thin-prism (s1..s4)). + OPENCV_THIN_PRISM_12 = 4, // rational radial + tangential + thin-prism (s1..s4)). +}; + +/// @brief Unscented Transform hyperparameters. +/// +/// This kernel implements the canonical 3D UT with a fixed \(2D+1\) sigma point set (7 points). +/// The parameters here control the standard UT scaling / weighting. +struct UTParams { + float alpha = 0.1f; // Blending parameter for UT + float beta = 2.0f; // Scaling parameter for UT + float kappa = 0.0f; // Additional scaling parameter for UT + float inImageMargin = 0.1f; // Margin for in-image check + bool requireAllSigmaPointsInImage = + true; // Require all sigma points to be in image to consider a Gaussian valid +}; + +/// @brief Project 3D Gaussians to 2D screen space pixel coordinates for rendering using the +/// Unscented Transform (UT) algorithm. +/// +/// This function transforms 3D Gaussians to 2D screen space by applying camera projections. +/// It computes the 2D means, depths, 2D covariance matrices (conics), and potentially compensation +/// factors to accurately represent the 3D Gaussians in 2D for later rasterization. +/// +/// The origin of the 2D pixel coordinates is the top-left corner of the image, with positive x-axis +/// pointing to the right and positive y-axis pointing downwards. +/// +/// @attention The output radii of 3D Gaussians that are discarded (due to clipping or projection +/// too small) are set to zero, but the other output values of discarded Gaussians are uninitialized +/// (undefined). +/// +/// The UT algorithm is a non-parametric method for approximating the mean and covariance of a +/// probability distribution. It is used to project 3D Gaussians to 2D screen space by applying +/// camera projections. +/// +/// High-level algorithm: +/// 1. **Generate sigma points** in world space for each 3D Gaussian (fixed 7-point UT in 3D). +/// 2. **Project** each sigma point to pixels using the selected `CameraModel` and rolling-shutter +/// policy. +/// 3. **Reconstruct** the 2D mean and covariance from the projected sigma points + UT weights. +/// 4. **Stabilize** covariance by adding a small blur term (`eps2d`) and compute the conic form. +/// 5. **Cull** gaussians that are out-of-range (near/far) or too small (min radius), and write +/// outputs for the survivors. +/// +/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) +/// +/// @param[in] means 3D positions of Gaussians [N, 3] where N is number of Gaussians +/// @param[in] quats Quaternion rotations of Gaussians [N, 4] in format (w, x, y, z) +/// @param[in] logScales Log-scale factors of Gaussians [N, 3] (natural log), representing extent in +/// each dimension +/// @param[in] worldToCamMatricesStart Camera view matrices at the start of the frame. Shape [C, 4, +/// 4] where C is number of cameras +/// @param[in] worldToCamMatricesEnd Camera view matrices at the end of the frame. Shape [C, 4, 4] +/// where C is number of cameras +/// @param[in] projectionMatrices Camera intrinsic matrices [C, 3, 3] +/// @param[in] rollingShutterType Type of rolling shutter effect to apply +/// @param[in] utParams Unscented Transform parameters +/// @param[in] cameraModel Camera model for projection. +/// @param[in] distortionCoeffs Distortion coefficients for each camera. +/// - CameraModel::PINHOLE: ignored (use [C,0] or [C,K] tensor). +/// - CameraModel::ORTHOGRAPHIC: ignored (use [C,0] or [C,K] tensor). +/// - CameraModel::OPENCV_*: expects [C,12] coefficients in the following order: +/// [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4] +/// where k1..k6 are radial (rational), p1,p2 are tangential, and s1..s4 are thin-prism. +/// @param[in] imageWidth Width of the output image in pixels +/// @param[in] imageHeight Height of the output image in pixels +/// @param[in] eps2d 2D projection epsilon for numerical stability +/// @param[in] nearPlane Near clipping plane distance +/// @param[in] farPlane Far clipping plane distance +/// @param[in] minRadius2d Minimum 2D radius threshold; Gaussians with projected radius <= this +/// value are clipped/discarded +/// @param[in] calcCompensations Whether to calculate view-dependent compensation factors +/// +/// @return std::tuple containing: +/// - Radii of 2D Gaussians [C, N] +/// - 2D projected Gaussian centers [C, N, 2] +/// - Depths of Gaussians [C, N] +/// - Covariance matrices in conic form [C, N, 3] representing (a, b, c) in ax² + 2bxy + cy² +/// - Compensation factors [C, N] (if calc_compensations is true, otherwise empty tensor) +template +std::tuple +dispatchGaussianProjectionForwardUT( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const CameraModel cameraModel, + const torch::Tensor &distortionCoeffs, // [C, 12] for OPENCV_*, or [C, 0] for PINHOLE/ORTHO + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONUT_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh b/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh index 65894499..87694cd9 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh +++ b/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh @@ -44,6 +44,113 @@ binSearch(const T *arr, const uint32_t len, const T val) { return low - 1; } +/// @brief Converts a 3x3 rotation matrix to a quaternion. +/// +/// This function converts a 3x3 rotation matrix to the equivalent quaternion \([w,x,y,z]\) and +/// normalizes the result. +/// +/// The implementation uses the standard **branch-based** algorithm for numerical robustness: +/// - If \(\mathrm{trace}(R) > 0\), it uses the closed-form trace formula +/// \(w = \tfrac{1}{2}\sqrt{1 + \mathrm{trace}(R)}\) and derives \((x,y,z)\) from the +/// off-diagonals. +/// - Otherwise it selects the largest diagonal element and computes the quaternion from that +/// branch (x-dominant / y-dominant / z-dominant cases). +/// +/// Degenerate inputs (e.g. non-rotation matrices, NaNs) are guarded against to avoid division by +/// near-zero intermediates; in such cases the function falls back to the identity quaternion. +/// +/// @param R Input 3x3 rotation matrix. +/// @return nanovdb::math::Vec4 Quaternion equivalent to the rotation matrix. +template +__host__ __device__ nanovdb::math::Vec4 +rotationMatrixToQuaternion(const nanovdb::math::Mat3 &R) { + T trace = R[0][0] + R[1][1] + R[2][2]; + T x, y, z, w; + + // Guard against division by ~0 in the branch formulas below. + // This can happen for degenerate / NaN inputs where `t` underflows to 0 or is clamped to 0, + // causing `s = 2*sqrt(t)` to be 0, while the numerators remain finite -> inf/NaN. + const T s_min = (sizeof(T) == sizeof(float)) ? T(1e-8) : T(1e-12); + + if (trace > 0) { + T t = trace + T(1); + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); // S=4*qw + if (!(s > s_min)) { + // Degenerate input; fall back to identity. + w = T(1); + x = y = z = T(0); + } else { + w = T(0.25) * s; + x = (R[2][1] - R[1][2]) / s; + y = (R[0][2] - R[2][0]) / s; + z = (R[1][0] - R[0][1]) / s; + } + } else if ((R[0][0] > R[1][1]) && (R[0][0] > R[2][2])) { + T t = T(1) + R[0][0] - R[1][1] - R[2][2]; + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); // S=4*qx + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = (R[2][1] - R[1][2]) / s; + x = T(0.25) * s; + y = (R[0][1] + R[1][0]) / s; + z = (R[0][2] + R[2][0]) / s; + } + } else if (R[1][1] > R[2][2]) { + T t = T(1) + R[1][1] - R[0][0] - R[2][2]; + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); // S=4*qy + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = (R[0][2] - R[2][0]) / s; + x = (R[0][1] + R[1][0]) / s; + y = T(0.25) * s; + z = (R[1][2] + R[2][1]) / s; + } + } else { + T t = T(1) + R[2][2] - R[0][0] - R[1][1]; + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); // S=4*qz + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = (R[1][0] - R[0][1]) / s; + x = (R[0][2] + R[2][0]) / s; + y = (R[1][2] + R[2][1]) / s; + z = T(0.25) * s; + } + } + + // Normalize to guard against accumulated FP error / slightly non-orthonormal inputs. + const T norm2 = (w * w + x * x + y * y + z * z); + if (norm2 > T(0)) { + const T invNorm = T(1) / sqrt(norm2); + w *= invNorm; + x *= invNorm; + y *= invNorm; + z *= invNorm; + } else { + // Degenerate input; fall back to identity. + w = T(1); + x = y = z = T(0); + } + + // Optional convention: keep a consistent sign (q and -q represent the same rotation). + if (w < T(0)) { + w = -w; + x = -x; + y = -y; + z = -z; + } + return nanovdb::math::Vec4(w, x, y, z); +} + /// @brief Converts a quaternion to a 3x3 rotation matrix /// /// This function takes a quaternion [w,x,y,z] and converts it to the equivalent @@ -86,6 +193,61 @@ quaternionToRotationMatrix(nanovdb::math::Vec4 const &quat) { ); } +/// @brief Normalizes a quaternion to unit length +/// +/// This function normalizes a quaternion to unit length. If the quaternion is zero, it is set to +/// the identity quaternion. +/// +/// @param q Input quaternion [w,x,y,z] +/// @return nanovdb::math::Vec4 Normalized quaternion +template +inline __host__ __device__ nanovdb::math::Vec4 +normalizeQuaternionSafe(nanovdb::math::Vec4 q) { + const T n2 = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]; + if (n2 > T(0)) { + const T invN = T(1) / sqrt(n2); + q[0] *= invN; + q[1] *= invN; + q[2] *= invN; + q[3] *= invN; + } else { + q[0] = T(1); + q[1] = q[2] = q[3] = T(0); + } + return q; +} + +/// @brief Interpolates between two quaternions using normalized linear interpolation along the +/// shortest path +/// +/// This function interpolates between two quaternions using normalized linear interpolation along +/// the shortest path. +/// +/// @param q0 First quaternion [w,x,y,z] +/// @param q1 Second quaternion [w,x,y,z] +/// @param u Interpolation factor in [0,1] +/// @return nanovdb::math::Vec4 Interpolated quaternion +template +inline __host__ __device__ nanovdb::math::Vec4 +nlerpQuaternionShortestPath(const nanovdb::math::Vec4 &q0, + nanovdb::math::Vec4 q1, + const T u) { + // Ensure shortest arc (q and -q represent the same rotation). + T dot = q0[0] * q1[0] + q0[1] * q1[1] + q0[2] * q1[2] + q0[3] * q1[3]; + if (dot < T(0)) { + q1[0] = -q1[0]; + q1[1] = -q1[1]; + q1[2] = -q1[2]; + q1[3] = -q1[3]; + } + + const T s = T(1) - u; + return normalizeQuaternionSafe(nanovdb::math::Vec4(s * q0[0] + u * q1[0], + s * q0[1] + u * q1[1], + s * q0[2] + u * q1[2], + s * q0[3] + u * q1[3])); +} + /// @brief Computes the vector-Jacobian product for quaternion to rotation matrix transformation /// /// This function computes the gradient of the loss with respect to a quaternion (dL/dq) diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 8fa54c5d..4164f5c0 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -150,6 +150,8 @@ ConfigureTest(GaussianSphericalHarmonicsForwardTest "GaussianSphericalHarmonicsF ConfigureTest(GaussianSphericalHarmonicsBackwardTest "GaussianSphericalHarmonicsBackwardTest.cpp") ConfigureTest(GaussianProjectionForwardTest "GaussianProjectionForwardTest.cpp") ConfigureTest(GaussianProjectionBackwardTest "GaussianProjectionBackwardTest.cpp") +ConfigureTest(GaussianProjectionUTTest "GaussianProjectionUTTest.cpp") +ConfigureTest(GaussianUtilsTest "GaussianUtilsTest.cu") ConfigureTest(GaussianRasterizeTopContributorsTest "GaussianRasterizeTopContributorsTest.cpp") ConfigureTest(GaussianRasterizeContributingGaussianIdsTest "GaussianRasterizeContributingGaussianIdsTest.cpp") ConfigureTest(GaussianMCMCAddNoiseTest "GaussianMCMCAddNoiseTest.cpp") diff --git a/src/tests/GaussianProjectionUTTest.cpp b/src/tests/GaussianProjectionUTTest.cpp new file mode 100644 index 00000000..cd73d7d4 --- /dev/null +++ b/src/tests/GaussianProjectionUTTest.cpp @@ -0,0 +1,1584 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace fvdb::detail::ops { + +namespace { + +inline std::tuple +projectPointWithOpenCVDistortion(const float x, + const float y, + const float z, + const float fx, + const float fy, + const float cx, + const float cy, + const std::vector &radial, // k1,k2,k3 or k1..k6 + const std::vector &tangential, // p1,p2 (or empty) + const std::vector &thinPrism // s1..s4 (or empty) +) { + const float xn = x / z; + const float yn = y / z; + + const float x2 = xn * xn; + const float y2 = yn * yn; + const float xy = xn * yn; + + const float r2 = x2 + y2; + const float r4 = r2 * r2; + const float r6 = r4 * r2; + + float radial_dist = 1.0f; + if (radial.size() == 3) { + const float k1 = radial[0]; + const float k2 = radial[1]; + const float k3 = radial[2]; + radial_dist = 1.0f + k1 * r2 + k2 * r4 + k3 * r6; + } else if (radial.size() == 6) { + const float k1 = radial[0]; + const float k2 = radial[1]; + const float k3 = radial[2]; + const float k4 = radial[3]; + const float k5 = radial[4]; + const float k6 = radial[5]; + const float num = 1.0f + r2 * (k1 + r2 * (k2 + r2 * k3)); + const float den = 1.0f + r2 * (k4 + r2 * (k5 + r2 * k6)); + radial_dist = (den != 0.0f) ? (num / den) : 0.0f; + } else if (!radial.empty()) { + return {std::nanf(""), std::nanf("")}; + } + + float xd = xn * radial_dist; + float yd = yn * radial_dist; + + const float p1 = tangential.size() >= 1 ? tangential[0] : 0.0f; + const float p2 = tangential.size() >= 2 ? tangential[1] : 0.0f; + // OpenCV tangential: + // x += 2*p1*x*y + p2*(r^2 + 2*x^2) + // y += p1*(r^2 + 2*y^2) + 2*p2*x*y + xd += 2.0f * p1 * xy + p2 * (r2 + 2.0f * x2); + yd += p1 * (r2 + 2.0f * y2) + 2.0f * p2 * xy; + + if (!thinPrism.empty()) { + if (thinPrism.size() != 4) { + return {std::nanf(""), std::nanf("")}; + } + const float s1 = thinPrism[0]; + const float s2 = thinPrism[1]; + const float s3 = thinPrism[2]; + const float s4 = thinPrism[3]; + xd += s1 * r2 + s2 * r4; + yd += s3 * r2 + s4 * r4; + } + + const float u = fx * xd + cx; + const float v = fy * yd + cy; + return {u, v}; +} + +} // namespace + +// We keep these UT tests purely analytic/synthetic: they should not depend on any external test +// data downloads. +struct GaussianProjectionUTTestFixture : public ::testing::Test { + void + SetUp() override { + torch::manual_seed(0); + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA is not available; skipping GaussianProjectionUT tests."; + } + } + + torch::Tensor means; // [N, 3] + torch::Tensor quats; // [N, 4] + torch::Tensor logScales; // [N, 3] + torch::Tensor worldToCamMatricesStart; // [C, 4, 4] + torch::Tensor worldToCamMatricesEnd; // [C, 4, 4] + torch::Tensor projectionMatrices; // [C, 3, 3] + CameraModel cameraModel = CameraModel::PINHOLE; + torch::Tensor distortionCoeffs; // [C, 12] for OPENCV, or [C, 0] for NONE + + int64_t imageWidth; + int64_t imageHeight; + float eps2d; + float nearPlane; + float farPlane; + float minRadius2d; + + UTParams utParams; +}; + +TEST_F(GaussianProjectionUTTestFixture, CenteredGaussian_NoDistortion_AnalyticMeanAndConic) { + const int64_t C = 1; + + // World-space Gaussian mean at optical axis (x=y=0), so projected mean should be exactly + // (cx,cy) + const float z = 5.0f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + // Quaternion is stored as [w,x,y,z] in fvdb kernels + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + + const float sx = 0.2f, sy = 0.3f, sz = 0.4f; + logScales = torch::log(torch::tensor({{sx, sy, sz}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 100.0f, fy = 200.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; // matches defaults + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; // interpreted as fraction of image dims + utParams.requireAllSigmaPointsInImage = true; + + // CUDA + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto means2d_cpu = means2d.cpu(); + auto depths_cpu = depths.cpu(); + auto conics_cpu = conics.cpu(); + auto radii_cpu = radii.cpu(); + + EXPECT_GT(radii_cpu[0][0].item(), 0); + EXPECT_NEAR(depths_cpu[0][0].item(), z, 1e-4f); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), cx, 1e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), cy, 1e-3f); + + // For u = fx * x/z + cx, v = fy * y/z + cy and mean on optical axis (x=y=0), + // the projected covariance is exactly: + // cov_u = (fx*sx/z)^2, cov_v = (fy*sy/z)^2, off-diagonals = 0. + const float cov_u = (fx * sx / z) * (fx * sx / z); + const float cov_v = (fy * sy / z) * (fy * sy / z); + const float cov_u_blur = cov_u + eps2d; + const float cov_v_blur = cov_v + eps2d; + + const float expected_a = 1.0f / cov_u_blur; + const float expected_b = 0.0f; + const float expected_c = 1.0f / cov_v_blur; + + EXPECT_NEAR(conics_cpu[0][0][0].item(), expected_a, 1e-3f); + EXPECT_NEAR(conics_cpu[0][0][1].item(), expected_b, 1e-3f); + EXPECT_NEAR(conics_cpu[0][0][2].item(), expected_c, 1e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, NonlinearUTCovariance_ProducesFinitePositiveConic) { + const int64_t C = 1; + + // Construct a case where nonlinear projection + negative UT covariance weights can yield an + // indefinite covariance if not handled carefully. The expected behavior is that outputs remain + // finite and the resulting conic (inverse covariance) is positive on the diagonal. + means = torch::tensor({{0.2f, 0.1f, 2.0f}}, torch::kFloat32); + + // Rotate about +Y to couple X and Z under the UT sigma points. + // Quaternion is [w,x,y,z]. + const float angle_rad = static_cast(M_PI) / 3.0f; // 60 degrees + const float w = std::cos(0.5f * angle_rad); + const float y = std::sin(0.5f * angle_rad); + quats = torch::tensor({{w, 0.0f, y, 0.0f}}, torch::kFloat32); + + // Relatively large scales to produce noticeable nonlinear effects. + logScales = torch::log(torch::tensor({{1.0f, 0.7f, 0.8f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 700.0f, fy = 700.0f, cx = 640.0f, cy = 360.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + { + auto K = projectionMatrices.accessor(); + K[0][0][0] = fx; + K[0][1][1] = fy; + K[0][0][2] = cx; + K[0][1][2] = cy; + K[0][2][2] = 1.0f; + } + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 1280; + imageHeight = 720; + eps2d = 0.3f; + nearPlane = 0.05f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.alpha = 0.2f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 1.0f; // very lenient bounds + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto conics_cpu = conics.cpu(); + auto means2d_cpu = means2d.cpu(); + + ASSERT_GT(radii_cpu[0][0].item(), 0); + + const float a = conics_cpu[0][0][0].item(); + const float b = conics_cpu[0][0][1].item(); + const float c_out = conics_cpu[0][0][2].item(); + + EXPECT_TRUE(std::isfinite(a)); + EXPECT_TRUE(std::isfinite(b)); + EXPECT_TRUE(std::isfinite(c_out)); + EXPECT_GT(a, 0.0f); + EXPECT_GT(c_out, 0.0f); + + EXPECT_TRUE(std::isfinite(means2d_cpu[0][0][0].item())); + EXPECT_TRUE(std::isfinite(means2d_cpu[0][0][1].item())); +} + +TEST_F(GaussianProjectionUTTestFixture, UTParams_InvalidAlpha_ThrowsOnHost) { + const int64_t C = 1; + + means = torch::tensor({{0.0f, 0.0f, 5.0f}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{0.2f, 0.3f, 0.4f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto K = projectionMatrices.accessor(); + K[0][0][0] = 100.0f; + K[0][1][1] = 200.0f; + K[0][0][2] = 320.0f; + K[0][1][2] = 240.0f; + K[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.alpha = 0.0f; // invalid + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + EXPECT_THROW((dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false)), + c10::Error); +} + +TEST_F(GaussianProjectionUTTestFixture, UTParams_InvalidKappa_ThrowsOnHost) { + const int64_t C = 1; + + means = torch::tensor({{0.0f, 0.0f, 5.0f}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{0.2f, 0.3f, 0.4f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto K = projectionMatrices.accessor(); + K[0][0][0] = 100.0f; + K[0][1][1] = 200.0f; + K[0][0][2] = 320.0f; + K[0][1][2] = 240.0f; + K[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.alpha = 0.1f; + // For the 3D UT (D=3), require D + kappa > 0. Setting kappa=-3 makes denom=0. + utParams.kappa = -3.0f; // invalid + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + EXPECT_THROW((dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false)), + c10::Error); +} + +TEST_F(GaussianProjectionUTTestFixture, DepthNearCameraPlane_BelowZEps_HardRejects) { + const int64_t C = 1; + + // Use a very small positive depth below the perspective z-epsilon (1e-6). This should be + // treated as invalid (BehindCamera) rather than projected with a clamped z. + const float z = 5e-7f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + // Extremely small Gaussian so all sigma points stay near the mean and below z_eps. + logScales = torch::log(torch::tensor({{1e-8f, 1e-8f, 1e-8f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto K = projectionMatrices.accessor(); + K[0][0][0] = 100.0f; + K[0][1][1] = 200.0f; + K[0][0][2] = 320.0f; + K[0][1][2] = 240.0f; + K[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.0f; // don't near-plane cull; we want to test z-eps rejection + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.requireAllSigmaPointsInImage = false; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + EXPECT_EQ(radii_cpu[0][0].item(), 0); +} + +TEST_F(GaussianProjectionUTTestFixture, DepthNearCameraPlane_AboveZEps_Projects) { + const int64_t C = 1; + + // Depth just above z-eps should project normally. + const float z = 2e-6f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{1e-8f, 1e-8f, 1e-8f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 100.0f, fy = 200.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto K = projectionMatrices.accessor(); + K[0][0][0] = fx; + K[0][1][1] = fy; + K[0][0][2] = cx; + K[0][1][2] = cy; + K[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.0f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), cx, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), cy, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, Orthographic_NoDistortion_AnalyticMeanAndDepth) { + const int64_t C = 1; + + const float x = 1.0f, y = -2.0f, z = 10.0f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{0.2f, 0.3f, 0.4f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 123.0f, fy = 77.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + cameraModel = CameraModel::ORTHOGRAPHIC; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto means2d_cpu = means2d.cpu(); + auto depths_cpu = depths.cpu(); + auto radii_cpu = radii.cpu(); + + EXPECT_GT(radii_cpu[0][0].item(), 0); + EXPECT_NEAR(depths_cpu[0][0].item(), z, 1e-4f); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), fx * x + cx, 1e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), fy * y + cy, 1e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, OffAxisTinyGaussian_NoDistortion_MeanMatchesPinhole) { + const int64_t C = 1; + + const float x = 1.0f, y = -2.0f, z = 10.0f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + // Extremely small Gaussian so UT mean should match the point projection closely + // (off-axis + perspective nonlinearity can otherwise introduce a tiny UT mean shift). + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 123.0f, fy = 77.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto means2d_cpu = means2d.cpu(); + const float expected_u = fx * (x / z) + cx; + const float expected_v = fy * (y / z) + cy; + // UT projects sigma points through a nonlinear pinhole model; even for a very small Gaussian + // there can be a tiny second-order mean shift. Keep a slightly relaxed tolerance here. + EXPECT_NEAR(means2d_cpu[0][0][0].item(), expected_u, 2e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), expected_v, 2e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, MultiCamera_RadTanDistortion_PerCameraParams) { + const int64_t C = 2; + + const float x = 0.2f, y = -0.1f, z = 2.0f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + // Different intrinsics per camera. + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto K = projectionMatrices.accessor(); + // Cam0 + const float fx0 = 500.0f, fy0 = 450.0f, cx0 = 400.0f, cy0 = 300.0f; + K[0][0][0] = fx0; + K[0][1][1] = fy0; + K[0][0][2] = cx0; + K[0][1][2] = cy0; + K[0][2][2] = 1.0f; + // Cam1 + const float fx1 = 300.0f, fy1 = 350.0f, cx1 = 320.0f, cy1 = 240.0f; + K[1][0][0] = fx1; + K[1][1][1] = fy1; + K[1][0][2] = cx1; + K[1][1][2] = cy1; + K[1][2][2] = 1.0f; + + // Different distortion coeffs per camera. + cameraModel = CameraModel::OPENCV_RADTAN_5; + distortionCoeffs = torch::zeros({C, 12}, torch::kFloat32); + auto dc = distortionCoeffs.accessor(); + // Cam0: non-trivial coefficients + const float k1_0 = 0.10f, k2_0 = -0.05f, k3_0 = 0.01f, p1_0 = 0.001f, p2_0 = -0.0005f; + dc[0][0] = k1_0; + dc[0][1] = k2_0; + dc[0][2] = k3_0; + dc[0][6] = p1_0; + dc[0][7] = p2_0; + // Cam1: different coefficients + const float k1_1 = -0.02f, k2_1 = 0.03f, k3_1 = 0.0f, p1_1 = -0.0008f, p2_1 = 0.0002f; + dc[1][0] = k1_1; + dc[1][1] = k2_1; + dc[1][2] = k3_1; + dc[1][6] = p1_1; + dc[1][7] = p2_1; + + imageWidth = 800; + imageHeight = 600; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + EXPECT_GT(radii_cpu[1][0].item(), 0); + + const auto [u0, v0] = projectPointWithOpenCVDistortion( + x, y, z, fx0, fy0, cx0, cy0, {k1_0, k2_0, k3_0}, {p1_0, p2_0}, {}); + const auto [u1, v1] = projectPointWithOpenCVDistortion( + x, y, z, fx1, fy1, cx1, cy1, {k1_1, k2_1, k3_1}, {p1_1, p2_1}, {}); + + EXPECT_NEAR(means2d_cpu[0][0][0].item(), u0, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), v0, 5e-3f); + EXPECT_NEAR(means2d_cpu[1][0][0].item(), u1, 5e-3f); + EXPECT_NEAR(means2d_cpu[1][0][1].item(), v1, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, MultiCamera_Pinhole_ZeroCoeffTensor_PerCameraIntrinsics) { + const int64_t C = 2; + + // Put the mean on the optical axis so the UT mean should be exactly (cx,cy) per camera. + const float z = 5.0f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{0.2f, 0.3f, 0.4f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + // Different intrinsics per camera. + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto K = projectionMatrices.accessor(); + // Cam0 + const float fx0 = 100.0f, fy0 = 200.0f, cx0 = 111.0f, cy0 = 222.0f; + K[0][0][0] = fx0; + K[0][1][1] = fy0; + K[0][0][2] = cx0; + K[0][1][2] = cy0; + K[0][2][2] = 1.0f; + // Cam1 + const float fx1 = 300.0f, fy1 = 400.0f, cx1 = 333.0f, cy1 = 444.0f; + K[1][0][0] = fx1; + K[1][1][1] = fy1; + K[1][0][2] = cx1; + K[1][1][2] = cy1; + K[1][2][2] = 1.0f; + + // Pinhole projection should ignore distortionCoeffs. Use a [C,0] tensor to exercise the + // mNumDistortionCoeffs==0 shared-memory path. + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 800; + imageHeight = 600; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + EXPECT_GT(radii_cpu[1][0].item(), 0); + + // UT projects sigma points through a nonlinear model; tiny mean shifts can occur due to + // floating point and second-order effects. Keep a slightly relaxed tolerance. + EXPECT_NEAR(means2d_cpu[0][0][0].item(), cx0, 3e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), cy0, 3e-3f); + EXPECT_NEAR(means2d_cpu[1][0][0].item(), cx1, 3e-3f); + EXPECT_NEAR(means2d_cpu[1][0][1].item(), cy1, 3e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, + OffAxisTinyGaussian_RadTanDistortion_MeanMatchesOpenCVPoint) { + const int64_t C = 1; + + const float x = 0.2f, y = -0.1f, z = 2.0f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 500.0f, fy = 450.0f, cx = 400.0f, cy = 300.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + // coefficients chosen to be non-trivial but not extreme + const float k1 = 0.10f; + const float k2 = -0.05f; + const float k3 = 0.01f; + const float p1 = 0.001f; + const float p2 = -0.0005f; + + cameraModel = CameraModel::OPENCV_RADTAN_5; + // [k1..k6,p1,p2,s1..s4] (use polynomial by setting k4..k6 = 0, and no thin-prism by zeroing s*) + distortionCoeffs = torch::zeros({C, 12}, torch::kFloat32); + auto distortionCoeffsAcc = distortionCoeffs.accessor(); + distortionCoeffsAcc[0][0] = k1; + distortionCoeffsAcc[0][1] = k2; + distortionCoeffsAcc[0][2] = k3; + distortionCoeffsAcc[0][6] = p1; + distortionCoeffsAcc[0][7] = p2; + + imageWidth = 800; + imageHeight = 600; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + + const auto [expected_u, expected_v] = + projectPointWithOpenCVDistortion(x, y, z, fx, fy, cx, cy, {k1, k2, k3}, {p1, p2}, {}); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), expected_u, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), expected_v, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, + OffAxisTinyGaussian_RationalDistortion_MeanMatchesOpenCVPoint) { + const int64_t C = 1; + + const float x = -0.15f, y = 0.12f, z = 3.0f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 600.0f, fy = 550.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + const float k1 = 0.08f; + const float k2 = -0.02f; + const float k3 = 0.005f; + const float k4 = 0.01f; + const float k5 = -0.004f; + const float k6 = 0.001f; + const float p1 = -0.0007f; + const float p2 = 0.0003f; + + cameraModel = CameraModel::OPENCV_RATIONAL_8; + distortionCoeffs = torch::zeros({C, 12}, torch::kFloat32); + auto distortionCoeffsAcc = distortionCoeffs.accessor(); + distortionCoeffsAcc[0][0] = k1; + distortionCoeffsAcc[0][1] = k2; + distortionCoeffsAcc[0][2] = k3; + distortionCoeffsAcc[0][3] = k4; + distortionCoeffsAcc[0][4] = k5; + distortionCoeffsAcc[0][5] = k6; + distortionCoeffsAcc[0][6] = p1; + distortionCoeffsAcc[0][7] = p2; + + imageWidth = 800; + imageHeight = 600; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + + const auto [expected_u, expected_v] = projectPointWithOpenCVDistortion( + x, y, z, fx, fy, cx, cy, {k1, k2, k3, k4, k5, k6}, {p1, p2}, {}); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), expected_u, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), expected_v, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, + OffAxisTinyGaussian_ThinPrismDistortion_MeanMatchesOpenCVPoint) { + const int64_t C = 1; + + const float x = 0.1f, y = 0.08f, z = 2.5f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 700.0f, fy = 650.0f, cx = 500.0f, cy = 400.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + const float k1 = 0.05f; + const float k2 = -0.01f; + const float k3 = 0.002f; + const float k4 = 0.005f; + const float k5 = -0.001f; + const float k6 = 0.0005f; + const float p1 = 0.0002f; + const float p2 = -0.0004f; + const float s1 = 0.0008f; + const float s2 = -0.0003f; + const float s3 = 0.0005f; + const float s4 = 0.0001f; + + cameraModel = CameraModel::OPENCV_THIN_PRISM_12; + distortionCoeffs = torch::zeros({C, 12}, torch::kFloat32); + auto distortionCoeffsAcc = distortionCoeffs.accessor(); + distortionCoeffsAcc[0][0] = k1; + distortionCoeffsAcc[0][1] = k2; + distortionCoeffsAcc[0][2] = k3; + distortionCoeffsAcc[0][3] = k4; + distortionCoeffsAcc[0][4] = k5; + distortionCoeffsAcc[0][5] = k6; + distortionCoeffsAcc[0][6] = p1; + distortionCoeffsAcc[0][7] = p2; + distortionCoeffsAcc[0][8] = s1; + distortionCoeffsAcc[0][9] = s2; + distortionCoeffsAcc[0][10] = s3; + distortionCoeffsAcc[0][11] = s4; + + imageWidth = 1200; + imageHeight = 900; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + + const auto [expected_u, expected_v] = projectPointWithOpenCVDistortion( + x, y, z, fx, fy, cx, cy, {k1, k2, k3, k4, k5, k6}, {p1, p2}, {s1, s2, s3, s4}); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), expected_u, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), expected_v, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, + OffAxisTinyGaussian_RadTanThinPrismDistortion_MeanMatchesOpenCVPoint) { + const int64_t C = 1; + + const float x = 0.07f, y = -0.11f, z = 2.2f; + means = torch::tensor({{x, y, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 620.0f, fy = 590.0f, cx = 410.0f, cy = 305.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + const float k1 = 0.06f; + const float k2 = -0.015f; + const float k3 = 0.003f; + const float p1 = 0.0004f; + const float p2 = -0.0002f; + const float s1 = 0.0007f; + const float s2 = -0.0003f; + const float s3 = 0.0005f; + const float s4 = 0.0001f; + + cameraModel = CameraModel::OPENCV_RADTAN_THIN_PRISM_9; + distortionCoeffs = torch::zeros({C, 12}, torch::kFloat32); + auto distortionCoeffsAcc = distortionCoeffs.accessor(); + distortionCoeffsAcc[0][0] = k1; + distortionCoeffsAcc[0][1] = k2; + distortionCoeffsAcc[0][2] = k3; + // k4..k6 must be 0 for this explicit model + distortionCoeffsAcc[0][6] = p1; + distortionCoeffsAcc[0][7] = p2; + distortionCoeffsAcc[0][8] = s1; + distortionCoeffsAcc[0][9] = s2; + distortionCoeffsAcc[0][10] = s3; + distortionCoeffsAcc[0][11] = s4; + + imageWidth = 900; + imageHeight = 700; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + + const auto [expected_u, expected_v] = projectPointWithOpenCVDistortion( + x, y, z, fx, fy, cx, cy, {k1, k2, k3}, {p1, p2}, {s1, s2, s3, s4}); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), expected_u, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), expected_v, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, RadTanThinPrism_IgnoresK456EvenIfNonZero) { + const int64_t C = 1; + + means = torch::tensor({{0.1f, 0.05f, 2.0f}}, torch::kFloat32).cuda(); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32).cuda(); + logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)).cuda(); + + worldToCamMatricesStart = torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)) + .unsqueeze(0) + .expand({C, 4, 4}) + .cuda(); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + // NOTE: `.accessor<>()` is a host-side view; only use it on CPU tensors, then move to CUDA. + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = 500.0f; + projectionMatricesAcc[0][1][1] = 500.0f; + projectionMatricesAcc[0][0][2] = 320.0f; + projectionMatricesAcc[0][1][2] = 240.0f; + projectionMatricesAcc[0][2][2] = 1.0f; + projectionMatrices = projectionMatrices.cuda(); + + cameraModel = CameraModel::OPENCV_RADTAN_THIN_PRISM_9; + distortionCoeffs = torch::zeros({C, 12}, torch::kFloat32); + auto distortionCoeffsAcc = distortionCoeffs.accessor(); + distortionCoeffsAcc[0][0] = 0.01f; // k1 + distortionCoeffsAcc[0][3] = 0.1f; // k4: should be ignored for RADTAN_THIN_PRISM_9 + distortionCoeffsAcc[0][8] = 0.001f; // s1 + distortionCoeffs = distortionCoeffs.cuda(); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + auto means2d_cpu = means2d.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); + + // Verify we effectively use polynomial radial + thin-prism (ignore k4..k6). + const float fx = 500.0f, fy = 500.0f, cx = 320.0f, cy = 240.0f; + const float x = 0.1f, y = 0.05f, z = 2.0f; + const auto [expected_u, expected_v] = projectPointWithOpenCVDistortion( + x, y, z, fx, fy, cx, cy, {0.01f, 0.0f, 0.0f}, {0.0f, 0.0f}, {0.001f, 0.0f, 0.0f, 0.0f}); + EXPECT_NEAR(means2d_cpu[0][0][0].item(), expected_u, 5e-3f); + EXPECT_NEAR(means2d_cpu[0][0][1].item(), expected_v, 5e-3f); +} + +TEST_F(GaussianProjectionUTTestFixture, + SigmaPointBehindCamera_HardRejectsEvenWhenNotRequiringAllInImage) { + const int64_t C = 1; + + // Place the Gaussian mean just in front of the camera, but give it a large Z scale so one + // UT sigma point crosses behind the camera (z <= 0). The UT kernel should hard-reject such + // Gaussians (new behavior), regardless of requireAllSigmaPointsInImage. + const float z = 0.20f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + + // With UT alpha=0.1, gamma ~= sqrt(0.03) ~= 0.173. Choose sz so z - gamma*sz <= 0. + const float sx = 1e-3f, sy = 1e-3f, sz = 2.0f; + logScales = torch::log(torch::tensor({{sx, sy, sz}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 500.0f, fy = 500.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.05f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams.alpha = 0.1f; + utParams.beta = 2.0f; + utParams.kappa = 0.0f; + utParams.inImageMargin = 0.1f; + utParams.requireAllSigmaPointsInImage = false; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + // When the UT kernel discards a Gaussian, only radii are defined to be 0; other outputs are + // undefined (may contain garbage). Only assert radii here. + auto radii_cpu = radii.cpu(); + EXPECT_EQ(radii_cpu[0][0].item(), 0); +} + +TEST_F(GaussianProjectionUTTestFixture, + SomeSigmaPointsOutOfBoundsButInFront_NotHardRejectedWhenNotRequiringAllInImage) { + const int64_t C = 1; + + // Centered mean in front of camera; huge X scale so +/-X sigma points project outside the + // image, but remain in front of the camera. With requireAllSigmaPointsInImage=false this should + // still produce a valid Gaussian (radii > 0). + const float z = 5.0f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + + // Choose sx large enough that projected sigma points fall outside image+margin. + const float sx = 30.0f, sy = 1e-3f, sz = 1e-3f; + logScales = torch::log(torch::tensor({{sx, sy, sz}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + + const float fx = 500.0f, fy = 500.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.requireAllSigmaPointsInImage = false; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto radii_cpu = radii.cpu(); + EXPECT_GT(radii_cpu[0][0].item(), 0); +} + +TEST_F(GaussianProjectionUTTestFixture, RollingShutterNone_DepthUsesStartPoseNotCenter) { + const int64_t C = 1; + + // If RollingShutterType::NONE, projection uses the start pose. Depth culling and outDepths + // should therefore also use the start pose. This test ensures we don't accidentally use the + // center pose (t=0.5) when start/end differ. + const float z = 5.0f; + means = torch::tensor({{0.0f, 0.0f, z}}, torch::kFloat32); + quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32); + logScales = torch::log(torch::tensor({{0.2f, 0.2f, 0.2f}}, torch::kFloat32)); + + worldToCamMatricesStart = + torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4}); + worldToCamMatricesEnd = worldToCamMatricesStart.clone(); + // End pose translates camera forward in +z, so p_cam.z is larger at t=1.0. + auto worldToCamEndAcc = worldToCamMatricesEnd.accessor(); + worldToCamEndAcc[0][2][3] = 1.0f; + + const float fx = 100.0f, fy = 100.0f, cx = 320.0f, cy = 240.0f; + projectionMatrices = torch::zeros({C, 3, 3}, torch::TensorOptions().dtype(torch::kFloat32)); + auto projectionMatricesAcc = projectionMatrices.accessor(); + projectionMatricesAcc[0][0][0] = fx; + projectionMatricesAcc[0][1][1] = fy; + projectionMatricesAcc[0][0][2] = cx; + projectionMatricesAcc[0][1][2] = cy; + projectionMatricesAcc[0][2][2] = 1.0f; + + cameraModel = CameraModel::PINHOLE; + distortionCoeffs = torch::zeros({C, 0}, torch::kFloat32); + + imageWidth = 640; + imageHeight = 480; + eps2d = 0.3f; + nearPlane = 0.1f; + farPlane = 100.0f; + minRadius2d = 0.0f; + + utParams = UTParams{}; + utParams.requireAllSigmaPointsInImage = true; + + means = means.cuda(); + quats = quats.cuda(); + logScales = logScales.cuda(); + worldToCamMatricesStart = worldToCamMatricesStart.cuda(); + worldToCamMatricesEnd = worldToCamMatricesEnd.cuda(); + projectionMatrices = projectionMatrices.cuda(); + distortionCoeffs = distortionCoeffs.cuda(); + + const auto [radii, means2d, depths, conics, compensations] = + dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); + + auto depths_cpu = depths.cpu(); + // Start pose is identity, so depth should be exactly z (not z + 0.5). + EXPECT_NEAR(depths_cpu[0][0].item(), z, 1e-4f); +} + +} // namespace fvdb::detail::ops diff --git a/src/tests/GaussianUtilsTest.cu b/src/tests/GaussianUtilsTest.cu new file mode 100644 index 00000000..cd847dcd --- /dev/null +++ b/src/tests/GaussianUtilsTest.cu @@ -0,0 +1,305 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +// This is a `.cu` test, compiled by NVCC in (at least) two passes: +// - host pass: we use `std::{sqrt,sin,cos}` for the reference helpers (needs ``) +// - device pass: we use `sqrtf/sinf/cosf` intrinsics instead (so avoid pulling in `` there) +#if !defined(__CUDA_ARCH__) +#include +#endif + +namespace fvdb::detail::ops { +namespace { + +using Mat3f = nanovdb::math::Mat3; +using Vec4f = nanovdb::math::Vec4; +using Vec3f = nanovdb::math::Vec3; + +// Minimal math helpers: CUDA intrinsics on device, `std::` on host. +__host__ __device__ inline float +mySqrt(float x) { +#if defined(__CUDA_ARCH__) + return sqrtf(x); +#else + return std::sqrt(x); +#endif +} + +__host__ __device__ inline float +mySin(float x) { +#if defined(__CUDA_ARCH__) + return sinf(x); +#else + return std::sin(x); +#endif +} + +__host__ __device__ inline float +myCos(float x) { +#if defined(__CUDA_ARCH__) + return cosf(x); +#else + return std::cos(x); +#endif +} + +inline Mat3f +quatToRotationMatrixHost(const Vec4f &q_wxyz) { + // Normalize the quaternion + float w = q_wxyz[0], x = q_wxyz[1], y = q_wxyz[2], z = q_wxyz[3]; + const float n2 = w * w + x * x + y * y + z * z; + if (n2 > 0.0f) { + const float invN = 1.0f / mySqrt(n2); + w *= invN; + x *= invN; + y *= invN; + z *= invN; + } else { + w = 1.0f; + x = y = z = 0.0f; + } + + const float x2 = x * x, y2 = y * y, z2 = z * z; + const float xy = x * y, xz = x * z, yz = y * z; + const float wx = w * x, wy = w * y, wz = w * z; + + return Mat3f((1.0f - 2.0f * (y2 + z2)), + (2.0f * (xy - wz)), + (2.0f * (xz + wy)), // 1st row + (2.0f * (xy + wz)), + (1.0f - 2.0f * (x2 + z2)), + (2.0f * (yz - wx)), // 2nd row + (2.0f * (xz - wy)), + (2.0f * (yz + wx)), + (1.0f - 2.0f * (x2 + y2)) // 3rd row + ); +} + +inline void +expectMatNear(const Mat3f &A, const Mat3f &B, float tol) { + for (int r = 0; r < 3; ++r) { + for (int c = 0; c < 3; ++c) { + EXPECT_NEAR(A[r][c], B[r][c], tol) << "Mismatch at (" << r << "," << c << ")"; + } + } +} + +inline Vec4f +axisAngleToQuatWxyz(float ax, float ay, float az, float angleRad) { + const float n = mySqrt(ax * ax + ay * ay + az * az); + if (n <= 0.0f) + return Vec4f(1.0f, 0.0f, 0.0f, 0.0f); + ax /= n; + ay /= n; + az /= n; + + const float half = 0.5f * angleRad; + const float s = mySin(half); + const float c = myCos(half); + return Vec4f(c, s * ax, s * ay, s * az); +} + +inline void +expectVecNear(const Vec3f &a, const Vec3f &b, float tol) { + EXPECT_NEAR(a[0], b[0], tol); + EXPECT_NEAR(a[1], b[1], tol); + EXPECT_NEAR(a[2], b[2], tol); +} + +inline void +expectQuatNear(const Vec4f &a, const Vec4f &b, float tol) { + EXPECT_NEAR(a[0], b[0], tol); + EXPECT_NEAR(a[1], b[1], tol); + EXPECT_NEAR(a[2], b[2], tol); + EXPECT_NEAR(a[3], b[3], tol); +} + +inline Vec4f +normalizeQuat(Vec4f q) { + const float n2 = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]; + if (n2 > 0.0f) { + const float invN = 1.0f / mySqrt(n2); + q[0] *= invN; + q[1] *= invN; + q[2] *= invN; + q[3] *= invN; + } else { + q[0] = 1.0f; + q[1] = q[2] = q[3] = 0.0f; + } + return q; +} + +inline float +quatDot(const Vec4f &a, const Vec4f &b) { + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3]; +} + +inline Vec4f +nlerpRefShortestPath(const Vec4f &q0, Vec4f q1, float u) { + float dot = quatDot(q0, q1); + if (dot < 0.0f) { + q1[0] = -q1[0]; + q1[1] = -q1[1]; + q1[2] = -q1[2]; + q1[3] = -q1[3]; + dot = -dot; + } + (void)dot; // suppress unused variable warning + const float s = 1.0f - u; + return normalizeQuat(Vec4f(s * q0[0] + u * q1[0], + s * q0[1] + u * q1[1], + s * q0[2] + u * q1[2], + s * q0[3] + u * q1[3])); +} + +} // namespace + +TEST(GaussianUtilsTest, RotationMatrixToQuaternion_Identity) { + const Mat3f R(1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f); + const Vec4f q = rotationMatrixToQuaternion(R); + + EXPECT_NEAR(q[0], 1.0f, 1e-6f); + EXPECT_NEAR(q[1], 0.0f, 1e-6f); + EXPECT_NEAR(q[2], 0.0f, 1e-6f); + EXPECT_NEAR(q[3], 0.0f, 1e-6f); + + const float n = mySqrt(q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]); + EXPECT_NEAR(n, 1.0f, 1e-6f); + EXPECT_GE(q[0], 0.0f); // sign convention +} + +TEST(GaussianUtilsTest, RotationMatrixToQuaternion_RoundTrip_KnownAxes) { + const float pi = 3.14159265358979323846f; + const Vec4f qs[] = { + axisAngleToQuatWxyz(1.0f, 0.0f, 0.0f, 0.5f * pi), // +90° about X + axisAngleToQuatWxyz(0.0f, 1.0f, 0.0f, 0.5f * pi), // +90° about Y + axisAngleToQuatWxyz(0.0f, 0.0f, 1.0f, 0.5f * pi), // +90° about Z + axisAngleToQuatWxyz(0.0f, 1.0f, 0.0f, 1.0f * pi), // 180° about Y (w=0 edge case) + }; + + for (const auto &q_in: qs) { + const Mat3f R_in = quatToRotationMatrixHost(q_in); + const Vec4f q_out = rotationMatrixToQuaternion(R_in); + const Mat3f R_out = quatToRotationMatrixHost(q_out); + + expectMatNear(R_in, R_out, 2e-5f); + EXPECT_GE(q_out[0], 0.0f); + } +} + +TEST(GaussianUtilsTest, RotationMatrixToQuaternion_ProducesPositiveWForEquivalentRotation) { + const float pi = 3.14159265358979323846f; + const Vec4f q = axisAngleToQuatWxyz(0.0f, 0.0f, 1.0f, pi / 3.0f); + const Vec4f q_neg = Vec4f(-q[0], -q[1], -q[2], -q[3]); + + const Mat3f R = quatToRotationMatrixHost(q_neg); + const Vec4f q_out = rotationMatrixToQuaternion(R); + EXPECT_GE(q_out[0], 0.0f); + + const Mat3f R_out = quatToRotationMatrixHost(q_out); + expectMatNear(R, R_out, 2e-5f); +} + +TEST(GaussianUtilsTest, RotationMatrixToQuaternion_RoundTrip_DeterministicSamples) { + const float pi = 3.14159265358979323846f; + const struct Sample { + float ax, ay, az, ang; + } samples[] = { + {0.3f, 0.7f, -0.2f, 0.1f * pi}, + {-0.8f, 0.1f, 0.5f, 0.25f * pi}, + {0.1f, -0.2f, 0.9f, 0.8f * pi}, + {0.9f, 0.4f, 0.1f, 1.1f * pi}, + {-0.4f, -0.6f, 0.2f, 0.6f * pi}, + {0.2f, -0.9f, -0.3f, 1.7f * pi}, + {-0.1f, 0.5f, 0.8f, 0.33f * pi}, + {0.6f, -0.3f, 0.7f, 1.9f * pi}, + }; + + for (const auto &s: samples) { + const Vec4f q_in = axisAngleToQuatWxyz(s.ax, s.ay, s.az, s.ang); + const Mat3f R_in = quatToRotationMatrixHost(q_in); + const Vec4f q_out = rotationMatrixToQuaternion(R_in); + const Mat3f R_out = quatToRotationMatrixHost(q_out); + + expectMatNear(R_in, R_out, 2e-5f); + EXPECT_GE(q_out[0], 0.0f); + } +} + +TEST(GaussianUtilsTest, RotationMatrixToQuaternion_DegenerateInput_ReturnsFiniteIdentity) { + // Construct a deliberately degenerate/NaN matrix which previously could trigger s=0 with a + // finite numerator, producing inf/NaN and bypassing the "degenerate input" fallback. +#if !defined(__CUDA_ARCH__) + const float nan = std::nanf(""); +#else + const float nan = nanf(""); +#endif + + // Force comparisons/trace paths to be ill-defined (NaN), but keep some off-diagonals finite. + // This makes `t` clamp to 0 -> s=0, while (R10-R01) is finite and non-zero. + const Mat3f R(nan, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f); + + const Vec4f q = rotationMatrixToQuaternion(R); + +#if !defined(__CUDA_ARCH__) + EXPECT_TRUE(std::isfinite(q[0])); + EXPECT_TRUE(std::isfinite(q[1])); + EXPECT_TRUE(std::isfinite(q[2])); + EXPECT_TRUE(std::isfinite(q[3])); +#endif + + // We choose identity as the explicit safe fallback for degenerate inputs. + EXPECT_NEAR(q[0], 1.0f, 1e-6f); + EXPECT_NEAR(q[1], 0.0f, 1e-6f); + EXPECT_NEAR(q[2], 0.0f, 1e-6f); + EXPECT_NEAR(q[3], 0.0f, 1e-6f); +} + +TEST(GaussianUtilsTest, NlerpQuaternionShortestPath_MatchesReference) { + const float pi = 3.14159265358979323846f; + const Vec4f q_start = axisAngleToQuatWxyz(1.0f, 0.0f, 0.0f, pi / 3.0f); // 60deg about X + const Vec4f q_end = axisAngleToQuatWxyz(0.0f, 1.0f, 0.0f, 2.0f * pi / 3.0f); // 120deg about Y + + const Mat3f R_start = quatToRotationMatrixHost(q_start); + const Mat3f R_end = quatToRotationMatrixHost(q_end); + + const Vec3f t_start(1.0f, 2.0f, 3.0f); + const Vec3f t_end(-4.0f, 5.0f, 0.5f); + const float u = 0.25f; + + const Vec4f q0 = rotationMatrixToQuaternion(R_start); + const Vec4f q1 = rotationMatrixToQuaternion(R_end); + const Vec4f q_ref = nlerpRefShortestPath(q0, q1, u); + const Vec4f q_interp = nlerpQuaternionShortestPath(q0, q1, u); + const Vec3f t_interp = t_start + u * (t_end - t_start); + + expectQuatNear(q_interp, q_ref, 2e-6f); + expectVecNear(t_interp, t_start + u * (t_end - t_start), 1e-6f); +} + +TEST(GaussianUtilsTest, PoseRt_WorldToCamAndBack_RoundTrip) { + const float pi = 3.14159265358979323846f; + + // Non-identity rotation + non-zero translation to catch ordering bugs. + const Vec4f q = axisAngleToQuatWxyz(0.2f, 0.9f, -0.4f, 0.37f * pi); + const Vec3f t(1.25f, -2.5f, 0.75f); + const Mat3f R = quatToRotationMatrixHost(q); + + const Vec3f p_world(0.3f, -1.1f, 2.7f); + const Vec3f p_cam = R * p_world + t; + const Vec3f p_world_rt = R.transpose() * (p_cam - t); + expectVecNear(p_world_rt, p_world, 2e-5f); + + // Also verify the opposite direction for completeness. + const Vec3f p_cam_in(-0.2f, 0.4f, 1.8f); + const Vec3f p_world_from_cam = R.transpose() * (p_cam_in - t); + const Vec3f p_cam_rt = R * p_world_from_cam + t; + expectVecNear(p_cam_rt, p_cam_in, 2e-5f); +} + +} // namespace fvdb::detail::ops