Skip to content

Commit a739636

Browse files
committed
Account for nerfacto vs splatfacto spherical harmonics differences
1 parent 43abb0f commit a739636

File tree

4 files changed

+82
-17
lines changed

4 files changed

+82
-17
lines changed

nerfstudio/scripts/exporter.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@
3232
import open3d as o3d
3333
import torch
3434
import tyro
35-
from scipy.spatial.transform import Rotation as ScR
36-
from typing_extensions import Annotated, Literal
37-
3835
from nerfstudio.cameras.rays import RayBundle
3936
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager
4037
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager
@@ -50,6 +47,8 @@
5047
from nerfstudio.utils.eval_utils import eval_setup
5148
from nerfstudio.utils.rich_utils import CONSOLE
5249
from nerfstudio.utils.spherical_harmonics import rotate_spherical_harmonics
50+
from scipy.spatial.transform import Rotation as ScR
51+
from typing_extensions import Annotated, Literal
5352

5453

5554
@dataclass
@@ -636,9 +635,11 @@ def main(self) -> None:
636635
dim_sh_all = shs_rest.shape[-1] + 1
637636
shs_coeffs_all = torch.zeros((n, 3, dim_sh_all), device=shs_rest.device)
638637
shs_coeffs_all[:, :, 1:] = shs_rest
639-
# TODO: check output rotation
640-
output_rotation = pipeline.datamanager.train_dataparser_outputs.dataparser_transform[:3, :3]
641-
shs_rest = rotate_spherical_harmonics(output_rotation, shs_coeffs_all)[:, :, 1:]
638+
shs_rest = rotate_spherical_harmonics(
639+
pipeline.datamanager.train_dataparser_outputs.dataparser_transform[:3, :3].T,
640+
shs_coeffs_all,
641+
component_convention="-y,+z,-x",
642+
)[:, :, 1:]
642643

643644
shs_rest = shs_rest.cpu().numpy().reshape((n, -1))
644645

nerfstudio/utils/spherical_harmonics.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
"""Sphecal Harmonics utils."""
1616

1717
import math
18+
from typing import Literal
1819

1920
import torch
2021
from e3nn.o3 import Irreps
2122
from jaxtyping import Float
2223
from torch import Tensor
24+
from typing_extensions import assert_never
2325

2426
MAX_SH_DEGREE = 4
2527

@@ -117,12 +119,16 @@ def SH2RGB(sh):
117119
def rotate_spherical_harmonics(
118120
rotation_matrix: Float[Tensor, "3 3"],
119121
coeffs: Float[Tensor, "*batch dim_sh"],
122+
component_convention: Literal["-y,+z,-x", "+y,+z,+x"],
120123
) -> Float[Tensor, "*batch dim_sh"]:
121124
"""Rotates real spherical harmonic coefficients using a given 3x3 rotation matrix.
122125
123126
Args:
124127
rotation_matrix : A 3x3 rotation matrix.
125128
coeffs : SH coefficients
129+
component_convention: Component convention for spherical harmonics.
130+
Nerfstudio (nerfacto) uses +y,+z,+x, while gsplat (splatfacto) uses
131+
-y,+z,-x.
126132
127133
Returns:
128134
The rotated SH coefficients
@@ -132,11 +138,31 @@ def rotate_spherical_harmonics(
132138
sh_degree = int(math.sqrt(dim_sh)) - 1
133139

134140
# e3nn uses the xyz ordering instead of the standard yzx used in ns, equivalent to a change of basis
135-
R_yzx_to_xyz = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)
136-
R_total = (R_yzx_to_xyz.T @ rotation_matrix @ R_yzx_to_xyz).cpu()
141+
if component_convention == "+y,+z,+x":
142+
R_xyz_from_yzx = torch.tensor(
143+
[
144+
[0, 0, 1],
145+
[1, 0, 0],
146+
[0, 1, 0],
147+
],
148+
dtype=torch.float32,
149+
)
150+
rotation_matrix = (R_xyz_from_yzx.T @ rotation_matrix @ R_xyz_from_yzx).cpu()
151+
elif component_convention == "-y,+z,-x":
152+
R_xyz_from_negyznegx = torch.tensor(
153+
[
154+
[0, 0, -1],
155+
[-1, 0, 0],
156+
[0, 1, 0],
157+
],
158+
dtype=torch.float32,
159+
)
160+
rotation_matrix = (R_xyz_from_negyznegx.T @ rotation_matrix @ R_xyz_from_negyznegx).cpu()
161+
else:
162+
assert_never(component_convention)
137163

138164
irreps = Irreps(" + ".join([f"{i}e" for i in range(sh_degree + 1)])) # Even parity spherical harmonics of degree l
139-
D_matrix = irreps.D_from_matrix(R_total).to(coeffs.device) # Construct Wigner D-matrix
165+
D_matrix = irreps.D_from_matrix(rotation_matrix).to(coeffs.device) # Construct Wigner D-matrix
140166

141167
# Multiply last dimension of coeffs (..., dim_sh) with the Wigner D-matrix (dim_sh, dim_sh)
142168
rotated_coeffs = coeffs @ D_matrix.T

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ include = ["nerfstudio*"]
147147
"*" = ["*.cu", "*.json", "py.typed", "setup.bash", "setup.zsh"]
148148

149149
[tool.pytest.ini_options]
150-
addopts = "-n=4 --typeguard-packages=nerfstudio --disable-warnings"
150+
# addopts = "-n=4 --typeguard-packages=nerfstudio --disable-warnings"
151151
testpaths = [
152152
"tests",
153153
]

tests/utils/test_spherical_harmonics.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from typing import Literal
2+
13
import numpy as np
24
import pytest
35
import torch
4-
from scipy.spatial.transform import Rotation as ScR
5-
6+
from gsplat.cuda._torch_impl import _eval_sh_bases_fast as gsplat_eval_sh_bases
7+
from gsplat.cuda._torch_impl import _spherical_harmonics as gsplat_spherical_harmonics
68
from nerfstudio.utils.spherical_harmonics import (
79
components_from_spherical_harmonics,
810
num_sh_bases,
911
rotate_spherical_harmonics,
1012
)
13+
from scipy.spatial.transform import Rotation as ScR
1114

1215

1316
@pytest.mark.parametrize("degree", list(range(0, 5)))
@@ -23,7 +26,7 @@ def test_spherical_harmonics_components(degree):
2326

2427

2528
@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
26-
def test_spherical_harmonics_rotation(sh_degree):
29+
def test_spherical_harmonics_rotation_nerfacto(sh_degree):
2730
"""Test if rotating both the view direction and SH coefficients by the same rotation
2831
produces the same color output as the original.
2932
@@ -43,7 +46,7 @@ def test_spherical_harmonics_rotation(sh_degree):
4346
color_original = (sh_coeffs * y_lm[..., None, :]).sum(dim=-1)
4447

