Skip to content

Commit 61d1352

Browse files
committed
logscales
Signed-off-by: Francis Williams <francis@fwilliams.info>
1 parent 9936c0d commit 61d1352

File tree

3 files changed

+31
-28
lines changed

3 files changed

+31
-28
lines changed

src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
676676
dispatchGaussianProjectionForwardUT<torch::kCUDA>(
677677
const torch::Tensor &means, // [N, 3]
678678
const torch::Tensor &quats, // [N, 4]
679-
const torch::Tensor &scales, // [N, 3]
679+
const torch::Tensor &logScales, // [N, 3]
680680
const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4]
681681
const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4]
682682
const torch::Tensor &projectionMatrices, // [C, 3, 3]
@@ -696,7 +696,7 @@ dispatchGaussianProjectionForwardUT<torch::kCUDA>(
696696

697697
TORCH_CHECK_VALUE(means.is_cuda(), "means must be a CUDA tensor");
698698
TORCH_CHECK_VALUE(quats.is_cuda(), "quats must be a CUDA tensor");
699-
TORCH_CHECK_VALUE(scales.is_cuda(), "scales must be a CUDA tensor");
699+
TORCH_CHECK_VALUE(logScales.is_cuda(), "logScales must be a CUDA tensor");
700700
TORCH_CHECK_VALUE(worldToCamMatricesStart.is_cuda(),
701701
"worldToCamMatricesStart must be a CUDA tensor");
702702
TORCH_CHECK_VALUE(worldToCamMatricesEnd.is_cuda(),
@@ -784,7 +784,7 @@ dispatchGaussianProjectionForwardUT<torch::kCUDA>(
784784
calcCompensations,
785785
means,
786786
quats,
787-
torch::log(scales),
787+
logScales,
788788
worldToCamMatricesStart,
789789
worldToCamMatricesEnd,
790790
projectionMatrices,
@@ -807,7 +807,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
807807
dispatchGaussianProjectionForwardUT<torch::kCPU>(
808808
const torch::Tensor &means, // [N, 3]
809809
const torch::Tensor &quats, // [N, 4]
810-
const torch::Tensor &scales, // [N, 3]
810+
const torch::Tensor &logScales, // [N, 3]
811811
const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4]
812812
const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4]
813813
const torch::Tensor &projectionMatrices, // [C, 3, 3]
@@ -823,6 +823,7 @@ dispatchGaussianProjectionForwardUT<torch::kCPU>(
823823
const float minRadius2d,
824824
const bool calcCompensations,
825825
const bool ortho) {
826+
(void)logScales;
826827
(void)distortionModel;
827828
(void)distortionCoeffs;
828829
TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianProjectionForwardUT not implemented on the CPU");
@@ -833,7 +834,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
833834
dispatchGaussianProjectionForwardUT<torch::kPrivateUse1>(
834835
const torch::Tensor &means, // [N, 3]
835836
const torch::Tensor &quats, // [N, 4]
836-
const torch::Tensor &scales, // [N, 3]
837+
const torch::Tensor &logScales, // [N, 3]
837838
const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4]
838839
const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4]
839840
const torch::Tensor &projectionMatrices, // [C, 3, 3]
@@ -849,6 +850,7 @@ dispatchGaussianProjectionForwardUT<torch::kPrivateUse1>(
849850
const float minRadius2d,
850851
const bool calcCompensations,
851852
const bool ortho) {
853+
(void)logScales;
852854
(void)distortionModel;
853855
(void)distortionCoeffs;
854856
TORCH_CHECK_NOT_IMPLEMENTED(false,

src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ struct UTParams {
6161
///
6262
/// @param[in] means 3D positions of Gaussians [N, 3] where N is number of Gaussians
6363
/// @param[in] quats Quaternion rotations of Gaussians [N, 4] in format (w, x, y, z)
64-
/// @param[in] scales Scale factors of Gaussians [N, 3] representing extent in each dimension
64+
/// @param[in] logScales Log-scale factors of Gaussians [N, 3] (natural log), representing extent in
65+
/// each dimension
6566
/// @param[in] worldToCamMatricesStart Camera view matrices at the start of the frame. Shape [C, 4,
6667
/// 4] where C is number of cameras
6768
/// @param[in] worldToCamMatricesEnd Camera view matrices at the end of the frame. Shape [C, 4, 4]
@@ -95,7 +96,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
9596
dispatchGaussianProjectionForwardUT(
9697
const torch::Tensor &means, // [N, 3]
9798
const torch::Tensor &quats, // [N, 4]
98-
const torch::Tensor &scales, // [N, 3]
99+
const torch::Tensor &logScales, // [N, 3]
99100
const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4]
100101
const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4]
101102
const torch::Tensor &projectionMatrices, // [C, 3, 3]

src/tests/GaussianProjectionUTTest.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct GaussianProjectionUTTestFixture : public ::testing::Test {
104104

105105
torch::Tensor means; // [N, 3]
106106
torch::Tensor quats; // [N, 4]
107-
torch::Tensor scales; // [N, 3]
107+
torch::Tensor logScales; // [N, 3]
108108
torch::Tensor worldToCamMatricesStart; // [C, 4, 4]
109109
torch::Tensor worldToCamMatricesEnd; // [C, 4, 4]
110110
torch::Tensor projectionMatrices; // [C, 3, 3]
@@ -132,7 +132,7 @@ TEST_F(GaussianProjectionUTTestFixture, CenteredGaussian_NoDistortion_AnalyticMe
132132
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
133133

134134
const float sx = 0.2f, sy = 0.3f, sz = 0.4f;
135-
scales = torch::tensor({{sx, sy, sz}}, torch::kFloat32);
135+
logScales = torch::log(torch::tensor({{sx, sy, sz}}, torch::kFloat32));
136136

137137
worldToCamMatricesStart =
138138
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
@@ -165,7 +165,7 @@ TEST_F(GaussianProjectionUTTestFixture, CenteredGaussian_NoDistortion_AnalyticMe
165165
// CUDA
166166
means = means.cuda();
167167
quats = quats.cuda();
168-
scales = scales.cuda();
168+
logScales = logScales.cuda();
169169
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
170170
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
171171
projectionMatrices = projectionMatrices.cuda();
@@ -174,7 +174,7 @@ TEST_F(GaussianProjectionUTTestFixture, CenteredGaussian_NoDistortion_AnalyticMe
174174
const auto [radii, means2d, depths, conics, compensations] =
175175
dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
176176
quats,
177-
scales,
177+
logScales,
178178
worldToCamMatricesStart,
179179
worldToCamMatricesEnd,
180180
projectionMatrices,
@@ -226,7 +226,7 @@ TEST_F(GaussianProjectionUTTestFixture, OffAxisTinyGaussian_NoDistortion_MeanMat
226226
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
227227
// Extremely small Gaussian so UT mean should match the point projection closely
228228
// (off-axis + perspective nonlinearity can otherwise introduce a tiny UT mean shift).
229-
scales = torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32);
229+
logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32));
230230

231231
worldToCamMatricesStart =
232232
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
@@ -258,7 +258,7 @@ TEST_F(GaussianProjectionUTTestFixture, OffAxisTinyGaussian_NoDistortion_MeanMat
258258

259259
means = means.cuda();
260260
quats = quats.cuda();
261-
scales = scales.cuda();
261+
logScales = logScales.cuda();
262262
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
263263
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
264264
projectionMatrices = projectionMatrices.cuda();
@@ -267,7 +267,7 @@ TEST_F(GaussianProjectionUTTestFixture, OffAxisTinyGaussian_NoDistortion_MeanMat
267267
const auto [radii, means2d, depths, conics, compensations] =
268268
dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
269269
quats,
270-
scales,
270+
logScales,
271271
worldToCamMatricesStart,
272272
worldToCamMatricesEnd,
273273
projectionMatrices,
@@ -300,7 +300,7 @@ TEST_F(GaussianProjectionUTTestFixture,
300300
const float x = 0.2f, y = -0.1f, z = 2.0f;
301301
means = torch::tensor({{x, y, z}}, torch::kFloat32);
302302
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
303-
scales = torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32);
303+
logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32));
304304

305305
worldToCamMatricesStart =
306306
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
@@ -345,7 +345,7 @@ TEST_F(GaussianProjectionUTTestFixture,
345345

346346
means = means.cuda();
347347
quats = quats.cuda();
348-
scales = scales.cuda();
348+
logScales = logScales.cuda();
349349
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
350350
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
351351
projectionMatrices = projectionMatrices.cuda();
@@ -354,7 +354,7 @@ TEST_F(GaussianProjectionUTTestFixture,
354354
const auto [radii, means2d, depths, conics, compensations] =
355355
dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
356356
quats,
357-
scales,
357+
logScales,
358358
worldToCamMatricesStart,
359359
worldToCamMatricesEnd,
360360
projectionMatrices,
@@ -388,7 +388,7 @@ TEST_F(GaussianProjectionUTTestFixture,
388388
const float x = -0.15f, y = 0.12f, z = 3.0f;
389389
means = torch::tensor({{x, y, z}}, torch::kFloat32);
390390
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
391-
scales = torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32);
391+
logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32));
392392

393393
worldToCamMatricesStart =
394394
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
@@ -437,7 +437,7 @@ TEST_F(GaussianProjectionUTTestFixture,
437437

438438
means = means.cuda();
439439
quats = quats.cuda();
440-
scales = scales.cuda();
440+
logScales = logScales.cuda();
441441
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
442442
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
443443
projectionMatrices = projectionMatrices.cuda();
@@ -446,7 +446,7 @@ TEST_F(GaussianProjectionUTTestFixture,
446446
const auto [radii, means2d, depths, conics, compensations] =
447447
dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
448448
quats,
449-
scales,
449+
logScales,
450450
worldToCamMatricesStart,
451451
worldToCamMatricesEnd,
452452
projectionMatrices,
@@ -480,7 +480,7 @@ TEST_F(GaussianProjectionUTTestFixture,
480480
const float x = 0.1f, y = 0.08f, z = 2.5f;
481481
means = torch::tensor({{x, y, z}}, torch::kFloat32);
482482
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
483-
scales = torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32);
483+
logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32));
484484

485485
worldToCamMatricesStart =
486486
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
@@ -537,7 +537,7 @@ TEST_F(GaussianProjectionUTTestFixture,
537537

538538
means = means.cuda();
539539
quats = quats.cuda();
540-
scales = scales.cuda();
540+
logScales = logScales.cuda();
541541
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
542542
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
543543
projectionMatrices = projectionMatrices.cuda();
@@ -546,7 +546,7 @@ TEST_F(GaussianProjectionUTTestFixture,
546546
const auto [radii, means2d, depths, conics, compensations] =
547547
dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
548548
quats,
549-
scales,
549+
logScales,
550550
worldToCamMatricesStart,
551551
worldToCamMatricesEnd,
552552
projectionMatrices,
@@ -580,7 +580,7 @@ TEST_F(GaussianProjectionUTTestFixture,
580580
const float x = 0.07f, y = -0.11f, z = 2.2f;
581581
means = torch::tensor({{x, y, z}}, torch::kFloat32);
582582
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32);
583-
scales = torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32);
583+
logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32));
584584

585585
worldToCamMatricesStart =
586586
torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(0).expand({C, 4, 4});
@@ -632,7 +632,7 @@ TEST_F(GaussianProjectionUTTestFixture,
632632

633633
means = means.cuda();
634634
quats = quats.cuda();
635-
scales = scales.cuda();
635+
logScales = logScales.cuda();
636636
worldToCamMatricesStart = worldToCamMatricesStart.cuda();
637637
worldToCamMatricesEnd = worldToCamMatricesEnd.cuda();
638638
projectionMatrices = projectionMatrices.cuda();
@@ -641,7 +641,7 @@ TEST_F(GaussianProjectionUTTestFixture,
641641
const auto [radii, means2d, depths, conics, compensations] =
642642
dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
643643
quats,
644-
scales,
644+
logScales,
645645
worldToCamMatricesStart,
646646
worldToCamMatricesEnd,
647647
projectionMatrices,
@@ -673,7 +673,7 @@ TEST_F(GaussianProjectionUTTestFixture, RadTanThinPrism_RejectsNonZeroK456) {
673673

674674
means = torch::tensor({{0.1f, 0.05f, 2.0f}}, torch::kFloat32).cuda();
675675
quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, torch::kFloat32).cuda();
676-
scales = torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32).cuda();
676+
logScales = torch::log(torch::tensor({{1e-6f, 1e-6f, 1e-6f}}, torch::kFloat32)).cuda();
677677

678678
worldToCamMatricesStart = torch::eye(4, torch::TensorOptions().dtype(torch::kFloat32))
679679
.unsqueeze(0)
@@ -705,7 +705,7 @@ TEST_F(GaussianProjectionUTTestFixture, RadTanThinPrism_RejectsNonZeroK456) {
705705

706706
EXPECT_THROW((dispatchGaussianProjectionForwardUT<torch::kCUDA>(means,
707707
quats,
708-
scales,
708+
logScales,
709709
worldToCamMatricesStart,
710710
worldToCamMatricesEnd,
711711
projectionMatrices,

0 commit comments

Comments
 (0)