Skip to content

Commit 4b598a9

Browse files
committed
Made radii required to get around backwards issue; more consistent with eval SH use in GaussianSplat3d
Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
1 parent 976a02a commit 4b598a9

File tree

5 files changed

+43
-113
lines changed

5 files changed

+43
-113
lines changed

fvdb/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def evaluate_spherical_harmonics(
9696
sh_degree: int,
9797
num_cameras: int,
9898
sh0: torch.Tensor,
99+
radii: torch.Tensor,
99100
shN: torch.Tensor | None = None,
100101
view_directions: torch.Tensor | None = None,
101-
radii: torch.Tensor | None = None,
102102
) -> torch.Tensor: ...
103103

104104
__all__ = [

fvdb/_fvdb_cpp.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,9 @@ def evaluate_spherical_harmonics(
11251125
sh_degree: int,
11261126
num_cameras: int,
11271127
sh0: torch.Tensor,
1128+
radii: torch.Tensor,
11281129
shN: Optional[torch.Tensor] = ...,
11291130
view_directions: Optional[torch.Tensor] = ...,
1130-
radii: Optional[torch.Tensor] = ...,
11311131
) -> torch.Tensor: ...
11321132
@overload
11331133
def jcat(grid_batches: list[GridBatch]) -> GridBatch: ...

src/fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ EvaluateSphericalHarmonics::backward(EvaluateSphericalHarmonics::AutogradContext
5858
Variable viewDirs = saved.at(0);
5959
Variable shNCoeffs = saved.at(1);
6060

61-
const int shDegreeToUse = static_cast<int>(ctx->saved_data["shDegreeToUse"].toInt());
62-
const int numCameras = static_cast<int>(ctx->saved_data["numCameras"].toInt());
63-
const int numGaussians = static_cast<int>(ctx->saved_data["numGaussians"].toInt());
64-
const bool hasRadii = ctx->saved_data["hasRadii"].toBool();
65-
const bool computeDLossDViewDirs = ctx->needs_input_grad(1);
61+
const int shDegreeToUse = static_cast<int>(ctx->saved_data["shDegreeToUse"].toInt());
62+
const int numCameras = static_cast<int>(ctx->saved_data["numCameras"].toInt());
63+
const int numGaussians = static_cast<int>(ctx->saved_data["numGaussians"].toInt());
64+
const bool hasRadii = ctx->saved_data["hasRadii"].toBool();
65+
// Only compute viewDirs gradients if viewDirs is defined and requires grad
66+
const bool computeDLossDViewDirs = viewDirs.defined() && viewDirs.requires_grad();
6667

6768
// Radii is only saved if it was defined in forward
6869
Variable radii = hasRadii ? saved.at(2) : torch::Tensor();

src/python/GaussianSplatBinding.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,9 @@ bind_gaussian_splat3d(py::module &m) {
399399
[](int64_t shDegree,
400400
int64_t numCameras,
401401
const torch::Tensor &sh0,
402+
const torch::Tensor &radii,
402403
const std::optional<torch::Tensor> &shN,
403-
const std::optional<torch::Tensor> &viewDirections,
404-
const std::optional<torch::Tensor> &radii) {
404+
const std::optional<torch::Tensor> &viewDirections) {
405405
return fvdb::detail::autograd::EvaluateSphericalHarmonics::apply(
406406
shDegree, numCameras, viewDirections, sh0, shN, radii)[0];
407407
},
@@ -419,21 +419,21 @@ view directions for view-dependent appearance.
419419
num_cameras: Number of camera views (C). The output will have shape [C, N, D].
420420
sh0: DC term coefficients with shape [N, 1, D] where N is the number of
421421
points and D is the number of feature channels.
422+
radii: Projected radii with shape [C, N] (int32). Points with radii <= 0
423+
will output zeros (used to skip invisible gaussians). Pass a tensor
424+
of ones to evaluate all points.
422425
shN: Higher-order SH coefficients with shape [N, K-1, D] where
423426
K = (sh_degree+1)^2. Required when sh_degree > 0. Pass None for degree 0.
424427
view_directions: Unnormalized view directions with shape [C, N, 3].
425428
Required when sh_degree > 0. Pass None for degree 0.
426-
radii: Optional projected radii with shape [C, N] (int32). When provided,
427-
points with radii <= 0 will output zeros (used to skip invisible
428-
gaussians as an optimization). Pass None to evaluate all points.
429429
430430
Returns:
431431
Tensor of shape [C, N, D] containing the evaluated features/colors.
432432
)doc",
433433
py::arg("sh_degree"),
434434
py::arg("num_cameras"),
435435
py::arg("sh0"),
436+
py::arg("radii"),
436437
py::arg("shN") = std::nullopt,
437-
py::arg("view_directions") = std::nullopt,
438-
py::arg("radii") = std::nullopt);
438+
py::arg("view_directions") = std::nullopt);
439439
}

