Skip to content

Commit 69d65b5

Browse files
committed
Account for nerfacto vs splatfacto spherical harmonics differences
1 parent 43abb0f commit 69d65b5

File tree

4 files changed

+81
-12
lines changed

4 files changed

+81
-12
lines changed

nerfstudio/scripts/exporter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,9 +636,11 @@ def main(self) -> None:
636636
dim_sh_all = shs_rest.shape[-1] + 1
637637
shs_coeffs_all = torch.zeros((n, 3, dim_sh_all), device=shs_rest.device)
638638
shs_coeffs_all[:, :, 1:] = shs_rest
639-
# TODO: check output rotation
640-
output_rotation = pipeline.datamanager.train_dataparser_outputs.dataparser_transform[:3, :3]
641-
shs_rest = rotate_spherical_harmonics(output_rotation, shs_coeffs_all)[:, :, 1:]
639+
shs_rest = rotate_spherical_harmonics(
640+
pipeline.datamanager.train_dataparser_outputs.dataparser_transform[:3, :3].T,
641+
shs_coeffs_all,
642+
component_convention="-y,+z,-x",
643+
)[:, :, 1:]
642644

643645
shs_rest = shs_rest.cpu().numpy().reshape((n, -1))
644646

nerfstudio/utils/spherical_harmonics.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
"""Sphecal Harmonics utils."""
1616

1717
import math
18+
from typing import Literal
1819

1920
import torch
2021
from e3nn.o3 import Irreps
2122
from jaxtyping import Float
2223
from torch import Tensor
24+
from typing_extensions import assert_never
2325

2426
MAX_SH_DEGREE = 4
2527

@@ -117,12 +119,16 @@ def SH2RGB(sh):
117119
def rotate_spherical_harmonics(
118120
rotation_matrix: Float[Tensor, "3 3"],
119121
coeffs: Float[Tensor, "*batch dim_sh"],
122+
component_convention: Literal["-y,+z,-x", "+y,+z,+x"],
120123
) -> Float[Tensor, "*batch dim_sh"]:
121124
"""Rotates real spherical harmonic coefficients using a given 3x3 rotation matrix.
122125
123126
Args:
124127
rotation_matrix : A 3x3 rotation matrix.
125128
coeffs : SH coefficients
129+
component_convention: Component convention for spherical harmonics.
130+
Nerfstudio (nerfacto) uses +y,+z,+x, while gsplat (splatfacto) uses
131+
-y,+z,-x.
126132
127133
Returns:
128134
The rotated SH coefficients
@@ -132,11 +138,31 @@ def rotate_spherical_harmonics(
132138
sh_degree = int(math.sqrt(dim_sh)) - 1
133139

134140
# e3nn uses the xyz ordering instead of the standard yzx used in ns, equivalent to a change of basis
135-
R_yzx_to_xyz = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)
136-
R_total = (R_yzx_to_xyz.T @ rotation_matrix @ R_yzx_to_xyz).cpu()
141+
if component_convention == "+y,+z,+x":
142+
R_xyz_from_yzx = torch.tensor(
143+
[
144+
[0, 0, 1],
145+
[1, 0, 0],
146+
[0, 1, 0],
147+
],
148+
dtype=torch.float32,
149+
)
150+
rotation_matrix = (R_xyz_from_yzx.T @ rotation_matrix @ R_xyz_from_yzx).cpu()
151+
elif component_convention == "-y,+z,-x":
152+
R_xyz_from_negyznegx = torch.tensor(
153+
[
154+
[0, 0, -1],
155+
[-1, 0, 0],
156+
[0, 1, 0],
157+
],
158+
dtype=torch.float32,
159+
)
160+
rotation_matrix = (R_xyz_from_negyznegx.T @ rotation_matrix @ R_xyz_from_negyznegx).cpu()
161+
else:
162+
assert_never(component_convention)
137163

138164
irreps = Irreps(" + ".join([f"{i}e" for i in range(sh_degree + 1)])) # Even parity spherical harmonics of degree l
139-
D_matrix = irreps.D_from_matrix(R_total).to(coeffs.device) # Construct Wigner D-matrix
165+
D_matrix = irreps.D_from_matrix(rotation_matrix).to(coeffs.device) # Construct Wigner D-matrix
140166

141167
# Multiply last dimension of coeffs (..., dim_sh) with the Wigner D-matrix (dim_sh, dim_sh)
142168
rotated_coeffs = coeffs @ D_matrix.T

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ include = ["nerfstudio*"]
147147
"*" = ["*.cu", "*.json", "py.typed", "setup.bash", "setup.zsh"]
148148

149149
[tool.pytest.ini_options]
150-
addopts = "-n=4 --typeguard-packages=nerfstudio --disable-warnings"
150+
# addopts = "-n=4 --typeguard-packages=nerfstudio --disable-warnings"
151151
testpaths = [
152152
"tests",
153153
]

tests/utils/test_spherical_harmonics.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
from typing import Literal
2+
13
import numpy as np
24
import pytest
35
import torch
6+
from gsplat.cuda._torch_impl import (
7+
_eval_sh_bases_fast as gsplat_eval_sh_bases,
8+
_spherical_harmonics as gsplat_spherical_harmonics,
9+
)
410
from scipy.spatial.transform import Rotation as ScR
511

612
from nerfstudio.utils.spherical_harmonics import (
@@ -23,7 +29,7 @@ def test_spherical_harmonics_components(degree):
2329

2430

2531
@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
26-
def test_spherical_harmonics_rotation(sh_degree):
32+
def test_spherical_harmonics_rotation_nerfacto(sh_degree):
2733
"""Test if rotating both the view direction and SH coefficients by the same rotation
2834
produces the same color output as the original.
2935
@@ -43,7 +49,7 @@ def test_spherical_harmonics_rotation(sh_degree):
4349
color_original = (sh_coeffs * y_lm[..., None, :]).sum(dim=-1)
4450