4548
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
46-
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
49+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention="+y,+z,+x")
4750
dirs_rotated = (rot_matrix @ dirs.T).T
4851
y_lm_rotated = components_from_spherical_harmonics(sh_degree, dirs_rotated)
4952
color_rotated = (sh_coeffs_rotated * y_lm_rotated[..., None, :]).sum(dim=-1)
@@ -52,7 +55,42 @@ def test_spherical_harmonics_rotation(sh_degree):
5255

5356

5457
@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
55-
def test_spherical_harmonics_rotation_properties(sh_degree):
58+
def test_spherical_harmonics_rotation_splatfacto(sh_degree):
59+
"""Test if rotating both the view direction and SH coefficients by the same rotation
60+
produces the same color output as the original.
61+
62+
In other words, for any rotation R:
63+
color(dir, coeffs) = color(R @ dir, rotate_sh(R, coeffs))
64+
"""
65+
torch.manual_seed(0)
66+
np.random.seed(0)
67+
68+
N = 1000
69+
num_coeffs = (sh_degree + 1) ** 2
70+
sh_coeffs = torch.rand(N, 3, num_coeffs)
71+
dirs = torch.rand(N, 3)
72+
dirs = dirs / torch.linalg.norm(dirs, dim=-1, keepdim=True)
73+
74+
assert dirs.shape == (N, 3)
75+
y_lm = gsplat_eval_sh_bases(num_coeffs, dirs)
76+
color_original = (sh_coeffs * y_lm[..., None, :]).sum(dim=-1)
77+
78+
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
79+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention="-y,+z,-x")
80+
dirs_rotated = (rot_matrix @ dirs.T).T
81+
assert dirs_rotated.shape == (N, 3)
82+
y_lm_rotated = gsplat_eval_sh_bases(num_coeffs, dirs_rotated)
83+
color_rotated = (sh_coeffs_rotated * y_lm_rotated[..., None, :]).sum(dim=-1)
84+
85+
torch.testing.assert_close(
86+
gsplat_spherical_harmonics(sh_degree, coeffs=sh_coeffs.swapaxes(-1, -2), dirs=dirs),
87+
gsplat_spherical_harmonics(sh_degree, coeffs=sh_coeffs_rotated.swapaxes(-1, -2), dirs=dirs_rotated),
88+
)
89+
90+
91+
@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
92+
@pytest.mark.parametrize("component_convention", ["+y,+z,+x", "-y,+z,-x"])
93+
def test_spherical_harmonics_rotation_properties(sh_degree: int, component_convention: Literal["+y,+z,+x", "-y,+z,-x"]):
5694
"""Test properties of the SH rotation"""
5795
torch.manual_seed(0)
5896
np.random.seed(0)
@@ -61,7 +99,7 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
6199
num_coeffs = (sh_degree + 1) ** 2
62100
sh_coeffs = torch.rand(N, 3, num_coeffs)
63101
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
64-
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
102+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention)
65103

66104
# Norm preserving
67105
norm_original = torch.norm(sh_coeffs, dim=-1)
@@ -73,5 +111,5 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
73111

74112
# Identity rotation
75113
rot_matrix = torch.eye(3)
76-
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
114+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, ordering)
77115
torch.testing.assert_close(sh_coeffs, sh_coeffs_rotated, rtol=0, atol=1e-6)

0 commit comments

Comments
 (0)