Skip to content

Commit ee93e67

Browse files
committed
Merge remote-tracking branch 'origin/main' into main
2 parents 8549ae8 + 2df0a95 commit ee93e67

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1167
-809
lines changed

.github/workflows/publish.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,6 @@ jobs:
106106
env:
107107
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
108108
run: |
109+
BUILD_NO_CUDA=1 python -m build
109110
twine upload --username __token__ --password $PYPI_TOKEN dist/*
110-
shell: bash
111+
shell: bash

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
recursive-include gsplat/cuda/csrc *
2+
recursive-include gsplat/cuda/include *

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@ To build gsplat from source on Windows, please check [this instruction](docs/INS
4040
This repo comes with a standalone script that reproduces the official Gaussian Splatting with exactly the same performance on PSNR, SSIM, LPIPS, and converged number of Gaussians. Powered by gsplat’s efficient CUDA implementation, the training takes up to **4x less GPU memory** with up to **15% less time** to finish than the official implementation. Full report can be found [here](https://docs.gsplat.studio/main/tests/eval.html).
4141

4242
```bash
43-
pip install -r examples/requirements.txt
43+
cd examples
44+
pip install -r requirements.txt
4445
# download mipnerf_360 benchmark data
45-
python examples/datasets/download_dataset.py
46+
python datasets/download_dataset.py
4647
# run batch evaluation
47-
bash examples/benchmarks/basic.sh
48+
bash benchmarks/basic.sh
4849
```
4950

5051
## Examples

examples/datasets/colmap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def __init__(
266266
+ params[2] * theta**6
267267
+ params[3] * theta**8
268268
)
269-
mapx = fx * x1 * r + width // 2
270-
mapy = fy * y1 * r + height // 2
269+
mapx = (fx * x1 * r + width // 2).astype(np.float32)
270+
mapy = (fy * y1 * r + height // 2).astype(np.float32)
271271

272272
# Use mask to define ROI
273273
mask = np.logical_and(

examples/simple_trainer.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from gsplat.distributed import cli
4141
from gsplat.rendering import rasterization
4242
from gsplat.strategy import DefaultStrategy, MCMCStrategy
43+
from gsplat.optimizers import SelectiveAdam
4344

4445

4546
@dataclass
@@ -115,6 +116,8 @@ class Config:
115116
packed: bool = False
116117
# Use sparse gradients for optimization. (experimental)
117118
sparse_grad: bool = False
119+
# Use visible adam from Taming 3DGS. (experimental)
120+
visible_adam: bool = False
118121
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
119122
antialiased: bool = False
120123

@@ -191,6 +194,7 @@ def create_splats_with_optimizers(
191194
scene_scale: float = 1.0,
192195
sh_degree: int = 3,
193196
sparse_grad: bool = False,
197+
visible_adam: bool = False,
194198
batch_size: int = 1,
195199
feature_dim: Optional[int] = None,
196200
device: str = "cuda",
@@ -247,8 +251,15 @@ def create_splats_with_optimizers(
247251
# Note that this would not make the training exactly equivalent, see
248252
# https://arxiv.org/pdf/2402.18824v1
249253
BS = batch_size * world_size
254+
optimizer_class = None
255+
if sparse_grad:
256+
optimizer_class = torch.optim.SparseAdam
257+
elif visible_adam:
258+
optimizer_class = SelectiveAdam
259+
else:
260+
optimizer_class = torch.optim.Adam
250261
optimizers = {
251-
name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)(
262+
name: optimizer_class(
252263
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
253264
eps=1e-15 / math.sqrt(BS),
254265
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
@@ -316,6 +327,7 @@ def __init__(
316327
scene_scale=self.scene_scale,
317328
sh_degree=cfg.sh_degree,
318329
sparse_grad=cfg.sparse_grad,
330+
visible_adam=cfg.visible_adam,
319331
batch_size=cfg.batch_size,
320332
feature_dim=feature_dim,
321333
device=self.device,
@@ -739,9 +751,22 @@ def train(self):
739751
is_coalesced=len(Ks) == 1,
740752
)
741753

754+
if cfg.visible_adam:
755+
gaussian_cnt = self.splats.means.shape[0]
756+
if cfg.packed:
757+
visibility_mask = torch.zeros_like(
758+
self.splats["opacities"], dtype=bool
759+
)
760+
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
761+
else:
762+
visibility_mask = (info["radii"] > 0).any(0)
763+
742764
# optimize
743765
for optimizer in self.optimizers.values():
744-
optimizer.step()
766+
if cfg.visible_adam:
767+
optimizer.step(visibility_mask)
768+
else:
769+
optimizer.step()
745770
optimizer.zero_grad(set_to_none=True)
746771
for optimizer in self.pose_optimizers:
747772
optimizer.step()

examples/simple_trainer_2dgs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def rasterize_splats(
436436
**kwargs,
437437
)
438438
elif self.model_type == "2dgs-inria":
439-
render_colors, render_alphas, info = rasterization_2dgs_inria_wrapper(
439+
renders, info = rasterization_2dgs_inria_wrapper(
440440
means=means,
441441
quats=quats,
442442
scales=scales,
@@ -577,6 +577,10 @@ def train(self):
577577
step=step,
578578
info=info,
579579
)
580+
masks = data["mask"].to(device) if "mask" in data else None
581+
if masks is not None:
582+
pixels = pixels * masks[..., None]
583+
colors = colors * masks[..., None]
580584

581585
# loss
582586
l1loss = F.l1_loss(colors, pixels)

examples/simple_viewer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def main(local_rank: int, world_rank, world_size: int, args):
6868
quats, # [N, 4]
6969
scales, # [N, 3]
7070
opacities, # [N]
71-
colors, # [N, 3]
71+
colors, # [N, S, 3]
7272
viewmats, # [C, 4, 4]
7373
Ks, # [C, 3, 3]
7474
width,
@@ -181,7 +181,7 @@ def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
181181
quats, # [N, 4]
182182
scales, # [N, 3]
183183
opacities, # [N]
184-
colors, # [N, 3]
184+
colors, # [N, S, 3]
185185
viewmat[None], # [1, 4, 4]
186186
K[None], # [1, 3, 3]
187187
width,

gsplat/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
from .compression import PngCompression
4+
from .optimizers import SelectiveAdam
45
from .cuda._torch_impl import accumulate
56
from .cuda._torch_impl_2dgs import accumulate_2dgs
67
from .cuda._wrapper import (

gsplat/cuda/_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def cuda_toolkit_version():
8989
current_dir = os.path.dirname(os.path.abspath(__file__))
9090
glm_path = os.path.join(current_dir, "csrc", "third_party", "glm")
9191

92-
extra_include_paths = [os.path.join(PATH, "csrc/"), glm_path]
92+
extra_include_paths = [os.path.join(PATH, "include/"), glm_path]
9393
extra_cflags = ["-O3"]
9494
if NO_FAST_MATH:
9595
extra_cuda_cflags = ["-O3"]

gsplat/cuda/_wrapper.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@ def call_cuda(*args, **kwargs):
1616
return call_cuda
1717

1818

19+
def selective_adam_update(
20+
param: Tensor,
21+
param_grad: Tensor,
22+
exp_avg: Tensor,
23+
exp_avg_sq: Tensor,
24+
tiles_touched: Tensor,
25+
lr: float,
26+
b1: float,
27+
b2: float,
28+
eps: float,
29+
N: int,
30+
M: int,
31+
) -> None:
32+
_make_lazy_cuda_func("selective_adam_update")(
33+
param, param_grad, exp_avg, exp_avg_sq, tiles_touched, lr, b1, b2, eps, N, M
34+
)
35+
36+
1937
def _make_lazy_cuda_obj(name: str) -> Any:
2038
# pylint: disable=import-outside-toplevel
2139
from ._backend import _C

0 commit comments

Comments
 (0)