tests/unit/test_gaussian_splat_3d.py

Lines changed: 28 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -3126,115 +3126,20 @@ def test_degree_0_basic(self):
31263126
C = 2 # number of cameras
31273127

31283128
sh0 = torch.randn(N, 1, D, device=self.device)
3129+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
31293130

31303131
result = evaluate_spherical_harmonics(
31313132
sh_degree=0,
31323133
num_cameras=C,
31333134
sh0=sh0,
3135+
radii=radii,
31343136
)
31353137

31363138
self.assertEqual(result.shape, (C, N, D))
31373139
# For degree 0, the result should be the same for all cameras
31383140
# since there's no view dependence
31393141
self.assertTrue(torch.allclose(result[0], result[1], atol=1e-6))
31403142

3141-
def test_optional_parameters_omitted(self):
3142-
"""Test that optional parameters (shN, view_directions, radii) can be omitted.
3143-
3144-
This is a regression test for an issue where passing undefined tensors
3145-
to save_for_backward caused 'tensor does not have a device' errors.
3146-
"""
3147-
N = 50
3148-
D = 3
3149-
C = 2
3150-
3151-
sh0 = torch.randn(N, 1, D, device=self.device)
3152-
3153-
# Test degree 0 with no optional parameters
3154-
result = evaluate_spherical_harmonics(
3155-
sh_degree=0,
3156-
num_cameras=C,
3157-
sh0=sh0,
3158-
)
3159-
self.assertEqual(result.shape, (C, N, D))
3160-
self.assertFalse(torch.isnan(result).any())
3161-
3162-
# Test degree 0 with explicit None for optional parameters
3163-
result = evaluate_spherical_harmonics(
3164-
sh_degree=0,
3165-
num_cameras=C,
3166-
sh0=sh0,
3167-
shN=None,
3168-
view_directions=None,
3169-
radii=None,
3170-
)
3171-
self.assertEqual(result.shape, (C, N, D))
3172-
self.assertFalse(torch.isnan(result).any())
3173-
3174-
# Test higher degree without radii
3175-
shN = torch.randn(N, 15, D, device=self.device)
3176-
view_dirs = torch.randn(C, N, 3, device=self.device)
3177-
3178-
result = evaluate_spherical_harmonics(
3179-
sh_degree=3,
3180-
num_cameras=C,
3181-
sh0=sh0,
3182-
shN=shN,
3183-
view_directions=view_dirs,
3184-
# radii intentionally omitted
3185-
)
3186-
self.assertEqual(result.shape, (C, N, D))
3187-
self.assertFalse(torch.isnan(result).any())
3188-
3189-
# Test higher degree with explicit radii=None
3190-
result = evaluate_spherical_harmonics(
3191-
sh_degree=3,
3192-
num_cameras=C,
3193-
sh0=sh0,
3194-
shN=shN,
3195-
view_directions=view_dirs,
3196-
radii=None,
3197-
)
3198-
self.assertEqual(result.shape, (C, N, D))
3199-
self.assertFalse(torch.isnan(result).any())
3200-
3201-
def test_optional_parameters_gradient_flow(self):
3202-
"""Test gradient flow works when optional parameters are omitted.
3203-
3204-
Regression test to ensure backward pass works without radii.
3205-
"""
3206-
N = 20
3207-
D = 3
3208-
C = 2
3209-
3210-
sh0 = torch.randn(N, 1, D, device=self.device, requires_grad=True)
3211-
shN = torch.randn(N, 15, D, device=self.device, requires_grad=True)
3212-
view_dirs = torch.randn(C, N, 3, device=self.device)
3213-
3214-
# Forward without radii
3215-
result = evaluate_spherical_harmonics(
3216-
sh_degree=3,
3217-
num_cameras=C,
3218-
sh0=sh0,
3219-
shN=shN,
3220-
view_directions=view_dirs,
3221-
# radii intentionally omitted
3222-
)
3223-
3224-
# Backward
3225-
loss = result.sum()
3226-
loss.backward()
3227-
3228-
# Check gradients exist and are valid
3229-
self.assertIsNotNone(sh0.grad)
3230-
self.assertIsNotNone(shN.grad)
3231-
assert sh0.grad is not None # for type narrowing
3232-
assert shN.grad is not None # for type narrowing
3233-
self.assertTrue(torch.any(sh0.grad != 0))
3234-
self.assertTrue(torch.any(shN.grad != 0))
3235-
self.assertFalse(torch.isnan(sh0.grad).any())
3236-
self.assertFalse(torch.isnan(shN.grad).any())
3237-
32383143
def test_degree_0_matches_expected(self):
32393144
"""Test that degree 0 SH evaluation produces expected output."""
32403145
N = 10
@@ -3243,11 +3148,13 @@ def test_degree_0_matches_expected(self):
32433148

