Skip to content

Commit 9445c0c

Browse files
committed
Merge branch 'pr-24'
Merge PR apple#24: improve gaussian utilities and CLI prediction � :wq
2 parents 1eaa046 + 71af9eb commit 9445c0c

File tree

2 files changed

+255
-1
lines changed

2 files changed

+255
-1
lines changed

src/sharp/cli/predict.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
Gaussians3D,
2727
SceneMetaData,
2828
save_ply,
29+
save_splat,
30+
save_sog,
2931
unproject_gaussians,
3032
)
3133

@@ -66,6 +68,15 @@
6668
default=False,
6769
help="Whether to render trajectory for checkpoint.",
6870
)
71+
@click.option(
72+
"-f",
73+
"--format",
74+
"export_formats",
75+
type=click.Choice(["ply", "splat", "sog"], case_sensitive=False),
76+
multiple=True,
77+
default=["ply"],
78+
help="Output format(s). Can specify multiple: -f ply -f splat -f sog",
79+
)
6980
@click.option(
7081
"--device",
7182
type=str,
@@ -78,12 +89,16 @@ def predict_cli(
7889
output_path: Path,
7990
checkpoint_path: Path,
8091
with_rendering: bool,
92+
export_formats: tuple[str, ...],
8193
device: str,
8294
verbose: bool,
8395
):
8496
"""Predict Gaussians from input images."""
8597
logging_utils.configure(logging.DEBUG if verbose else logging.INFO)
8698

99+
# Normalize export formats to lowercase
100+
export_formats = tuple(fmt.lower() for fmt in export_formats)
101+
87102
extensions = io.get_supported_image_extensions()
88103

89104
image_paths = []
@@ -145,7 +160,16 @@ def predict_cli(
145160
gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device))
146161

