Skip to content

Commit a8888e7

Browse files
Move util functions out of splatfacto (#3538)
* Move util functions out of splatfacto Nothing else currently uses some of the SH utils, but it might make sense to get them out of splatfacto. I also moved the k nearest neighbors to utils since it doesn't depend on the model class. * fix assert * fix sh test * convert to using degrees, not levels
1 parent e8bf472 commit a8888e7

File tree

8 files changed

+207
-181
lines changed

8 files changed

+207
-181
lines changed

nerfstudio/field_components/encodings.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828

2929
from nerfstudio.field_components.base_field_component import FieldComponent
3030
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
31-
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis
31+
from nerfstudio.utils.math import expected_sin, generate_polyhedron_basis
3232
from nerfstudio.utils.printing import print_tcnn_speed_warning
33+
from nerfstudio.utils.spherical_harmonics import MAX_SH_DEGREE, components_from_spherical_harmonics
3334

3435

3536
class Encoding(FieldComponent):
@@ -756,14 +757,16 @@ class SHEncoding(Encoding):
756757
"""Spherical harmonic encoding
757758
758759
Args:
759-
levels: Number of spherical harmonic levels to encode.
760+
levels: Number of spherical harmonic levels to encode. (level = sh degree + 1)
760761
"""
761762

762763
def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
763764
super().__init__(in_dim=3)
764765

765-
if levels <= 0 or levels > 4:
766-
raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}")
766+
if levels <= 0 or levels > MAX_SH_DEGREE + 1:
767+
raise ValueError(
768+
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}"
769+
)
767770

768771
self.levels = levels
769772

@@ -778,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
778781
)
779782

780783
@classmethod
781-
def get_tcnn_encoding_config(cls, levels) -> dict:
784+
def get_tcnn_encoding_config(cls, levels: int) -> dict:
782785
"""Get the encoding configuration for tcnn if implemented"""
783786
encoding_config = {
784787
"otype": "SphericalHarmonics",
@@ -792,7 +795,7 @@ def get_out_dim(self) -> int:
792795
@torch.no_grad()
793796
def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
794797
"""Forward pass using pytorch. Significantly slower than TCNN implementation."""
795-
return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)
798+
return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor)
796799

797800
def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
798801
if self.tcnn_encoding is not None:

nerfstudio/model_components/renderers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838

3939
from nerfstudio.cameras.rays import RaySamples
4040
from nerfstudio.utils import colors
41-
from nerfstudio.utils.math import components_from_spherical_harmonics, safe_normalize
41+
from nerfstudio.utils.math import safe_normalize
42+
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics
4243

4344
BackgroundColor = Union[Literal["random", "last_sample", "black", "white"], Float[Tensor, "3"], Float[Tensor, "*bs 3"]]
4445
BACKGROUND_COLOR_OVERRIDE: Optional[Float[Tensor, "3"]] = None
@@ -268,7 +269,7 @@ def forward(
268269
sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3)
269270

270271
levels = int(math.sqrt(sh.shape[-1]))
271-
components = components_from_spherical_harmonics(levels=levels, directions=directions)
272+
components = components_from_spherical_harmonics(degree=levels - 1, directions=directions)
272273

273274
rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components]
274275
rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3]

nerfstudio/models/splatfacto.py

Lines changed: 3 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919

2020
from __future__ import annotations
2121

22-
import math
2322
from dataclasses import dataclass, field
2423
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
2524

26-
import numpy as np
2725
import torch
2826
from gsplat.strategy import DefaultStrategy
2927

