Skip to content

Commit ba96359

Browse files
committed
camera model
Signed-off-by: Francis Williams <francis@fwilliams.info>
1 parent f8359b7 commit ba96359

File tree

3 files changed

+64
-65
lines changed

3 files changed

+64
-65
lines changed

src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ template <typename T> class OpenCVCameraModel {
4040
using Mat3 = nanovdb::math::Mat3<T>;
4141

4242
// Internal "math mode" for OpenCV distortion evaluation.
43-
// Kept inside the struct to avoid confusion with the public API `DistortionModel`.
43+
// Kept inside the struct to avoid confusion with the public API `CameraModel`.
4444
enum class Model : uint8_t {
4545
NONE = 0, // no distortion
4646
RADTAN_5 = 5, // k1,k2,p1,p2,k3 (polynomial radial up to r^6)
@@ -49,20 +49,18 @@ template <typename T> class OpenCVCameraModel {
4949
THIN_PRISM_12 = 12, // RATIONAL_8 + s1,s2,s3,s4
5050
};
5151

52-
// Construct this camera model from the public DistortionModel enum.
52+
// Construct this camera model from the public CameraModel enum.
5353
//
5454
// For OpenCV models, coefficients are read from a per-camera packed layout:
5555
// [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]
5656
//
5757
// Preconditions are asserted on-device (trap) rather than returning false.
5858
__device__ __forceinline__
59-
OpenCVCameraModel(const DistortionModel distortionModel,
60-
const Mat3 *K_in,
61-
const T *distortionCoeffs)
59+
OpenCVCameraModel(const CameraModel cameraModel, const Mat3 *K_in, const T *distortionCoeffs)
6260
: K(K_in) {
6361
deviceAssertOrTrap(K != nullptr);
6462

65-
if (distortionModel == DistortionModel::NONE) {
63+
if (cameraModel == CameraModel::PINHOLE) {
6664
radial = tangential = thinPrism = nullptr;
6765
numRadial = numTangential = numThinPrism = 0;
6866
model = Model::NONE;
@@ -75,7 +73,7 @@ template <typename T> class OpenCVCameraModel {
7573
const T *tang_in = distortionCoeffs + 6; // p1,p2
7674
const T *thin_in = distortionCoeffs + 8; // s1..s4
7775

78-
if (distortionModel == DistortionModel::OPENCV_RADTAN_5) {
76+
if (cameraModel == CameraModel::OPENCV_RADTAN_5) {
7977
radial = radial_in;
8078
numRadial = 3; // k1,k2,k3 (polynomial)
8179
tangential = tang_in;
@@ -85,7 +83,7 @@ template <typename T> class OpenCVCameraModel {
8583
model = Model::RADTAN_5;
8684
return;
8785
}
88-
if (distortionModel == DistortionModel::OPENCV_RATIONAL_8) {
86+
if (cameraModel == CameraModel::OPENCV_RATIONAL_8) {
8987
radial = radial_in;
9088
numRadial = 6; // k1..k6 (rational)
9189
tangential = tang_in;
@@ -95,7 +93,7 @@ template <typename T> class OpenCVCameraModel {
9593
model = Model::RATIONAL_8;
9694
return;
9795
}
98-
if (distortionModel == DistortionModel::OPENCV_THIN_PRISM_12) {
96+
if (cameraModel == CameraModel::OPENCV_THIN_PRISM_12) {
9997
radial = radial_in;
10098
numRadial = 6; // k1..k6 (rational)
10199
tangential = tang_in;
@@ -105,7 +103,7 @@ template <typename T> class OpenCVCameraModel {
105103
model = Model::THIN_PRISM_12;
106104
return;
107105
}
108-
if (distortionModel == DistortionModel::OPENCV_RADTAN_THIN_PRISM_9) {
106+
if (cameraModel == CameraModel::OPENCV_RADTAN_THIN_PRISM_9) {
109107
// Polynomial radial + thin-prism; ignore k4..k6 by construction.
110108
radial = radial_in;
111109
numRadial = 3; // k1,k2,k3 (polynomial)
@@ -117,7 +115,7 @@ template <typename T> class OpenCVCameraModel {
117115
return;
118116
}
119117

120-
// Unknown distortion model: should be unreachable if host validation is correct.
118+
// Unknown camera model: should be unreachable if host validation is correct.
121119
deviceAssertOrTrap(false);
122120
}
123121

@@ -449,7 +447,7 @@ template <typename ScalarType> struct ProjectionForwardUT {
449447
const ScalarType mRadiusClip;
450448
const RollingShutterType mRollingShutterType;
451449
const UTParams mUTParams;
452-
const DistortionModel mDistortionModel;
450+
const CameraModel mCameraModel;
453451
const int64_t
454452
mNumDistortionCoeffs; // Number of distortion coeffs per camera (e.g. 12 for OPENCV)
455453

@@ -487,7 +485,7 @@ template <typename ScalarType> struct ProjectionForwardUT {
487485
const ScalarType minRadius2d,
488486
const RollingShutterType rollingShutterType,
489487
const UTParams &utParams,
490-
const DistortionModel distortionModel,
488+
const CameraModel cameraModel,
491489
const bool calcCompensations,
492490
const torch::Tensor &means, // [N, 3]
493491
const torch::Tensor &quats, // [N, 4]
@@ -506,7 +504,7 @@ template <typename ScalarType> struct ProjectionForwardUT {
506504
mImageWidth(static_cast<int32_t>(imageWidth)),
507505
mImageHeight(static_cast<int32_t>(imageHeight)), mEps2d(eps2d), mNearPlane(nearPlane),
508506
mFarPlane(farPlane), mRadiusClip(minRadius2d), mRollingShutterType(rollingShutterType),
509-
mUTParams(utParams), mDistortionModel(distortionModel),
507+
mUTParams(utParams), mCameraModel(cameraModel),
510508
mNumDistortionCoeffs(distortionCoeffs.size(1)),
511509
mMeansAcc(means.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>()),
512510
mQuatsAcc(quats.packed_accessor64<ScalarType, 2, torch::RestrictPtrTraits>()),
@@ -622,7 +620,7 @@ template <typename ScalarType> struct ProjectionForwardUT {
622620
const ScalarType *distortionCoeffs = &distortionCoeffsShared[camId * mNumDistortionCoeffs];
623621

624622
// Define the camera model (projection and distortion) using the shared memory pointers
625-
OpenCVCameraModel<ScalarType> camera(mDistortionModel, &projectionMatrix, distortionCoeffs);
623+
OpenCVCameraModel<ScalarType> camera(mCameraModel, &projectionMatrix, distortionCoeffs);
626624

627625
// Define the world-to-camera transforms using the shared memory pointers at the start and
628626
// end of the shutter period
@@ -797,7 +795,7 @@ dispatchGaussianProjectionForwardUT<torch::kCUDA>(
797795
const torch::Tensor &projectionMatrices, // [C, 3, 3]
798796
const RollingShutterType rollingShutterType,
799797
const UTParams &utParams,
800-
const DistortionModel distortionModel,
798+
const CameraModel cameraModel,
801799
const torch::Tensor &distortionCoeffs, // [C,12] for OPENCV, [C,0] for NONE
802800
const int64_t imageWidth,
803801
const int64_t imageHeight,
@@ -819,17 +817,17 @@ dispatchGaussianProjectionForwardUT<torch::kCUDA>(
819817
TORCH_CHECK_VALUE(projectionMatrices.is_cuda(), "projectionMatrices must be a CUDA tensor");
820818
TORCH_CHECK_VALUE(distortionCoeffs.is_cuda(), "distortionCoeffs must be a CUDA tensor");
821819
TORCH_CHECK_VALUE(distortionCoeffs.dim() == 2, "distortionCoeffs must be 2D");
822-
if (distortionModel == DistortionModel::NONE) {
820+
if (cameraModel == CameraModel::PINHOLE) {
823821
// Accept any K (including 0); ignored.
824-
} else if (distortionModel == DistortionModel::OPENCV_RADTAN_5 ||
825-
distortionModel == DistortionModel::OPENCV_RATIONAL_8 ||
826-
distortionModel == DistortionModel::OPENCV_RADTAN_THIN_PRISM_9 ||
827-
distortionModel == DistortionModel::OPENCV_THIN_PRISM_12) {
822+
} else if (cameraModel == CameraModel::OPENCV_RADTAN_5 ||
823+
cameraModel == CameraModel::OPENCV_RATIONAL_8 ||
824+
cameraModel == CameraModel::OPENCV_RADTAN_THIN_PRISM_9 ||
825+
cameraModel == CameraModel::OPENCV_THIN_PRISM_12) {
828826
TORCH_CHECK_VALUE(distortionCoeffs.size(1) == 12,
829-
"For DistortionModel::OPENCV_* , distortionCoeffs must have shape [C,12] "
827+
"For CameraModel::OPENCV_* , distortionCoeffs must have shape [C,12] "
830828
"as [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]");
831829
} else {
832-
TORCH_CHECK_VALUE(false, "Unknown DistortionModel for GaussianProjectionForwardUT");
830+
TORCH_CHECK_VALUE(false, "Unknown CameraModel for GaussianProjectionForwardUT");
833831
}
834832

835833
// This kernel implements only the canonical 3D UT with 2D+1 sigma points (7).
@@ -876,7 +874,7 @@ dispatchGaussianProjectionForwardUT<torch::kCUDA>(
876874
minRadius2d,
877875
rollingShutterType,
878876
utParams,
879-
distortionModel,
877+
cameraModel,
880878
calcCompensations,
881879
means,
882880
quats,
@@ -909,7 +907,7 @@ dispatchGaussianProjectionForwardUT<torch::kCPU>(
909907
const torch::Tensor &projectionMatrices, // [C, 3, 3]
910908
const RollingShutterType rollingShutterType,
911909
const UTParams &utParams,
912-
const DistortionModel distortionModel,
910+
const CameraModel cameraModel,
913911
const torch::Tensor &distortionCoeffs, // [C,12] for OPENCV, [C,0] for NONE
914912
const int64_t imageWidth,
915913
const int64_t imageHeight,
@@ -933,7 +931,7 @@ dispatchGaussianProjectionForwardUT<torch::kPrivateUse1>(
933931
const torch::Tensor &projectionMatrices, // [C, 3, 3]
934932
const RollingShutterType rollingShutterType,
935933
const UTParams &utParams,
936-
const DistortionModel distortionModel,
934+
const CameraModel cameraModel,
937935
const torch::Tensor &distortionCoeffs, // [C,12] for OPENCV, [C,0] for NONE
938936
const int64_t imageWidth,
939937
const int64_t imageHeight,

src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,22 @@ namespace ops {
1515

1616
enum class RollingShutterType { NONE = 0, VERTICAL = 1, HORIZONTAL = 2 };
1717

18-
// Distortion model for camera projection in the UT kernel.
18+
// Camera model for projection in the UT kernel.
19+
//
20+
// Today, all supported camera models are pinhole intrinsics + optional OpenCV-style distortion.
1921
//
2022
// Distortion coefficients are supplied as a single tensor `distortionCoeffs` and interpreted
21-
// according to this enum.
22-
enum class DistortionModel : int32_t {
23-
NONE = 0,
23+
// according to this enum for the OpenCV variants.
24+
enum class CameraModel : int32_t {
25+
// Pinhole intrinsics only (no distortion).
26+
PINHOLE = 0,
2427

25-
// OpenCV variants (all use the same [C,12] coefficient layout):
26-
// [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]
27-
//
28-
// The enum exists mostly for clarity + runtime validation of coefficient usage.
29-
OPENCV_RADTAN_5 = 1, // polynomial radial (k1,k2,k3) + tangential (p1,p2)
30-
OPENCV_RATIONAL_8 = 2, // rational radial (k1..k6) + tangential (p1,p2)
31-
OPENCV_RADTAN_THIN_PRISM_9 = 3, // polynomial radial + tangential + thin-prism (s1..s4)
32-
OPENCV_THIN_PRISM_12 = 4, // rational radial + tangential + thin-prism (s1..s4)
28+
// OpenCV variants which are just pinhole intrinsics + optional distortion (all of them use the
29+
// same [C,12] distortion coefficients layout: [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]).
30+
OPENCV_RADTAN_5 = 1, // polynomial radial (k1,k2,k3) + tangential (p1,p2)).
31+
OPENCV_RATIONAL_8 = 2, // rational radial (k1..k6) + tangential (p1,p2)).
32+
OPENCV_RADTAN_THIN_PRISM_9 = 3, // polynomial radial + tangential + thin-prism (s1..s4)).
33+
OPENCV_THIN_PRISM_12 = 4, // rational radial + tangential + thin-prism (s1..s4)).
3334
};
3435

3536
struct UTParams {
@@ -79,10 +80,10 @@ struct UTParams {
7980
/// @param[in] projectionMatrices Camera intrinsic matrices [C, 3, 3]
8081
/// @param[in] rollingShutterType Type of rolling shutter effect to apply
8182
/// @param[in] utParams Unscented Transform parameters
82-
/// @param[in] distortionModel Distortion model used to interpret `distortionCoeffs`.
83+
/// @param[in] cameraModel Camera model used to interpret `distortionCoeffs`.
8384
/// @param[in] distortionCoeffs Distortion coefficients for each camera.
84-
/// - DistortionModel::NONE: ignored (use [C,0] or [C,K] tensor).
85-
/// - DistortionModel::OPENCV_*: expects [C,12] coefficients in the following order:
85+
/// - CameraModel::PINHOLE: ignored (use [C,0] or [C,K] tensor).
86+
/// - CameraModel::OPENCV_*: expects [C,12] coefficients in the following order:
8687
/// [k1,k2,k3,k4,k5,k6,p1,p2,s1,s2,s3,s4]
8788
/// where k1..k6 are radial (rational), p1,p2 are tangential, and s1..s4 are thin-prism.
8889
/// @param[in] imageWidth Width of the output image in pixels
@@ -112,7 +113,7 @@ dispatchGaussianProjectionForwardUT(
112113
const torch::Tensor &projectionMatrices, // [C, 3, 3]
113114
const RollingShutterType rollingShutterType,
114115
const UTParams &utParams,
115-
const DistortionModel distortionModel,
116+
const CameraModel cameraModel,
116117
const torch::Tensor &distortionCoeffs, // [C, 12] for OPENCV_*, or [C, 0] for NONE
117118
const int64_t imageWidth,
118119
const int64_t imageHeight,

0 commit comments

Comments
 (0)