147162
LOGGER.info("Saving 3DGS to %s", output_path)
148-
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
163+
for fmt in export_formats:
164+
if fmt == "ply":
165+
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
166+
LOGGER.info("Saved PLY: %s", output_path / f"{image_path.stem}.ply")
167+
elif fmt == "splat":
168+
save_splat(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.splat")
169+
LOGGER.info("Saved SPLAT: %s", output_path / f"{image_path.stem}.splat")
170+
elif fmt == "sog":
171+
save_sog(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.sog")
172+
LOGGER.info("Saved SOG: %s", output_path / f"{image_path.stem}.sog")
149173

150174
if with_rendering:
151175
output_video_path = (output_path / image_path.stem).with_suffix(".mp4")

src/sharp/utils/gaussians.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,233 @@ def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
481481

482482
plydata.write(path)
483483
return plydata
484+
485+
486+
@torch.no_grad()
487+
def save_splat(
488+
gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path
489+
) -> None:
490+
"""Save Gaussians to .splat format (compact binary format for web viewers).
491+
492+
The .splat format is a simple binary format used by web-based 3DGS viewers.
493+
Each Gaussian is stored as 32 bytes:
494+
- 12 bytes: xyz position (3 x float32)
495+
- 12 bytes: scales (3 x float32)
496+
- 4 bytes: RGBA color (4 x uint8)
497+
- 4 bytes: quaternion rotation (4 x uint8, encoded as (q * 128 + 128))
498+
499+
Gaussians are sorted by size * opacity (descending) for progressive rendering.
500+
"""
501+
xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy()
502+
scales = gaussians.singular_values.flatten(0, 1).cpu().numpy()
503+
quats = gaussians.quaternions.flatten(0, 1).cpu().numpy()
504+
colors_rgb = cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)).cpu().numpy()
505+
opacities = gaussians.opacities.flatten(0, 1).cpu().numpy()
506+
507+
# Sort by size * opacity (descending) for progressive rendering
508+
sort_idx = np.argsort(-(scales.prod(axis=1) * opacities))
509+
510+
# Normalize quaternions
511+
quats = quats / np.linalg.norm(quats, axis=1, keepdims=True)
512+
513+
with open(path, "wb") as f:
514+
for i in sort_idx:
515+
f.write(xyz[i].astype(np.float32).tobytes())
516+
f.write(scales[i].astype(np.float32).tobytes())
517+
rgba = np.concatenate([colors_rgb[i], [opacities[i]]])
518+
f.write((rgba * 255).clip(0, 255).astype(np.uint8).tobytes())
519+
f.write((quats[i] * 128 + 128).clip(0, 255).astype(np.uint8).tobytes())
520+
521+
522+
@torch.no_grad()
523+
def save_sog(
524+
gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path
525+
) -> None:
526+
"""Save Gaussians to SOG format (Spatially Ordered Gaussians).
527+
528+
SOG is a highly compressed format using quantization and WebP images.
529+
Typically 15-20x smaller than PLY. The format stores data in a ZIP archive
530+
containing WebP images for positions, rotations, scales, and colors.
531+
532+
Reference: https://github.com/aras-p/sog-format
533+
"""
534+
import io
535+
import json
536+
import math
537+
import zipfile
538+
539+
from PIL import Image
540+
541+
xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy()
542+
scales = gaussians.singular_values.flatten(0, 1).cpu().numpy()
543+
quats = gaussians.quaternions.flatten(0, 1).cpu().numpy()
544+
colors_linear = gaussians.colors.flatten(0, 1).cpu().numpy()
545+
opacities = gaussians.opacities.flatten(0, 1).cpu().numpy()
546+
547+
num_gaussians = len(xyz)
548+
549+
# Compute image dimensions (roughly square)
550+
img_width = int(math.ceil(math.sqrt(num_gaussians)))
551+
img_height = int(math.ceil(num_gaussians / img_width))
552+
total_pixels = img_width * img_height
553+
554+
# Pad arrays to fill image
555+
def pad_array(arr: np.ndarray, total: int) -> np.ndarray:
556+
if len(arr) < total:
557+
pad_shape = (total - len(arr),) + arr.shape[1:]
558+
return np.concatenate([arr, np.zeros(pad_shape, dtype=arr.dtype)])
559+
return arr
560+
561+
xyz = pad_array(xyz, total_pixels)
562+
scales = pad_array(scales, total_pixels)
563+
quats = pad_array(quats, total_pixels)
564+
colors_linear = pad_array(colors_linear, total_pixels)
565+
opacities = pad_array(opacities, total_pixels)
566+
567+
# Normalize quaternions
568+
quats = quats / (np.linalg.norm(quats, axis=1, keepdims=True) + 1e-8)
569+
570+
# === 1. Encode positions (16-bit per axis with symmetric log transform) ===
571+
def symlog(x: np.ndarray) -> np.ndarray:
572+
return np.sign(x) * np.log1p(np.abs(x))
573+
574+
xyz_log = symlog(xyz)
575+
mins = xyz_log.min(axis=0)
576+
maxs = xyz_log.max(axis=0)
577+
578+
# Avoid division by zero
579+
ranges = maxs - mins
580+
ranges = np.where(ranges < 1e-8, 1.0, ranges)
581+
582+
# Quantize to 16-bit
583+
xyz_norm = (xyz_log - mins) / ranges
584+
xyz_q16 = (xyz_norm * 65535).clip(0, 65535).astype(np.uint16)
585+
586+
means_l = (xyz_q16 & 0xFF).astype(np.uint8)
587+
means_u = (xyz_q16 >> 8).astype(np.uint8)
588+
589+
# === 2. Encode quaternions (smallest-three, 26-bit) ===
590+
def encode_quaternion(q: np.ndarray) -> np.ndarray:
591+
"""Encode quaternion using smallest-three method."""
592+
# Find largest component
593+
abs_q = np.abs(q)
594+
mode = np.argmax(abs_q, axis=1)
595+
596+
# Ensure the largest component is positive
597+
signs = np.sign(q[np.arange(len(q)), mode])
598+
q = q * signs[:, None]
599+
600+
# Extract the three smallest components
601+
result = np.zeros((len(q), 4), dtype=np.uint8)
602+
sqrt2_inv = 1.0 / math.sqrt(2)
603+
604+
for i in range(len(q)):
605+
m = mode[i]
606+
# Get indices of the three kept components
607+
kept = [j for j in range(4) if j != m]
608+
vals = q[i, kept]
609+
# Quantize from [-sqrt2/2, sqrt2/2] to [0, 255]
610+
encoded = ((vals * sqrt2_inv + 0.5) * 255).clip(0, 255).astype(np.uint8)
611+
result[i, :3] = encoded
612+
result[i, 3] = 252 + m # Mode in alpha channel
613+
614+
return result
615+
616+
quats_encoded = encode_quaternion(quats)
617+
618+
# === 3. Build scale codebook (256 entries) ===
619+
# SOG stores scales in LOG space - the renderer does exp(codebook[idx])
620+
scales_log = np.log(np.maximum(scales, 1e-10))
621+
scales_log_flat = scales_log.flatten()
622+
623+
# Use percentiles for codebook (in log space)
624+
percentiles = np.linspace(0, 100, 256)
625+
scale_codebook = np.percentile(scales_log_flat, percentiles).astype(np.float32)
626+
627+
# Quantize values to nearest codebook entry
628+
def quantize_to_codebook(values: np.ndarray, codebook: np.ndarray) -> np.ndarray:
629+
indices = np.searchsorted(codebook, values)
630+
indices = np.clip(indices, 0, len(codebook) - 1)
631+
# Check if previous index is closer
632+
prev_indices = np.clip(indices - 1, 0, len(codebook) - 1)
633+
dist_curr = np.abs(values - codebook[indices])
634+
dist_prev = np.abs(values - codebook[prev_indices])
635+
use_prev = (dist_prev < dist_curr) & (indices > 0)
636+
indices = np.where(use_prev, prev_indices, indices)
637+
return indices.astype(np.uint8)
638+
639+
scales_q = np.stack(
640+
[
641+
quantize_to_codebook(scales_log[:, 0], scale_codebook),
642+
quantize_to_codebook(scales_log[:, 1], scale_codebook),
643+
quantize_to_codebook(scales_log[:, 2], scale_codebook),
644+
],
645+
axis=1,
646+
)
647+
648+
# === 4. Build SH0 codebook and encode colors ===
649+
SH_C0 = 0.28209479177387814
650+
sh0_coeffs = (colors_linear - 0.5) / SH_C0
651+
sh0_flat = sh0_coeffs.flatten()
652+
653+
sh0_percentiles = np.linspace(0, 100, 256)
654+
sh0_codebook = np.percentile(sh0_flat, sh0_percentiles).astype(np.float32)
655+
656+
sh0_r = quantize_to_codebook(sh0_coeffs[:, 0], sh0_codebook)
657+
sh0_g = quantize_to_codebook(sh0_coeffs[:, 1], sh0_codebook)
658+
sh0_b = quantize_to_codebook(sh0_coeffs[:, 2], sh0_codebook)
659+
sh0_a = (opacities * 255).clip(0, 255).astype(np.uint8)
660+
661+
# === 5. Create images ===
662+
def create_image(data: np.ndarray, width: int, height: int) -> Image.Image:
663+
data = data.reshape(height, width, -1)
664+
if data.shape[2] == 3:
665+
return Image.fromarray(data, mode="RGB")
666+
elif data.shape[2] == 4:
667+
return Image.fromarray(data, mode="RGBA")
668+
else:
669+
raise ValueError(f"Unexpected channel count: {data.shape[2]}")
670+
671+
means_l_img = create_image(means_l, img_width, img_height)
672+
means_u_img = create_image(means_u, img_width, img_height)
673+
quats_img = create_image(quats_encoded, img_width, img_height)
674+
scales_img = create_image(scales_q, img_width, img_height)
675+
676+
sh0_data = np.stack([sh0_r, sh0_g, sh0_b, sh0_a], axis=1)
677+
sh0_img = create_image(sh0_data, img_width, img_height)
678+
679+
# === 6. Create meta.json ===
680+
meta = {
681+
"version": 2,
682+
"count": num_gaussians,
683+
"antialias": False,
684+
"means": {
685+
"mins": mins.tolist(),
686+
"maxs": maxs.tolist(),
687+
"files": ["means_l.webp", "means_u.webp"],
688+
},
689+
"scales": {"codebook": scale_codebook.tolist(), "files": ["scales.webp"]},
690+
"quats": {"files": ["quats.webp"]},
691+
"sh0": {"codebook": sh0_codebook.tolist(), "files": ["sh0.webp"]},
692+
}
693+
694+
# === 7. Save as ZIP archive ===
695+
path = Path(path)
696+
if path.suffix.lower() != ".sog":
697+
path = path.with_suffix(".sog")
698+
699+
with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zf:
700+
# Save images as lossless WebP
701+
for name, img in [
702+
("means_l.webp", means_l_img),
703+
("means_u.webp", means_u_img),
704+
("quats.webp", quats_img),
705+
("scales.webp", scales_img),
706+
("sh0.webp", sh0_img),
707+
]:
708+
buf = io.BytesIO()
709+
img.save(buf, format="WEBP", lossless=True)
710+
zf.writestr(name, buf.getvalue())
711+
712+
# Save meta.json
713+
zf.writestr("meta.json", json.dumps(meta, indent=2))

0 commit comments

Comments
 (0)