Skip to content

Commit 96ca8b0

Browse files
jb-yebell-oneJianbo Yebrentyi
authored
Update exporter.py to export sh_degree 0 case #3371 (#3374)
* Update exporter.py for sh_degree 0 Change to write sh coefficients instead of color values * Add flag for use_sh0_renderer Add sh0 renderer case for model.config.sh_degree == 0 * fix ruff * add warning if use_sh0_renderer is used when higher order of SH is available * fix rgb export for color-only training * use ply_color_mode * better handling ply_color_mode=='rgb' when sh_degree>0 * clean RGB2SH * fix issues * update description --------- Co-authored-by: bell-one <[email protected]> Co-authored-by: Jianbo Ye <[email protected]> Co-authored-by: Brent Yi <[email protected]>
1 parent 22dae34 commit 96ca8b0

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

nerfstudio/models/splatfacto.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,10 @@ def colors(self):
314314

315315
@property
316316
def shs_0(self):
317-
return self.features_dc
317+
if self.config.sh_degree > 0:
318+
return self.features_dc
319+
else:
320+
return RGB2SH(torch.sigmoid(self.features_dc))
318321

319322
@property
320323
def shs_rest(self):

nerfstudio/scripts/exporter.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,9 @@ class ExportGaussianSplat(Exporter):
485485
"""Rotation of the oriented bounding box. Expressed as RPY Euler angles in radians"""
486486
obb_scale: Optional[Tuple[float, float, float]] = None
487487
"""Scale of the oriented bounding box along each axis."""
488+
ply_color_mode: Literal["sh_coeffs", "rgb"] = "sh_coeffs"
489+
"""If "rgb", export colors as red/green/blue fields. Otherwise, export colors as
490+
spherical harmonics coefficients."""
488491

489492
@staticmethod
490493
def write_ply(
@@ -504,7 +507,7 @@ def write_ply(
504507
"""
505508

506509
# Ensure count matches the length of all tensors
507-
if not all(len(tensor) == count for tensor in map_to_tensors.values()):
510+
if not all(tensor.size == count for tensor in map_to_tensors.values()):
508511
raise ValueError("Count does not match the length of all tensors")
509512

510513
# Type check for numpy arrays of type float or uint8 and non-empty
@@ -552,7 +555,6 @@ def main(self) -> None:
552555

553556
filename = self.output_dir / "splat.ply"
554557

555-
count = 0
556558
map_to_tensors = OrderedDict()
557559

558560
with torch.no_grad():
@@ -566,19 +568,28 @@ def main(self) -> None:
566568
map_to_tensors["ny"] = np.zeros(n, dtype=np.float32)
567569
map_to_tensors["nz"] = np.zeros(n, dtype=np.float32)
568570

569-
if model.config.sh_degree > 0:
571+
if self.ply_color_mode == "rgb":
572+
colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy()
573+
colors = (colors * 255).astype(np.uint8)
574+
map_to_tensors["red"] = colors[:, 0]
575+
map_to_tensors["green"] = colors[:, 1]
576+
map_to_tensors["blue"] = colors[:, 2]
577+
elif self.ply_color_mode == "sh_coeffs":
570578
shs_0 = model.shs_0.contiguous().cpu().numpy()
571579
for i in range(shs_0.shape[1]):
572580
map_to_tensors[f"f_dc_{i}"] = shs_0[:, i, None]
573581

574-
# transpose(1, 2) was needed to match the sh order in Inria version
575-
shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy()
576-
shs_rest = shs_rest.reshape((n, -1))
577-
for i in range(shs_rest.shape[-1]):
578-
map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None]
579-
else:
580-
colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy()
581-
map_to_tensors["colors"] = (colors * 255).astype(np.uint8)
582+
if model.config.sh_degree > 0:
583+
if self.ply_color_mode == "rgb":
584+
CONSOLE.print(
585+
"Warning: model has higher level of spherical harmonics, ignoring them and only export rgb."
586+
)
587+
elif self.ply_color_mode == "sh_coeffs":
588+
# transpose(1, 2) was needed to match the sh order in Inria version
589+
shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy()
590+
shs_rest = shs_rest.reshape((n, -1))
591+
for i in range(shs_rest.shape[-1]):
592+
map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None]
582593

583594
map_to_tensors["opacity"] = model.opacities.data.cpu().numpy()
584595

0 commit comments

Comments
 (0)