1+ from typing import Literal
2+
13import numpy as np
24import pytest
35import torch
4- from scipy . spatial . transform import Rotation as ScR
5-
6+ from gsplat . cuda . _torch_impl import _eval_sh_bases_fast as gsplat_eval_sh_bases
7+ from gsplat . cuda . _torch_impl import _spherical_harmonics as gsplat_spherical_harmonics
68from nerfstudio .utils .spherical_harmonics import (
79 components_from_spherical_harmonics ,
810 num_sh_bases ,
911 rotate_spherical_harmonics ,
1012)
13+ from scipy .spatial .transform import Rotation as ScR
1114
1215
1316@pytest .mark .parametrize ("degree" , list (range (0 , 5 )))
@@ -23,7 +26,7 @@ def test_spherical_harmonics_components(degree):
2326
2427
2528@pytest .mark .parametrize ("sh_degree" , list (range (0 , 4 )))
26- def test_spherical_harmonics_rotation (sh_degree ):
29+ def test_spherical_harmonics_rotation_nerfacto (sh_degree ):
2730 """Test if rotating both the view direction and SH coefficients by the same rotation
2831 produces the same color output as the original.
2932
@@ -43,7 +46,7 @@ def test_spherical_harmonics_rotation(sh_degree):
4346 color_original = (sh_coeffs * y_lm [..., None , :]).sum (dim = - 1 )
4447
4548 rot_matrix = torch .tensor (ScR .random ().as_matrix (), dtype = torch .float32 )
46- sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs )
49+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , component_convention = "+y,+z,+x" )
4750 dirs_rotated = (rot_matrix @ dirs .T ).T
4851 y_lm_rotated = components_from_spherical_harmonics (sh_degree , dirs_rotated )
4952 color_rotated = (sh_coeffs_rotated * y_lm_rotated [..., None , :]).sum (dim = - 1 )
@@ -52,7 +55,42 @@ def test_spherical_harmonics_rotation(sh_degree):
5255
5356
5457@pytest .mark .parametrize ("sh_degree" , list (range (0 , 4 )))
55- def test_spherical_harmonics_rotation_properties (sh_degree ):
58+ def test_spherical_harmonics_rotation_splatfacto (sh_degree ):
59+ """Test if rotating both the view direction and SH coefficients by the same rotation
60+ produces the same color output as the original.
61+
62+ In other words, for any rotation R:
63+ color(dir, coeffs) = color(R @ dir, rotate_sh(R, coeffs))
64+ """
65+ torch .manual_seed (0 )
66+ np .random .seed (0 )
67+
68+ N = 1000
69+ num_coeffs = (sh_degree + 1 ) ** 2
70+ sh_coeffs = torch .rand (N , 3 , num_coeffs )
71+ dirs = torch .rand (N , 3 )
72+ dirs = dirs / torch .linalg .norm (dirs , dim = - 1 , keepdim = True )
73+
74+ assert dirs .shape == (N , 3 )
75+ y_lm = gsplat_eval_sh_bases (num_coeffs , dirs )
76+ color_original = (sh_coeffs * y_lm [..., None , :]).sum (dim = - 1 )
77+
78+ rot_matrix = torch .tensor (ScR .random ().as_matrix (), dtype = torch .float32 )
79+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , component_convention = "-y,+z,-x" )
80+ dirs_rotated = (rot_matrix @ dirs .T ).T
81+ assert dirs_rotated .shape == (N , 3 )
82+ y_lm_rotated = gsplat_eval_sh_bases (num_coeffs , dirs_rotated )
83+ color_rotated = (sh_coeffs_rotated * y_lm_rotated [..., None , :]).sum (dim = - 1 )
84+
85+ torch .testing .assert_close (
86+ gsplat_spherical_harmonics (sh_degree , coeffs = sh_coeffs .swapaxes (- 1 , - 2 ), dirs = dirs ),
87+ gsplat_spherical_harmonics (sh_degree , coeffs = sh_coeffs_rotated .swapaxes (- 1 , - 2 ), dirs = dirs_rotated ),
88+ )
89+
90+
91+ @pytest .mark .parametrize ("sh_degree" , list (range (0 , 4 )))
92+ @pytest .mark .parametrize ("component_convention" , ["+y,+z,+x" , "-y,+z,-x" ])
93+ def test_spherical_harmonics_rotation_properties (sh_degree : int , component_convention : Literal ["+y,+z,+x" , "-y,+z,-x" ]):
5694 """Test properties of the SH rotation"""
5795 torch .manual_seed (0 )
5896 np .random .seed (0 )
@@ -61,7 +99,7 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
6199 num_coeffs = (sh_degree + 1 ) ** 2
62100 sh_coeffs = torch .rand (N , 3 , num_coeffs )
63101 rot_matrix = torch .tensor (ScR .random ().as_matrix (), dtype = torch .float32 )
64- sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs )
102+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , component_convention )
65103
66104 # Norm preserving
67105 norm_original = torch .norm (sh_coeffs , dim = - 1 )
@@ -73,5 +111,5 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
73111
74112 # Identity rotation
75113 rot_matrix = torch .eye (3 )
76- sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs )
114+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , ordering )
77115 torch .testing .assert_close (sh_coeffs , sh_coeffs_rotated , rtol = 0 , atol = 1e-6 )
0 commit comments