@@ -42,70 +40,10 @@
4240
from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
4341
from nerfstudio.models.base_model import Model, ModelConfig
4442
from nerfstudio.utils.colors import get_color
43+
from nerfstudio.utils.math import k_nearest_sklearn, random_quat_tensor
4544
from nerfstudio.utils.misc import torch_compile
4645
from nerfstudio.utils.rich_utils import CONSOLE
47-
48-
49-
def num_sh_bases(degree: int) -> int:
50-
"""
51-
Returns the number of spherical harmonic bases for a given degree.
52-
"""
53-
assert degree <= 4, "We don't support degree greater than 4."
54-
return (degree + 1) ** 2
55-
56-
57-
def quat_to_rotmat(quat):
58-
assert quat.shape[-1] == 4, quat.shape
59-
w, x, y, z = torch.unbind(quat, dim=-1)
60-
mat = torch.stack(
61-
[
62-
1 - 2 * (y**2 + z**2),
63-
2 * (x * y - w * z),
64-
2 * (x * z + w * y),
65-
2 * (x * y + w * z),
66-
1 - 2 * (x**2 + z**2),
67-
2 * (y * z - w * x),
68-
2 * (x * z - w * y),
69-
2 * (y * z + w * x),
70-
1 - 2 * (x**2 + y**2),
71-
],
72-
dim=-1,
73-
)
74-
return mat.reshape(quat.shape[:-1] + (3, 3))
75-
76-
77-
def random_quat_tensor(N):
78-
"""
79-
Defines a random quaternion tensor of shape (N, 4)
80-
"""
81-
u = torch.rand(N)
82-
v = torch.rand(N)
83-
w = torch.rand(N)
84-
return torch.stack(
85-
[
86-
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
87-
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
88-
torch.sqrt(u) * torch.sin(2 * math.pi * w),
89-
torch.sqrt(u) * torch.cos(2 * math.pi * w),
90-
],
91-
dim=-1,
92-
)
93-
94-
95-
def RGB2SH(rgb):
96-
"""
97-
Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
98-
"""
99-
C0 = 0.28209479177387814
100-
return (rgb - 0.5) / C0
101-
102-
103-
def SH2RGB(sh):
104-
"""
105-
Converts from the 0th spherical harmonic coefficient to RGB values [0,1]
106-
"""
107-
C0 = 0.28209479177387814
108-
return sh * C0 + 0.5
46+
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases
10947

11048

11149
def resize_image(image: torch.Tensor, d: int):
@@ -243,8 +181,7 @@ def populate_modules(self):
243181
means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
244182
else:
245183
means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
246-
distances, _ = self.k_nearest_sklearn(means.data, 3)
247-
distances = torch.from_numpy(distances)
184+
distances, _ = k_nearest_sklearn(means.data, 3)
248185
# find the average of the three nearest neighbors for each point and use that as the scale
249186
avg_dist = distances.mean(dim=-1, keepdim=True)
250187
scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
@@ -392,26 +329,6 @@ def load_state_dict(self, dict, **kwargs): # type: ignore
392329
self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device))
393330
super().load_state_dict(dict, **kwargs)
394331

395-
def k_nearest_sklearn(self, x: torch.Tensor, k: int):
396-
"""
397-
Find k-nearest neighbors using sklearn's NearestNeighbors.
398-
x: The data tensor of shape [num_samples, num_features]
399-
k: The number of neighbors to retrieve
400-
"""
401-
# Convert tensor to numpy array
402-
x_np = x.cpu().numpy()
403-
404-
# Build the nearest neighbors model
405-
from sklearn.neighbors import NearestNeighbors
406-
407-
nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)
408-
409-
# Find the k-nearest neighbors
410-
distances, indices = nn_model.kneighbors(x_np)
411-
412-
# Exclude the point itself from the result and return
413-
return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)
414-
415332
def set_crop(self, crop_box: Optional[OrientedBox]):
416333
self.crop_box = crop_box
417334

nerfstudio/utils/math.py

Lines changed: 63 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,78 +20,12 @@
2020
from typing import Literal, Tuple
2121

2222
import torch
23-
from jaxtyping import Bool, Float
23+
from jaxtyping import Bool, Float, Int
2424
from torch import Tensor
2525

2626
from nerfstudio.data.scene_box import OrientedBox
2727

2828