4551
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
46-
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
52+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention="+y,+z,+x")
4753
dirs_rotated = (rot_matrix @ dirs.T).T
4854
y_lm_rotated = components_from_spherical_harmonics(sh_degree, dirs_rotated)
4955
color_rotated = (sh_coeffs_rotated * y_lm_rotated[..., None, :]).sum(dim=-1)
@@ -52,7 +58,42 @@ def test_spherical_harmonics_rotation(sh_degree):
5258

5359

5460
@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
55-
def test_spherical_harmonics_rotation_properties(sh_degree):
61+
def test_spherical_harmonics_rotation_splatfacto(sh_degree):
62+
"""Test if rotating both the view direction and SH coefficients by the same rotation
63+
produces the same color output as the original.
64+
65+
In other words, for any rotation R:
66+
color(dir, coeffs) = color(R @ dir, rotate_sh(R, coeffs))
67+
"""
68+
torch.manual_seed(0)
69+
np.random.seed(0)
70+
71+
N = 1000
72+
num_coeffs = (sh_degree + 1) ** 2
73+
sh_coeffs = torch.rand(N, 3, num_coeffs)
74+
dirs = torch.rand(N, 3)
75+
dirs = dirs / torch.linalg.norm(dirs, dim=-1, keepdim=True)
76+
77+
assert dirs.shape == (N, 3)
78+
y_lm = gsplat_eval_sh_bases(num_coeffs, dirs)
79+
color_original = (sh_coeffs * y_lm[..., None, :]).sum(dim=-1)
80+
81+
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
82+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention="-y,+z,-x")
83+
dirs_rotated = (rot_matrix @ dirs.T).T
84+
assert dirs_rotated.shape == (N, 3)
85+
y_lm_rotated = gsplat_eval_sh_bases(num_coeffs, dirs_rotated)
86+
color_rotated = (sh_coeffs_rotated * y_lm_rotated[..., None, :]).sum(dim=-1)
87+
88+
torch.testing.assert_close(
89+
gsplat_spherical_harmonics(sh_degree, coeffs=sh_coeffs.swapaxes(-1, -2), dirs=dirs),
90+
gsplat_spherical_harmonics(sh_degree, coeffs=sh_coeffs_rotated.swapaxes(-1, -2), dirs=dirs_rotated),
91+
)
92+
93+
94+
@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
95+
@pytest.mark.parametrize("component_convention", ["+y,+z,+x", "-y,+z,-x"])
96+
def test_spherical_harmonics_rotation_properties(sh_degree: int, component_convention: Literal["+y,+z,+x", "-y,+z,-x"]):
5697
"""Test properties of the SH rotation"""
5798
torch.manual_seed(0)
5899
np.random.seed(0)
@@ -61,7 +102,7 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
61102
num_coeffs = (sh_degree + 1) ** 2
62103
sh_coeffs = torch.rand(N, 3, num_coeffs)
63104
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
64-
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
105+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention)
65106

66107
# Norm preserving
67108
norm_original = torch.norm(sh_coeffs, dim=-1)
@@ -73,5 +114,5 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
73114

74115
# Identity rotation
75116
rot_matrix = torch.eye(3)
76-
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
117+
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, ordering)
77118
torch.testing.assert_close(sh_coeffs, sh_coeffs_rotated, rtol=0, atol=1e-6)

0 commit comments

Comments
 (0)