32443149
# Known sh0 values
32453150
sh0 = torch.ones(N, 1, D, device=self.device)
3151+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
32463152

32473153
result = evaluate_spherical_harmonics(
32483154
sh_degree=0,
32493155
num_cameras=C,
32503156
sh0=sh0,
3157+
radii=radii,
32513158
)
32523159

32533160
# For degree 0: result = 0.2820947917738781 * sh0 + 0.5
@@ -3264,13 +3171,15 @@ def test_degree_1_requires_view_directions(self):
32643171

32653172
sh0 = torch.randn(N, 1, D, device=self.device)
32663173
shN = torch.randn(N, 3, D, device=self.device) # 3 coefficients for degree 1
3174+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
32673175

3268-
# Should raise error when view_directions is not provided for degree > 0
3269-
with self.assertRaises(RuntimeError):
3176+
# Should raise ValueError when view_directions is not provided for degree > 0
3177+
with self.assertRaises(ValueError):
32703178
evaluate_spherical_harmonics(
32713179
sh_degree=1,
32723180
num_cameras=C,
32733181
sh0=sh0,
3182+
radii=radii,
32743183
shN=shN,
32753184
view_directions=None,
32763185
)
@@ -3281,6 +3190,7 @@ def test_degree_1_requires_view_directions(self):
32813190
sh_degree=1,
32823191
num_cameras=C,
32833192
sh0=sh0,
3193+
radii=radii,
32843194
shN=shN,
32853195
view_directions=view_dirs,
32863196
)
@@ -3296,11 +3206,13 @@ def test_degree_3_full(self):
32963206
# Degree 3 has (3+1)^2 = 16 bases, so K-1 = 15 higher order coefficients
32973207
shN = torch.randn(N, 15, D, device=self.device)
32983208
view_dirs = torch.randn(C, N, 3, device=self.device)
3209+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
32993210

33003211
result = evaluate_spherical_harmonics(
33013212
sh_degree=3,
33023213
num_cameras=C,
33033214
sh0=sh0,
3215+
radii=radii,
33043216
shN=shN,
33053217
view_directions=view_dirs,
33063218
)
@@ -3343,11 +3255,15 @@ def test_gradient_flow_sh0(self):
33433255
C = 1
33443256

33453257
sh0 = torch.randn(N, 1, D, device=self.device, requires_grad=True)
3258+
# Note: radii must be provided for backward pass to work correctly
3259+
# (matches GaussianSplat3d usage pattern)
3260+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
33463261