29-
def components_from_spherical_harmonics(
30-
levels: int, directions: Float[Tensor, "*batch 3"]
31-
) -> Float[Tensor, "*batch components"]:
32-
"""
33-
Returns value for each component of spherical harmonics.
34-
35-
Args:
36-
levels: Number of spherical harmonic levels to compute.
37-
directions: Spherical harmonic coefficients
38-
"""
39-
num_components = levels**2
40-
components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device)
41-
42-
assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}"
43-
assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}"
44-
45-
x = directions[..., 0]
46-
y = directions[..., 1]
47-
z = directions[..., 2]
48-
49-
xx = x**2
50-
yy = y**2
51-
zz = z**2
52-
53-
# l0
54-
components[..., 0] = 0.28209479177387814
55-
56-
# l1
57-
if levels > 1:
58-
components[..., 1] = 0.4886025119029199 * y
59-
components[..., 2] = 0.4886025119029199 * z
60-
components[..., 3] = 0.4886025119029199 * x
61-
62-
# l2
63-
if levels > 2:
64-
components[..., 4] = 1.0925484305920792 * x * y
65-
components[..., 5] = 1.0925484305920792 * y * z
66-
components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999
67-
components[..., 7] = 1.0925484305920792 * x * z
68-
components[..., 8] = 0.5462742152960396 * (xx - yy)
69-
70-
# l3
71-
if levels > 3:
72-
components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy)
73-
components[..., 10] = 2.890611442640554 * x * y * z
74-
components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1)
75-
components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3)
76-
components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1)
77-
components[..., 14] = 1.445305721320277 * z * (xx - yy)
78-
components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy)
79-
80-
# l4
81-
if levels > 4:
82-
components[..., 16] = 2.5033429417967046 * x * y * (xx - yy)
83-
components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy)
84-
components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1)
85-
components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3)
86-
components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3)
87-
components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3)
88-
components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1)
89-
components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy)
90-
components[..., 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
91-
92-
return components
93-
94-
9529
@dataclass
9630
class Gaussians:
9731
"""Stores Gaussians
@@ -323,7 +257,9 @@ def masked_reduction(
323257

324258

325259
def normalized_depth_scale_and_shift(
326-
prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"]
260+
prediction: Float[Tensor, "1 32 mult"],
261+
target: Float[Tensor, "1 32 mult"],
262+
mask: Bool[Tensor, "1 32 mult"],
327263
):
328264
"""
329265
More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss
@@ -405,7 +341,10 @@ def _compute_tesselation_weights(v: int) -> Tensor:
405341

406342

407343
def _tesselate_geodesic(
408-
vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4
344+
vertices: Float[Tensor, "N 3"],
345+
faces: Float[Tensor, "M 3"],
346+
v: int,
347+
eps: float = 1e-4,
409348
) -> Tensor:
410349
"""Tesselate the vertices of a geodesic polyhedron.
411350
@@ -518,3 +457,58 @@ def generate_polyhedron_basis(
518457

519458
basis = verts.flip(-1)
520459
return basis
460+
461+
462+
def random_quat_tensor(N: int) -> Float[Tensor, "*batch 4"]:
463+
"""
464+
Defines a random quaternion tensor.
465+
466+
Args:
467+
N: Number of quaternions to generate
468+
469+
Returns:
470+
a random quaternion tensor of shape (N, 4)
471+
472+
"""
473+
u = torch.rand(N)
474+
v = torch.rand(N)
475+
w = torch.rand(N)
476+
return torch.stack(
477+
[
478+
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
479+
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
480+
torch.sqrt(u) * torch.sin(2 * math.pi * w),
481+
torch.sqrt(u) * torch.cos(2 * math.pi * w),
482+
],
483+
dim=-1,
484+
)
485+
486+
487+
def k_nearest_sklearn(
488+
x: torch.Tensor, k: int, metric: str = "euclidean"
489+
) -> Tuple[Float[Tensor, "*batch k"], Int[Tensor, "*batch k"]]:
490+
"""
491+
Find k-nearest neighbors using sklearn's NearestNeighbors.
492+
493+
Args:
494+
x: input tensor
495+
k: number of neighbors to find
496+
metric: metric to use for distance computation
497+
498+
Returns:
499+
distances: distances to the k-nearest neighbors
500+
indices: indices of the k-nearest neighbors
501+
"""
502+
# Convert tensor to numpy array
503+
x_np = x.cpu().numpy()
504+
505+
# Build the nearest neighbors model
506+
from sklearn.neighbors import NearestNeighbors
507+
508+
nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric=metric).fit(x_np)
509+
510+
# Find the k-nearest neighbors
511+
distances, indices = nn_model.kneighbors(x_np)
512+
513+
# Exclude the point itself from the result and return
514+
return torch.tensor(distances[:, 1:], dtype=torch.float32), torch.tensor(indices[:, 1:], dtype=torch.int64)

0 commit comments

Comments
 (0)