Skip to content

Commit b44cc41

Browse files
convert to using degrees, not levels
1 parent 872668d commit b44cc41

File tree

5 files changed

+21
-21
lines changed

5 files changed

+21
-21
lines changed

nerfstudio/field_components/encodings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -757,15 +757,15 @@ class SHEncoding(Encoding):
757757
"""Spherical harmonic encoding
758758
759759
Args:
760-
levels: Number of spherical harmonic levels to encode.
760+
levels: Number of spherical harmonic levels to encode. (level = sh degree + 1)
761761
"""
762762

763763
def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
764764
super().__init__(in_dim=3)
765765

766-
if levels <= 0 or levels > MAX_SH_DEGREE:
766+
if levels <= 0 or levels > MAX_SH_DEGREE + 1:
767767
raise ValueError(
768-
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE} levels, requested {levels}"
768+
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}"
769769
)
770770

771771
self.levels = levels
@@ -781,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
781781
)
782782

783783
@classmethod
784-
def get_tcnn_encoding_config(cls, levels) -> dict:
784+
def get_tcnn_encoding_config(cls, levels: int) -> dict:
785785
"""Get the encoding configuration for tcnn if implemented"""
786786
encoding_config = {
787787
"otype": "SphericalHarmonics",
@@ -795,7 +795,7 @@ def get_out_dim(self) -> int:
795795
@torch.no_grad()
796796
def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
797797
"""Forward pass using pytorch. Significantly slower than TCNN implementation."""
798-
return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)
798+
return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor)
799799

800800
def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
801801
if self.tcnn_encoding is not None:

nerfstudio/model_components/renderers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def forward(
269269
sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3)
270270

271271
levels = int(math.sqrt(sh.shape[-1]))
272-
components = components_from_spherical_harmonics(levels=levels, directions=directions)
272+
components = components_from_spherical_harmonics(degree=levels - 1, directions=directions)
273273

274274
rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components]
275275
rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3]

nerfstudio/utils/spherical_harmonics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@
2222

2323

2424
def components_from_spherical_harmonics(
25-
levels: int, directions: Float[Tensor, "*batch 3"]
25+
degree: int, directions: Float[Tensor, "*batch 3"]
2626
) -> Float[Tensor, "*batch components"]:
2727
"""
2828
Returns value for each component of spherical harmonics.
2929
3030
Args:
31-
levels: Number of spherical harmonic levels to compute.
31+
degree: Number of spherical harmonic degrees to compute.
3232
directions: Spherical harmonic coefficients
3333
"""
34-
num_components = levels**2
34+
num_components = num_sh_bases(degree)
3535
components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device)
3636

37-
assert 1 <= levels <= MAX_SH_DEGREE, f"SH levels must be in [1,{MAX_SH_DEGREE}], got {levels}"
37+
assert 0 <= degree <= MAX_SH_DEGREE, f"SH degree must be in [0, {MAX_SH_DEGREE}], got {degree}"
3838
assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}"
3939

4040
x = directions[..., 0]
@@ -49,21 +49,21 @@ def components_from_spherical_harmonics(
4949
components[..., 0] = 0.28209479177387814
5050

5151
# l1
52-
if levels > 1:
52+
if degree > 0:
5353
components[..., 1] = 0.4886025119029199 * y
5454
components[..., 2] = 0.4886025119029199 * z
5555
components[..., 3] = 0.4886025119029199 * x
5656

5757
# l2
58-
if levels > 2:
58+
if degree > 1:
5959
components[..., 4] = 1.0925484305920792 * x * y
6060
components[..., 5] = 1.0925484305920792 * y * z
6161
components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999
6262
components[..., 7] = 1.0925484305920792 * x * z
6363
components[..., 8] = 0.5462742152960396 * (xx - yy)
6464

6565
# l3
66-
if levels > 3:
66+
if degree > 2:
6767
components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy)
6868
components[..., 10] = 2.890611442640554 * x * y * z
6969
components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1)
@@ -73,7 +73,7 @@ def components_from_spherical_harmonics(
7373
components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy)
7474

7575
# l4
76-
if levels > 4:
76+
if degree > 3:
7777
components[..., 16] = 2.5033429417967046 * x * y * (xx - yy)
7878
components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy)
7979
components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1)

tests/field_components/test_encodings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ def test_tensor_cp_encoder():
125125
def test_tensor_sh_encoder():
126126
"""Test Spherical Harmonic encoder"""
127127

128-
levels = 4
128+
levels = 5
129129
out_dim = levels**2
130130

131131
with pytest.raises(ValueError):
132-
encoder = encodings.SHEncoding(levels=5)
132+
encoder = encodings.SHEncoding(levels=6)
133133

134134
encoder = encodings.SHEncoding(levels=levels)
135135
assert encoder.get_out_dim() == out_dim
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import pytest
22
import torch
33

4-
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics
4+
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics, num_sh_bases
55

66

7-
@pytest.mark.parametrize("components", list(range(1, 5)))
8-
def test_spherical_harmonics(components):
7+
@pytest.mark.parametrize("degree", list(range(0, 5)))
8+
def test_spherical_harmonics(degree):
99
torch.manual_seed(0)
1010
N = 1000000
1111

1212
dx = torch.normal(0, 1, size=(N, 3))
1313
dx = dx / torch.linalg.norm(dx, dim=-1, keepdim=True)
14-
sh = components_from_spherical_harmonics(components, dx)
14+
sh = components_from_spherical_harmonics(degree, dx)
1515
matrix = (sh.T @ sh) / N * 4 * torch.pi
16-
torch.testing.assert_close(matrix, torch.eye(components**2), rtol=0, atol=1.5e-2)
16+
torch.testing.assert_close(matrix, torch.eye(num_sh_bases(degree)), rtol=0, atol=1.5e-2)

0 commit comments

Comments
 (0)