Skip to content

Commit 976a02a

Browse files
committed
Fix an issue where if radii is not provided (it is optional), a 'tensor does not have a device' error is thrown when 'save_for_backward' is called
Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
1 parent 4ff0e7d commit 976a02a

File tree

4 files changed

+114
-7
lines changed

4 files changed

+114
-7
lines changed

src/fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,27 @@ EvaluateSphericalHarmonics::forward(
2020
viewDirections, // [C, N, 3] (optional)
2121
const EvaluateSphericalHarmonics::Variable &sh0Coeffs, // [N, 1, D]
2222
const std::optional<EvaluateSphericalHarmonics::Variable> &shNCoeffs, // [N, K-1, D]
23-
const EvaluateSphericalHarmonics::Variable &radii // [C, N]
23+
const std::optional<EvaluateSphericalHarmonics::Variable> &radii // [C, N] (optional)
2424
) {
2525
FVDB_FUNC_RANGE_WITH_NAME("EvaluateSphericalHarmonics::forward");
2626
const Variable viewDirectionsValue = viewDirections.value_or(torch::Tensor());
2727
const Variable shNCoeffsValue = shNCoeffs.value_or(torch::Tensor());
28+
const Variable radiiValue = radii.value_or(torch::Tensor());
2829
const Variable renderQuantities = FVDB_DISPATCH_KERNEL(sh0Coeffs.device(), [&]() {
2930
return ops::dispatchSphericalHarmonicsForward<DeviceTag>(
30-
shDegreeToUse, numCameras, viewDirectionsValue, sh0Coeffs, shNCoeffsValue, radii);
31+
shDegreeToUse, numCameras, viewDirectionsValue, sh0Coeffs, shNCoeffsValue, radiiValue);
3132
});
32-
ctx->save_for_backward({viewDirectionsValue, shNCoeffsValue, radii});
33+
// only save radii if defined to avoid device access issues
34+
const bool hasRadii = radii.has_value() && radii.value().defined();
35+
if (hasRadii) {
36+
ctx->save_for_backward({viewDirectionsValue, shNCoeffsValue, radiiValue});
37+
} else {
38+
ctx->save_for_backward({viewDirectionsValue, shNCoeffsValue});
39+
}
3340
ctx->saved_data["shDegreeToUse"] = static_cast<int64_t>(shDegreeToUse);
3441
ctx->saved_data["numCameras"] = static_cast<int64_t>(numCameras);
3542
ctx->saved_data["numGaussians"] = static_cast<int64_t>(sh0Coeffs.size(0));
43+
ctx->saved_data["hasRadii"] = hasRadii;
3644
return {renderQuantities};
3745
}
3846

