@@ -3126,115 +3126,20 @@ def test_degree_0_basic(self):
31263126 C = 2 # number of cameras
31273127
31283128 sh0 = torch .randn (N , 1 , D , device = self .device )
3129+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
31293130
31303131 result = evaluate_spherical_harmonics (
31313132 sh_degree = 0 ,
31323133 num_cameras = C ,
31333134 sh0 = sh0 ,
3135+ radii = radii ,
31343136 )
31353137
31363138 self .assertEqual (result .shape , (C , N , D ))
31373139 # For degree 0, the result should be the same for all cameras
31383140 # since there's no view dependence
31393141 self .assertTrue (torch .allclose (result [0 ], result [1 ], atol = 1e-6 ))
31403142
3141- def test_optional_parameters_omitted (self ):
3142- """Test that optional parameters (shN, view_directions, radii) can be omitted.
3143-
3144- This is a regression test for an issue where passing undefined tensors
3145- to save_for_backward caused 'tensor does not have a device' errors.
3146- """
3147- N = 50
3148- D = 3
3149- C = 2
3150-
3151- sh0 = torch .randn (N , 1 , D , device = self .device )
3152-
3153- # Test degree 0 with no optional parameters
3154- result = evaluate_spherical_harmonics (
3155- sh_degree = 0 ,
3156- num_cameras = C ,
3157- sh0 = sh0 ,
3158- )
3159- self .assertEqual (result .shape , (C , N , D ))
3160- self .assertFalse (torch .isnan (result ).any ())
3161-
3162- # Test degree 0 with explicit None for optional parameters
3163- result = evaluate_spherical_harmonics (
3164- sh_degree = 0 ,
3165- num_cameras = C ,
3166- sh0 = sh0 ,
3167- shN = None ,
3168- view_directions = None ,
3169- radii = None ,
3170- )
3171- self .assertEqual (result .shape , (C , N , D ))
3172- self .assertFalse (torch .isnan (result ).any ())
3173-
3174- # Test higher degree without radii
3175- shN = torch .randn (N , 15 , D , device = self .device )
3176- view_dirs = torch .randn (C , N , 3 , device = self .device )
3177-
3178- result = evaluate_spherical_harmonics (
3179- sh_degree = 3 ,
3180- num_cameras = C ,
3181- sh0 = sh0 ,
3182- shN = shN ,
3183- view_directions = view_dirs ,
3184- # radii intentionally omitted
3185- )
3186- self .assertEqual (result .shape , (C , N , D ))
3187- self .assertFalse (torch .isnan (result ).any ())
3188-
3189- # Test higher degree with explicit radii=None
3190- result = evaluate_spherical_harmonics (
3191- sh_degree = 3 ,
3192- num_cameras = C ,
3193- sh0 = sh0 ,
3194- shN = shN ,
3195- view_directions = view_dirs ,
3196- radii = None ,
3197- )
3198- self .assertEqual (result .shape , (C , N , D ))
3199- self .assertFalse (torch .isnan (result ).any ())
3200-
3201- def test_optional_parameters_gradient_flow (self ):
3202- """Test gradient flow works when optional parameters are omitted.
3203-
3204- Regression test to ensure backward pass works without radii.
3205- """
3206- N = 20
3207- D = 3
3208- C = 2
3209-
3210- sh0 = torch .randn (N , 1 , D , device = self .device , requires_grad = True )
3211- shN = torch .randn (N , 15 , D , device = self .device , requires_grad = True )
3212- view_dirs = torch .randn (C , N , 3 , device = self .device )
3213-
3214- # Forward without radii
3215- result = evaluate_spherical_harmonics (
3216- sh_degree = 3 ,
3217- num_cameras = C ,
3218- sh0 = sh0 ,
3219- shN = shN ,
3220- view_directions = view_dirs ,
3221- # radii intentionally omitted
3222- )
3223-
3224- # Backward
3225- loss = result .sum ()
3226- loss .backward ()
3227-
3228- # Check gradients exist and are valid
3229- self .assertIsNotNone (sh0 .grad )
3230- self .assertIsNotNone (shN .grad )
3231- assert sh0 .grad is not None # for type narrowing
3232- assert shN .grad is not None # for type narrowing
3233- self .assertTrue (torch .any (sh0 .grad != 0 ))
3234- self .assertTrue (torch .any (shN .grad != 0 ))
3235- self .assertFalse (torch .isnan (sh0 .grad ).any ())
3236- self .assertFalse (torch .isnan (shN .grad ).any ())
3237-
32383143 def test_degree_0_matches_expected (self ):
32393144 """Test that degree 0 SH evaluation produces expected output."""
32403145 N = 10
@@ -3243,11 +3148,13 @@ def test_degree_0_matches_expected(self):
32433148
32443149 # Known sh0 values
32453150 sh0 = torch .ones (N , 1 , D , device = self .device )
3151+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
32463152
32473153 result = evaluate_spherical_harmonics (
32483154 sh_degree = 0 ,
32493155 num_cameras = C ,
32503156 sh0 = sh0 ,
3157+ radii = radii ,
32513158 )
32523159
32533160 # For degree 0: result = 0.2820947917738781 * sh0 + 0.5
@@ -3264,13 +3171,15 @@ def test_degree_1_requires_view_directions(self):
32643171
32653172 sh0 = torch .randn (N , 1 , D , device = self .device )
32663173 shN = torch .randn (N , 3 , D , device = self .device ) # 3 coefficients for degree 1
3174+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
32673175
3268- # Should raise error when view_directions is not provided for degree > 0
3269- with self .assertRaises (RuntimeError ):
3176+ # Should raise ValueError when view_directions is not provided for degree > 0
3177+ with self .assertRaises (ValueError ):
32703178 evaluate_spherical_harmonics (
32713179 sh_degree = 1 ,
32723180 num_cameras = C ,
32733181 sh0 = sh0 ,
3182+ radii = radii ,
32743183 shN = shN ,
32753184 view_directions = None ,
32763185 )
@@ -3281,6 +3190,7 @@ def test_degree_1_requires_view_directions(self):
32813190 sh_degree = 1 ,
32823191 num_cameras = C ,
32833192 sh0 = sh0 ,
3193+ radii = radii ,
32843194 shN = shN ,
32853195 view_directions = view_dirs ,
32863196 )
@@ -3296,11 +3206,13 @@ def test_degree_3_full(self):
32963206 # Degree 3 has (3+1)^2 = 16 bases, so K-1 = 15 higher order coefficients
32973207 shN = torch .randn (N , 15 , D , device = self .device )
32983208 view_dirs = torch .randn (C , N , 3 , device = self .device )
3209+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
32993210
33003211 result = evaluate_spherical_harmonics (
33013212 sh_degree = 3 ,
33023213 num_cameras = C ,
33033214 sh0 = sh0 ,
3215+ radii = radii ,
33043216 shN = shN ,
33053217 view_directions = view_dirs ,
33063218 )
@@ -3343,11 +3255,15 @@ def test_gradient_flow_sh0(self):
33433255 C = 1
33443256
33453257 sh0 = torch .randn (N , 1 , D , device = self .device , requires_grad = True )
3258+ # Note: radii must be provided for backward pass to work correctly
3259+ # (matches GaussianSplat3d usage pattern)
3260+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
33463261
33473262 result = evaluate_spherical_harmonics (
33483263 sh_degree = 0 ,
33493264 num_cameras = C ,
33503265 sh0 = sh0 ,
3266+ radii = radii ,
33513267 )
33523268
33533269 loss = result .sum ()
@@ -3365,11 +3281,14 @@ def test_gradient_flow_shN(self):
33653281 sh0 = torch .randn (N , 1 , D , device = self .device , requires_grad = True )
33663282 shN = torch .randn (N , 15 , D , device = self .device , requires_grad = True )
33673283 view_dirs = torch .randn (C , N , 3 , device = self .device )
3284+ # Note: radii must be provided for backward pass to work correctly
3285+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
33683286
33693287 result = evaluate_spherical_harmonics (
33703288 sh_degree = 3 ,
33713289 num_cameras = C ,
33723290 sh0 = sh0 ,
3291+ radii = radii ,
33733292 shN = shN ,
33743293 view_directions = view_dirs ,
33753294 )
@@ -3391,11 +3310,13 @@ def test_gradient_flow_view_directions(self):
33913310 sh0 = torch .randn (N , 1 , D , device = self .device )
33923311 shN = torch .randn (N , 15 , D , device = self .device )
33933312 view_dirs = torch .randn (C , N , 3 , device = self .device , requires_grad = True )
3313+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
33943314
33953315 result = evaluate_spherical_harmonics (
33963316 sh_degree = 3 ,
33973317 num_cameras = C ,
33983318 sh0 = sh0 ,
3319+ radii = radii ,
33993320 shN = shN ,
34003321 view_directions = view_dirs ,
34013322 )
@@ -3415,11 +3336,13 @@ def test_single_gaussian(self):
34153336 sh0 = torch .randn (N , 1 , D , device = self .device )
34163337 shN = torch .randn (N , 15 , D , device = self .device )
34173338 view_dirs = torch .randn (C , N , 3 , device = self .device )
3339+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
34183340
34193341 result = evaluate_spherical_harmonics (
34203342 sh_degree = 3 ,
34213343 num_cameras = C ,
34223344 sh0 = sh0 ,
3345+ radii = radii ,
34233346 shN = shN ,
34243347 view_directions = view_dirs ,
34253348 )
@@ -3435,11 +3358,13 @@ def test_many_channels(self):
34353358 sh0 = torch .randn (N , 1 , D , device = self .device )
34363359 shN = torch .randn (N , 15 , D , device = self .device )
34373360 view_dirs = torch .randn (C , N , 3 , device = self .device )
3361+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
34383362
34393363 result = evaluate_spherical_harmonics (
34403364 sh_degree = 3 ,
34413365 num_cameras = C ,
34423366 sh0 = sh0 ,
3367+ radii = radii ,
34433368 shN = shN ,
34443369 view_directions = view_dirs ,
34453370 )
@@ -3457,11 +3382,13 @@ def test_different_sh_degrees(self, sh_degree):
34573382 K = (sh_degree + 1 ) ** 2
34583383 shN = torch .randn (N , K - 1 , D , device = self .device )
34593384 view_dirs = torch .randn (C , N , 3 , device = self .device )
3385+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
34603386
34613387 result = evaluate_spherical_harmonics (
34623388 sh_degree = sh_degree ,
34633389 num_cameras = C ,
34643390 sh0 = sh0 ,
3391+ radii = radii ,
34653392 shN = shN ,
34663393 view_directions = view_dirs ,
34673394 )
@@ -3476,6 +3403,7 @@ def test_view_directions_not_prenormalized(self):
34763403
34773404 sh0 = torch .randn (N , 1 , D , device = self .device )
34783405 shN = torch .randn (N , 15 , D , device = self .device )
3406+ radii = torch .ones (C , N , dtype = torch .int32 , device = self .device )
34793407
34803408 # Unnormalized view directions (varying magnitudes)
34813409 view_dirs = torch .randn (C , N , 3 , device = self .device ) * 10.0
@@ -3485,6 +3413,7 @@ def test_view_directions_not_prenormalized(self):
34853413 sh_degree = 3 ,
34863414 num_cameras = C ,
34873415 sh0 = sh0 ,
3416+ radii = radii ,
34883417 shN = shN ,
34893418 view_directions = view_dirs ,
34903419 )
0 commit comments