Skip to content

Commit 323d88e

Browse files
AIBluefisherChenyu
andauthored
Minor fixes to PNGCompression (#817)
* Add small value to avoid numerical issues * Using predefined sort keys - When 3DGS is extended with additional properties, e.g. appearance embeddings, the sorting can be erroneous. In this case, we sort only predefined attributes for 3DGS (means, quats, scales, opacities, and sh0 if appearance embedding is not used in the model). --------- Co-authored-by: Chenyu <[email protected]>
1 parent be97354 commit 323d88e

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

gsplat/compression/png_compression.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def _compress_kmeans(
325325
params: Tensor,
326326
n_clusters: int = 65536,
327327
quantization: int = 6,
328+
eps: float = 1e-6,
328329
verbose: bool = True,
329330
**kwargs,
330331
) -> Dict[str, Any]:
@@ -339,6 +340,7 @@ def _compress_kmeans(
339340
params (Tensor): parameters to compress
340341
n_clusters (int): number of K-means clusters
341342
quantization (int): number of bits in quantization
343+
eps (float, optional): small value to avoid numerical issues. Default to 1e-6.
342344
verbose (bool, optional): Whether to print verbose information. Default to True.
343345
344346
Returns:
@@ -364,7 +366,7 @@ def _compress_kmeans(
364366
labels = labels.detach().cpu().numpy()
365367
centroids = kmeans.centroids.permute(1, 0)
366368

367-
mins = torch.min(centroids)
369+
mins = torch.min(centroids) + eps
368370
maxs = torch.max(centroids)
369371
centroids_norm = (centroids - mins) / (maxs - mins)
370372
centroids_norm = centroids_norm.detach().cpu().numpy()

gsplat/compression/sort.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def sort_splats(splats: Dict[str, Tensor], verbose: bool = True) -> Dict[str, Te
2929
n_sidelen = int(n_gs**0.5)
3030
assert n_sidelen**2 == n_gs, "Must be a perfect square"
3131

32-
sort_keys = [k for k in splats if k != "shN"]
32+
sort_keys = ["means", "quats", "scales", "opacities"]
33+
if "sh0" in splats:
34+
sort_keys.append("sh0")
35+
3336
params_to_sort = torch.cat([splats[k].reshape(n_gs, -1) for k in sort_keys], dim=-1)
3437
shuffled_indices = torch.randperm(
3538
params_to_sort.shape[0], device=params_to_sort.device

0 commit comments

Comments
 (0)