@@ -3138,6 +3138,103 @@ def test_degree_0_basic(self):
31383138 # since there's no view dependence
31393139 self .assertTrue (torch .allclose (result [0 ], result [1 ], atol = 1e-6 ))
31403140
3141+ def test_optional_parameters_omitted (self ):
3142+ """Test that optional parameters (shN, view_directions, radii) can be omitted.
3143+
3144+ This is a regression test for an issue where passing undefined tensors
3145+ to save_for_backward caused 'tensor does not have a device' errors.
3146+ """
3147+ N = 50
3148+ D = 3
3149+ C = 2
3150+
3151+ sh0 = torch .randn (N , 1 , D , device = self .device )
3152+
3153+ # Test degree 0 with no optional parameters
3154+ result = evaluate_spherical_harmonics (
3155+ sh_degree = 0 ,
3156+ num_cameras = C ,
3157+ sh0 = sh0 ,
3158+ )
3159+ self .assertEqual (result .shape , (C , N , D ))
3160+ self .assertFalse (torch .isnan (result ).any ())
3161+
3162+ # Test degree 0 with explicit None for optional parameters
3163+ result = evaluate_spherical_harmonics (
3164+ sh_degree = 0 ,
3165+ num_cameras = C ,
3166+ sh0 = sh0 ,
3167+ shN = None ,
3168+ view_directions = None ,
3169+ radii = None ,
3170+ )
3171+ self .assertEqual (result .shape , (C , N , D ))
3172+ self .assertFalse (torch .isnan (result ).any ())
3173+
3174+ # Test higher degree without radii
3175+ shN = torch .randn (N , 15 , D , device = self .device )
3176+ view_dirs = torch .randn (C , N , 3 , device = self .device )
3177+
3178+ result = evaluate_spherical_harmonics (
3179+ sh_degree = 3 ,
3180+ num_cameras = C ,
3181+ sh0 = sh0 ,
3182+ shN = shN ,
3183+ view_directions = view_dirs ,
3184+ # radii intentionally omitted
3185+ )
3186+ self .assertEqual (result .shape , (C , N , D ))
3187+ self .assertFalse (torch .isnan (result ).any ())
3188+
3189+ # Test higher degree with explicit radii=None
3190+ result = evaluate_spherical_harmonics (
3191+ sh_degree = 3 ,
3192+ num_cameras = C ,
3193+ sh0 = sh0 ,
3194+ shN = shN ,
3195+ view_directions = view_dirs ,
3196+ radii = None ,
3197+ )
3198+ self .assertEqual (result .shape , (C , N , D ))
3199+ self .assertFalse (torch .isnan (result ).any ())
3200+
3201+ def test_optional_parameters_gradient_flow (self ):
3202+ """Test gradient flow works when optional parameters are omitted.
3203+
3204+ Regression test to ensure backward pass works without radii.
3205+ """
3206+ N = 20
3207+ D = 3
3208+ C = 2
3209+
3210+ sh0 = torch .randn (N , 1 , D , device = self .device , requires_grad = True )
3211+ shN = torch .randn (N , 15 , D , device = self .device , requires_grad = True )
3212+ view_dirs = torch .randn (C , N , 3 , device = self .device )
3213+
3214+ # Forward without radii
3215+ result = evaluate_spherical_harmonics (
3216+ sh_degree = 3 ,
3217+ num_cameras = C ,
3218+ sh0 = sh0 ,
3219+ shN = shN ,
3220+ view_directions = view_dirs ,
3221+ # radii intentionally omitted
3222+ )
3223+
3224+ # Backward
3225+ loss = result .sum ()
3226+ loss .backward ()
3227+
3228+ # Check gradients exist and are valid
3229+ self .assertIsNotNone (sh0 .grad )
3230+ self .assertIsNotNone (shN .grad )
3231+ assert sh0 .grad is not None # for type narrowing
3232+ assert shN .grad is not None # for type narrowing
3233+ self .assertTrue (torch .any (sh0 .grad != 0 ))
3234+ self .assertTrue (torch .any (shN .grad != 0 ))
3235+ self .assertFalse (torch .isnan (sh0 .grad ).any ())
3236+ self .assertFalse (torch .isnan (shN .grad ).any ())
3237+
31413238 def test_degree_0_matches_expected (self ):
31423239 """Test that degree 0 SH evaluation produces expected output."""
31433240 N = 10
0 commit comments