@@ -481,3 +481,233 @@ def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
481481
482482 plydata .write (path )
483483 return plydata
484+
485+
486+ @torch .no_grad ()
487+ def save_splat (
488+ gaussians : Gaussians3D , f_px : float , image_shape : tuple [int , int ], path : Path
489+ ) -> None :
490+ """Save Gaussians to .splat format (compact binary format for web viewers).
491+
492+ The .splat format is a simple binary format used by web-based 3DGS viewers.
493+ Each Gaussian is stored as 32 bytes:
494+ - 12 bytes: xyz position (3 x float32)
495+ - 12 bytes: scales (3 x float32)
496+ - 4 bytes: RGBA color (4 x uint8)
497+ - 4 bytes: quaternion rotation (4 x uint8, encoded as (q * 128 + 128))
498+
499+ Gaussians are sorted by size * opacity (descending) for progressive rendering.
500+ """
501+ xyz = gaussians .mean_vectors .flatten (0 , 1 ).cpu ().numpy ()
502+ scales = gaussians .singular_values .flatten (0 , 1 ).cpu ().numpy ()
503+ quats = gaussians .quaternions .flatten (0 , 1 ).cpu ().numpy ()
504+ colors_rgb = cs_utils .linearRGB2sRGB (gaussians .colors .flatten (0 , 1 )).cpu ().numpy ()
505+ opacities = gaussians .opacities .flatten (0 , 1 ).cpu ().numpy ()
506+
507+ # Sort by size * opacity (descending) for progressive rendering
508+ sort_idx = np .argsort (- (scales .prod (axis = 1 ) * opacities ))
509+
510+ # Normalize quaternions
511+ quats = quats / np .linalg .norm (quats , axis = 1 , keepdims = True )
512+
513+ with open (path , "wb" ) as f :
514+ for i in sort_idx :
515+ f .write (xyz [i ].astype (np .float32 ).tobytes ())
516+ f .write (scales [i ].astype (np .float32 ).tobytes ())
517+ rgba = np .concatenate ([colors_rgb [i ], [opacities [i ]]])
518+ f .write ((rgba * 255 ).clip (0 , 255 ).astype (np .uint8 ).tobytes ())
519+ f .write ((quats [i ] * 128 + 128 ).clip (0 , 255 ).astype (np .uint8 ).tobytes ())
520+
521+
522+ @torch .no_grad ()
523+ def save_sog (
524+ gaussians : Gaussians3D , f_px : float , image_shape : tuple [int , int ], path : Path
525+ ) -> None :
526+ """Save Gaussians to SOG format (Spatially Ordered Gaussians).
527+
528+ SOG is a highly compressed format using quantization and WebP images.
529+ Typically 15-20x smaller than PLY. The format stores data in a ZIP archive
530+ containing WebP images for positions, rotations, scales, and colors.
531+
532+ Reference: https://github.com/aras-p/sog-format
533+ """
534+ import io
535+ import json
536+ import math
537+ import zipfile
538+
539+ from PIL import Image
540+
541+ xyz = gaussians .mean_vectors .flatten (0 , 1 ).cpu ().numpy ()
542+ scales = gaussians .singular_values .flatten (0 , 1 ).cpu ().numpy ()
543+ quats = gaussians .quaternions .flatten (0 , 1 ).cpu ().numpy ()
544+ colors_linear = gaussians .colors .flatten (0 , 1 ).cpu ().numpy ()
545+ opacities = gaussians .opacities .flatten (0 , 1 ).cpu ().numpy ()
546+
547+ num_gaussians = len (xyz )
548+
549+ # Compute image dimensions (roughly square)
550+ img_width = int (math .ceil (math .sqrt (num_gaussians )))
551+ img_height = int (math .ceil (num_gaussians / img_width ))
552+ total_pixels = img_width * img_height
553+
554+ # Pad arrays to fill image
555+ def pad_array (arr : np .ndarray , total : int ) -> np .ndarray :
556+ if len (arr ) < total :
557+ pad_shape = (total - len (arr ),) + arr .shape [1 :]
558+ return np .concatenate ([arr , np .zeros (pad_shape , dtype = arr .dtype )])
559+ return arr
560+
561+ xyz = pad_array (xyz , total_pixels )
562+ scales = pad_array (scales , total_pixels )
563+ quats = pad_array (quats , total_pixels )
564+ colors_linear = pad_array (colors_linear , total_pixels )
565+ opacities = pad_array (opacities , total_pixels )
566+
567+ # Normalize quaternions
568+ quats = quats / (np .linalg .norm (quats , axis = 1 , keepdims = True ) + 1e-8 )
569+
570+ # === 1. Encode positions (16-bit per axis with symmetric log transform) ===
571+ def symlog (x : np .ndarray ) -> np .ndarray :
572+ return np .sign (x ) * np .log1p (np .abs (x ))
573+
574+ xyz_log = symlog (xyz )
575+ mins = xyz_log .min (axis = 0 )
576+ maxs = xyz_log .max (axis = 0 )
577+
578+ # Avoid division by zero
579+ ranges = maxs - mins
580+ ranges = np .where (ranges < 1e-8 , 1.0 , ranges )
581+
582+ # Quantize to 16-bit
583+ xyz_norm = (xyz_log - mins ) / ranges
584+ xyz_q16 = (xyz_norm * 65535 ).clip (0 , 65535 ).astype (np .uint16 )
585+
586+ means_l = (xyz_q16 & 0xFF ).astype (np .uint8 )
587+ means_u = (xyz_q16 >> 8 ).astype (np .uint8 )
588+
589+ # === 2. Encode quaternions (smallest-three, 26-bit) ===
590+ def encode_quaternion (q : np .ndarray ) -> np .ndarray :
591+ """Encode quaternion using smallest-three method."""
592+ # Find largest component
593+ abs_q = np .abs (q )
594+ mode = np .argmax (abs_q , axis = 1 )
595+
596+ # Ensure the largest component is positive
597+ signs = np .sign (q [np .arange (len (q )), mode ])
598+ q = q * signs [:, None ]
599+
600+ # Extract the three smallest components
601+ result = np .zeros ((len (q ), 4 ), dtype = np .uint8 )
602+ sqrt2_inv = 1.0 / math .sqrt (2 )
603+
604+ for i in range (len (q )):
605+ m = mode [i ]
606+ # Get indices of the three kept components
607+ kept = [j for j in range (4 ) if j != m ]
608+ vals = q [i , kept ]
609+ # Quantize from [-sqrt2/2, sqrt2/2] to [0, 255]
610+ encoded = ((vals * sqrt2_inv + 0.5 ) * 255 ).clip (0 , 255 ).astype (np .uint8 )
611+ result [i , :3 ] = encoded
612+ result [i , 3 ] = 252 + m # Mode in alpha channel
613+
614+ return result
615+
616+ quats_encoded = encode_quaternion (quats )
617+
618+ # === 3. Build scale codebook (256 entries) ===
619+ # SOG stores scales in LOG space - the renderer does exp(codebook[idx])
620+ scales_log = np .log (np .maximum (scales , 1e-10 ))
621+ scales_log_flat = scales_log .flatten ()
622+
623+ # Use percentiles for codebook (in log space)
624+ percentiles = np .linspace (0 , 100 , 256 )
625+ scale_codebook = np .percentile (scales_log_flat , percentiles ).astype (np .float32 )
626+
627+ # Quantize values to nearest codebook entry
628+ def quantize_to_codebook (values : np .ndarray , codebook : np .ndarray ) -> np .ndarray :
629+ indices = np .searchsorted (codebook , values )
630+ indices = np .clip (indices , 0 , len (codebook ) - 1 )
631+ # Check if previous index is closer
632+ prev_indices = np .clip (indices - 1 , 0 , len (codebook ) - 1 )
633+ dist_curr = np .abs (values - codebook [indices ])
634+ dist_prev = np .abs (values - codebook [prev_indices ])
635+ use_prev = (dist_prev < dist_curr ) & (indices > 0 )
636+ indices = np .where (use_prev , prev_indices , indices )
637+ return indices .astype (np .uint8 )
638+
639+ scales_q = np .stack (
640+ [
641+ quantize_to_codebook (scales_log [:, 0 ], scale_codebook ),
642+ quantize_to_codebook (scales_log [:, 1 ], scale_codebook ),
643+ quantize_to_codebook (scales_log [:, 2 ], scale_codebook ),
644+ ],
645+ axis = 1 ,
646+ )
647+
648+ # === 4. Build SH0 codebook and encode colors ===
649+ SH_C0 = 0.28209479177387814
650+ sh0_coeffs = (colors_linear - 0.5 ) / SH_C0
651+ sh0_flat = sh0_coeffs .flatten ()
652+
653+ sh0_percentiles = np .linspace (0 , 100 , 256 )
654+ sh0_codebook = np .percentile (sh0_flat , sh0_percentiles ).astype (np .float32 )
655+
656+ sh0_r = quantize_to_codebook (sh0_coeffs [:, 0 ], sh0_codebook )
657+ sh0_g = quantize_to_codebook (sh0_coeffs [:, 1 ], sh0_codebook )
658+ sh0_b = quantize_to_codebook (sh0_coeffs [:, 2 ], sh0_codebook )
659+ sh0_a = (opacities * 255 ).clip (0 , 255 ).astype (np .uint8 )
660+
661+ # === 5. Create images ===
662+ def create_image (data : np .ndarray , width : int , height : int ) -> Image .Image :
663+ data = data .reshape (height , width , - 1 )
664+ if data .shape [2 ] == 3 :
665+ return Image .fromarray (data , mode = "RGB" )
666+ elif data .shape [2 ] == 4 :
667+ return Image .fromarray (data , mode = "RGBA" )
668+ else :
669+ raise ValueError (f"Unexpected channel count: { data .shape [2 ]} " )
670+
671+ means_l_img = create_image (means_l , img_width , img_height )
672+ means_u_img = create_image (means_u , img_width , img_height )
673+ quats_img = create_image (quats_encoded , img_width , img_height )
674+ scales_img = create_image (scales_q , img_width , img_height )
675+
676+ sh0_data = np .stack ([sh0_r , sh0_g , sh0_b , sh0_a ], axis = 1 )
677+ sh0_img = create_image (sh0_data , img_width , img_height )
678+
679+ # === 6. Create meta.json ===
680+ meta = {
681+ "version" : 2 ,
682+ "count" : num_gaussians ,
683+ "antialias" : False ,
684+ "means" : {
685+ "mins" : mins .tolist (),
686+ "maxs" : maxs .tolist (),
687+ "files" : ["means_l.webp" , "means_u.webp" ],
688+ },
689+ "scales" : {"codebook" : scale_codebook .tolist (), "files" : ["scales.webp" ]},
690+ "quats" : {"files" : ["quats.webp" ]},
691+ "sh0" : {"codebook" : sh0_codebook .tolist (), "files" : ["sh0.webp" ]},
692+ }
693+
694+ # === 7. Save as ZIP archive ===
695+ path = Path (path )
696+ if path .suffix .lower () != ".sog" :
697+ path = path .with_suffix (".sog" )
698+
699+ with zipfile .ZipFile (path , "w" , zipfile .ZIP_DEFLATED ) as zf :
700+ # Save images as lossless WebP
701+ for name , img in [
702+ ("means_l.webp" , means_l_img ),
703+ ("means_u.webp" , means_u_img ),
704+ ("quats.webp" , quats_img ),
705+ ("scales.webp" , scales_img ),
706+ ("sh0.webp" , sh0_img ),
707+ ]:
708+ buf = io .BytesIO ()
709+ img .save (buf , format = "WEBP" , lossless = True )
710+ zf .writestr (name , buf .getvalue ())
711+
712+ # Save meta.json
713+ zf .writestr ("meta.json" , json .dumps (meta , indent = 2 ))
0 commit comments