@@ -49,13 +57,16 @@ EvaluateSphericalHarmonics::backward(EvaluateSphericalHarmonics::AutogradContext
4957
VariableList saved = ctx->get_saved_variables();
5058
Variable viewDirs = saved.at(0);
5159
Variable shNCoeffs = saved.at(1);
52-
Variable radii = saved.at(2);
5360

5461
const int shDegreeToUse = static_cast<int>(ctx->saved_data["shDegreeToUse"].toInt());
5562
const int numCameras = static_cast<int>(ctx->saved_data["numCameras"].toInt());
5663
const int numGaussians = static_cast<int>(ctx->saved_data["numGaussians"].toInt());
64+
const bool hasRadii = ctx->saved_data["hasRadii"].toBool();
5765
const bool computeDLossDViewDirs = ctx->needs_input_grad(1);
5866

67+
// Radii is only saved if it was defined in forward
68+
Variable radii = hasRadii ? saved.at(2) : torch::Tensor();
69+
5970
auto variables = FVDB_DISPATCH_KERNEL(dLossdColors.device(), [&]() {
6071
return ops::dispatchSphericalHarmonicsBackward<DeviceTag>(shDegreeToUse,
6172
numCameras,

src/fvdb/detail/autograd/EvaluateSphericalHarmonics.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct EvaluateSphericalHarmonics : public torch::autograd::Function<EvaluateSph
2222
const std::optional<Variable> viewDirections, // [N, 3] or empty for deg 0
2323
const Variable &sh0Coeffs, // [N, 1, D]
2424
const std::optional<Variable> &shNCoeffs, // [N, K-1, D]
25-
const Variable &radii // [N,]
25+
const std::optional<Variable> &radii // [C, N] (optional)
2626
);
2727

2828
static VariableList backward(AutogradContext *ctx, VariableList gradOutput);

src/python/GaussianSplatBinding.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,8 @@ bind_gaussian_splat3d(py::module &m) {
402402
const std::optional<torch::Tensor> &shN,
403403
const std::optional<torch::Tensor> &viewDirections,
404404
const std::optional<torch::Tensor> &radii) {
405-
torch::Tensor radiiValue = radii.value_or(torch::Tensor());
406405
return fvdb::detail::autograd::EvaluateSphericalHarmonics::apply(
407-
shDegree, numCameras, viewDirections, sh0, shN, radiiValue)[0];
406+
shDegree, numCameras, viewDirections, sh0, shN, radii)[0];
408407
},
409408
R"doc(
410409
Evaluate spherical harmonics to compute view-dependent features/colors.

tests/unit/test_gaussian_splat_3d.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,6 +3138,103 @@ def test_degree_0_basic(self):
31383138
# since there's no view dependence
31393139
self.assertTrue(torch.allclose(result[0], result[1], atol=1e-6))
31403140

3141+
def test_optional_parameters_omitted(self):
3142+
"""Test that optional parameters (shN, view_directions, radii) can be omitted.
3143+
3144+
This is a regression test for an issue where passing undefined tensors
3145+
to save_for_backward caused 'tensor does not have a device' errors.
3146+
"""
3147+
N = 50
3148+
D = 3
3149+
C = 2
3150+
3151+
sh0 = torch.randn(N, 1, D, device=self.device)
3152+
3153+
# Test degree 0 with no optional parameters
3154+
result = evaluate_spherical_harmonics(
3155+
sh_degree=0,
3156+
num_cameras=C,
3157+
sh0=sh0,
3158+
)
3159+
self.assertEqual(result.shape, (C, N, D))
3160+
self.assertFalse(torch.isnan(result).any())
3161+
3162+
# Test degree 0 with explicit None for optional parameters
3163+
result = evaluate_spherical_harmonics(
3164+
sh_degree=0,
3165+
num_cameras=C,
3166+
sh0=sh0,
3167+
shN=None,
3168+
view_directions=None,
3169+
radii=None,
3170+
)
3171+
self.assertEqual(result.shape, (C, N, D))
3172+
self.assertFalse(torch.isnan(result).any())
3173+
3174+
# Test higher degree without radii
3175+
shN = torch.randn(N, 15, D, device=self.device)
3176+
view_dirs = torch.randn(C, N, 3, device=self.device)
3177+
3178+
result = evaluate_spherical_harmonics(
3179+
sh_degree=3,
3180+
num_cameras=C,
3181+
sh0=sh0,
3182+
shN=shN,
3183+
view_directions=view_dirs,
3184+
# radii intentionally omitted
3185+
)
3186+
self.assertEqual(result.shape, (C, N, D))
3187+
self.assertFalse(torch.isnan(result).any())
3188+
3189+
# Test higher degree with explicit radii=None
3190+
result = evaluate_spherical_harmonics(
3191+
sh_degree=3,
3192+
num_cameras=C,
3193+
sh0=sh0,
3194+
shN=shN,
3195+
view_directions=view_dirs,
3196+
radii=None,
3197+
)
3198+
self.assertEqual(result.shape, (C, N, D))
3199+
self.assertFalse(torch.isnan(result).any())
3200+
3201+
def test_optional_parameters_gradient_flow(self):
3202+
"""Test gradient flow works when optional parameters are omitted.
3203+
3204+
Regression test to ensure backward pass works without radii.
3205+
"""
3206+
N = 20
3207+
D = 3
3208+
C = 2
3209+
3210+
sh0 = torch.randn(N, 1, D, device=self.device, requires_grad=True)
3211+
shN = torch.randn(N, 15, D, device=self.device, requires_grad=True)
3212+
view_dirs = torch.randn(C, N, 3, device=self.device)
3213+
3214+
# Forward without radii
3215+
result = evaluate_spherical_harmonics(
3216+
sh_degree=3,
3217+
num_cameras=C,
3218+
sh0=sh0,
3219+
shN=shN,
3220+
view_directions=view_dirs,
3221+
# radii intentionally omitted
3222+
)
3223+
3224+
# Backward
3225+
loss = result.sum()
3226+
loss.backward()
3227+
3228+
# Check gradients exist and are valid
3229+
self.assertIsNotNone(sh0.grad)
3230+
self.assertIsNotNone(shN.grad)
3231+
assert sh0.grad is not None # for type narrowing
3232+
assert shN.grad is not None # for type narrowing
3233+
self.assertTrue(torch.any(sh0.grad != 0))
3234+
self.assertTrue(torch.any(shN.grad != 0))
3235+
self.assertFalse(torch.isnan(sh0.grad).any())
3236+
self.assertFalse(torch.isnan(shN.grad).any())
3237+
31413238
def test_degree_0_matches_expected(self):
31423239
"""Test that degree 0 SH evaluation produces expected output."""
31433240
N = 10

0 commit comments

Comments
 (0)