1+ from typing import Literal
2+
13import numpy as np
24import pytest
35import torch
6+ from gsplat .cuda ._torch_impl import (
7+ _eval_sh_bases_fast as gsplat_eval_sh_bases ,
8+ _spherical_harmonics as gsplat_spherical_harmonics ,
9+ )
410from scipy .spatial .transform import Rotation as ScR
511
612from nerfstudio .utils .spherical_harmonics import (
@@ -23,7 +29,7 @@ def test_spherical_harmonics_components(degree):
2329
2430
2531@pytest .mark .parametrize ("sh_degree" , list (range (0 , 4 )))
26- def test_spherical_harmonics_rotation (sh_degree ):
32+ def test_spherical_harmonics_rotation_nerfacto (sh_degree ):
2733 """Test if rotating both the view direction and SH coefficients by the same rotation
2834 produces the same color output as the original.
2935
@@ -43,7 +49,7 @@ def test_spherical_harmonics_rotation(sh_degree):
4349 color_original = (sh_coeffs * y_lm [..., None , :]).sum (dim = - 1 )
4450
4551 rot_matrix = torch .tensor (ScR .random ().as_matrix (), dtype = torch .float32 )
46- sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs )
52+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , component_convention = "+y,+z,+x" )
4753 dirs_rotated = (rot_matrix @ dirs .T ).T
4854 y_lm_rotated = components_from_spherical_harmonics (sh_degree , dirs_rotated )
4955 color_rotated = (sh_coeffs_rotated * y_lm_rotated [..., None , :]).sum (dim = - 1 )
@@ -52,7 +58,42 @@ def test_spherical_harmonics_rotation(sh_degree):
5258
5359
5460@pytest .mark .parametrize ("sh_degree" , list (range (0 , 4 )))
55- def test_spherical_harmonics_rotation_properties (sh_degree ):
61+ def test_spherical_harmonics_rotation_splatfacto (sh_degree ):
62+ """Test if rotating both the view direction and SH coefficients by the same rotation
63+ produces the same color output as the original.
64+
65+ In other words, for any rotation R:
66+ color(dir, coeffs) = color(R @ dir, rotate_sh(R, coeffs))
67+ """
68+ torch .manual_seed (0 )
69+ np .random .seed (0 )
70+
71+ N = 1000
72+ num_coeffs = (sh_degree + 1 ) ** 2
73+ sh_coeffs = torch .rand (N , 3 , num_coeffs )
74+ dirs = torch .rand (N , 3 )
75+ dirs = dirs / torch .linalg .norm (dirs , dim = - 1 , keepdim = True )
76+
77+ assert dirs .shape == (N , 3 )
78+ y_lm = gsplat_eval_sh_bases (num_coeffs , dirs )
79+ color_original = (sh_coeffs * y_lm [..., None , :]).sum (dim = - 1 )
80+
81+ rot_matrix = torch .tensor (ScR .random ().as_matrix (), dtype = torch .float32 )
82+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , component_convention = "-y,+z,-x" )
83+ dirs_rotated = (rot_matrix @ dirs .T ).T
84+ assert dirs_rotated .shape == (N , 3 )
85+ y_lm_rotated = gsplat_eval_sh_bases (num_coeffs , dirs_rotated )
86+ color_rotated = (sh_coeffs_rotated * y_lm_rotated [..., None , :]).sum (dim = - 1 )
87+
88+ torch .testing .assert_close (
89+ gsplat_spherical_harmonics (sh_degree , coeffs = sh_coeffs .swapaxes (- 1 , - 2 ), dirs = dirs ),
90+ gsplat_spherical_harmonics (sh_degree , coeffs = sh_coeffs_rotated .swapaxes (- 1 , - 2 ), dirs = dirs_rotated ),
91+ )
92+
93+
94+ @pytest .mark .parametrize ("sh_degree" , list (range (0 , 4 )))
95+ @pytest .mark .parametrize ("component_convention" , ["+y,+z,+x" , "-y,+z,-x" ])
96+ def test_spherical_harmonics_rotation_properties (sh_degree : int , component_convention : Literal ["+y,+z,+x" , "-y,+z,-x" ]):
5697 """Test properties of the SH rotation"""
5798 torch .manual_seed (0 )
5899 np .random .seed (0 )
@@ -61,7 +102,7 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
61102 num_coeffs = (sh_degree + 1 ) ** 2
62103 sh_coeffs = torch .rand (N , 3 , num_coeffs )
63104 rot_matrix = torch .tensor (ScR .random ().as_matrix (), dtype = torch .float32 )
64- sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs )
105+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , component_convention )
65106
66107 # Norm preserving
67108 norm_original = torch .norm (sh_coeffs , dim = - 1 )
@@ -73,5 +114,5 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
73114
74115 # Identity rotation
75116 rot_matrix = torch .eye (3 )
76- sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs )
117+ sh_coeffs_rotated = rotate_spherical_harmonics (rot_matrix , sh_coeffs , ordering )
77118 torch .testing .assert_close (sh_coeffs , sh_coeffs_rotated , rtol = 0 , atol = 1e-6 )
0 commit comments