33473262
result = evaluate_spherical_harmonics(
33483263
sh_degree=0,
33493264
num_cameras=C,
33503265
sh0=sh0,
3266+
radii=radii,
33513267
)
33523268

33533269
loss = result.sum()
@@ -3365,11 +3281,14 @@ def test_gradient_flow_shN(self):
33653281
sh0 = torch.randn(N, 1, D, device=self.device, requires_grad=True)
33663282
shN = torch.randn(N, 15, D, device=self.device, requires_grad=True)
33673283
view_dirs = torch.randn(C, N, 3, device=self.device)
3284+
# Note: radii must be provided for backward pass to work correctly
3285+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
33683286

33693287
result = evaluate_spherical_harmonics(
33703288
sh_degree=3,
33713289
num_cameras=C,
33723290
sh0=sh0,
3291+
radii=radii,
33733292
shN=shN,
33743293
view_directions=view_dirs,
33753294
)
@@ -3391,11 +3310,13 @@ def test_gradient_flow_view_directions(self):
33913310
sh0 = torch.randn(N, 1, D, device=self.device)
33923311
shN = torch.randn(N, 15, D, device=self.device)
33933312
view_dirs = torch.randn(C, N, 3, device=self.device, requires_grad=True)
3313+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
33943314

33953315
result = evaluate_spherical_harmonics(
33963316
sh_degree=3,
33973317
num_cameras=C,
33983318
sh0=sh0,
3319+
radii=radii,
33993320
shN=shN,
34003321
view_directions=view_dirs,
34013322
)
@@ -3415,11 +3336,13 @@ def test_single_gaussian(self):
34153336
sh0 = torch.randn(N, 1, D, device=self.device)
34163337
shN = torch.randn(N, 15, D, device=self.device)
34173338
view_dirs = torch.randn(C, N, 3, device=self.device)
3339+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
34183340

34193341
result = evaluate_spherical_harmonics(
34203342
sh_degree=3,
34213343
num_cameras=C,
34223344
sh0=sh0,
3345+
radii=radii,
34233346
shN=shN,
34243347
view_directions=view_dirs,
34253348
)
@@ -3435,11 +3358,13 @@ def test_many_channels(self):
34353358
sh0 = torch.randn(N, 1, D, device=self.device)
34363359
shN = torch.randn(N, 15, D, device=self.device)
34373360
view_dirs = torch.randn(C, N, 3, device=self.device)
3361+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
34383362

34393363
result = evaluate_spherical_harmonics(
34403364
sh_degree=3,
34413365
num_cameras=C,
34423366
sh0=sh0,
3367+
radii=radii,
34433368
shN=shN,
34443369
view_directions=view_dirs,
34453370
)
@@ -3457,11 +3382,13 @@ def test_different_sh_degrees(self, sh_degree):
34573382
K = (sh_degree + 1) ** 2
34583383
shN = torch.randn(N, K - 1, D, device=self.device)
34593384
view_dirs = torch.randn(C, N, 3, device=self.device)
3385+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
34603386

34613387
result = evaluate_spherical_harmonics(
34623388
sh_degree=sh_degree,
34633389
num_cameras=C,
34643390
sh0=sh0,
3391+
radii=radii,
34653392
shN=shN,
34663393
view_directions=view_dirs,
34673394
)
@@ -3476,6 +3403,7 @@ def test_view_directions_not_prenormalized(self):
34763403

34773404
sh0 = torch.randn(N, 1, D, device=self.device)
34783405
shN = torch.randn(N, 15, D, device=self.device)
3406+
radii = torch.ones(C, N, dtype=torch.int32, device=self.device)
34793407

34803408
# Unnormalized view directions (varying magnitudes)
34813409
view_dirs = torch.randn(C, N, 3, device=self.device) * 10.0
@@ -3485,6 +3413,7 @@ def test_view_directions_not_prenormalized(self):
34853413
sh_degree=3,
34863414
num_cameras=C,
34873415
sh0=sh0,
3416+
radii=radii,
34883417
shN=shN,
34893418
view_directions=view_dirs,
34903419
)

0 commit comments

Comments
 (0)