|
40 | 40 | from gsplat.distributed import cli |
41 | 41 | from gsplat.rendering import rasterization |
42 | 42 | from gsplat.strategy import DefaultStrategy, MCMCStrategy |
| 43 | +from gsplat.optimizers import SelectiveAdam |
43 | 44 |
|
44 | 45 |
|
45 | 46 | @dataclass |
@@ -115,6 +116,8 @@ class Config: |
115 | 116 | packed: bool = False |
116 | 117 | # Use sparse gradients for optimization. (experimental) |
117 | 118 | sparse_grad: bool = False |
| 119 | + # Use visible adam from Taming 3DGS. (experimental) |
| 120 | + visible_adam: bool = False |
118 | 121 | # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. |
119 | 122 | antialiased: bool = False |
120 | 123 |
|
@@ -191,6 +194,7 @@ def create_splats_with_optimizers( |
191 | 194 | scene_scale: float = 1.0, |
192 | 195 | sh_degree: int = 3, |
193 | 196 | sparse_grad: bool = False, |
| 197 | + visible_adam: bool = False, |
194 | 198 | batch_size: int = 1, |
195 | 199 | feature_dim: Optional[int] = None, |
196 | 200 | device: str = "cuda", |
@@ -247,8 +251,15 @@ def create_splats_with_optimizers( |
247 | 251 | # Note that this would not make the training exactly equivalent, see |
248 | 252 | # https://arxiv.org/pdf/2402.18824v1 |
249 | 253 | BS = batch_size * world_size |
| 254 | + optimizer_class = None |
| 255 | + if sparse_grad: |
| 256 | + optimizer_class = torch.optim.SparseAdam |
| 257 | + elif visible_adam: |
| 258 | + optimizer_class = SelectiveAdam |
| 259 | + else: |
| 260 | + optimizer_class = torch.optim.Adam |
250 | 261 | optimizers = { |
251 | | - name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( |
| 262 | + name: optimizer_class( |
252 | 263 | [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], |
253 | 264 | eps=1e-15 / math.sqrt(BS), |
254 | 265 | # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. |
@@ -316,6 +327,7 @@ def __init__( |
316 | 327 | scene_scale=self.scene_scale, |
317 | 328 | sh_degree=cfg.sh_degree, |
318 | 329 | sparse_grad=cfg.sparse_grad, |
| 330 | + visible_adam=cfg.visible_adam, |
319 | 331 | batch_size=cfg.batch_size, |
320 | 332 | feature_dim=feature_dim, |
321 | 333 | device=self.device, |
@@ -739,9 +751,22 @@ def train(self): |
739 | 751 | is_coalesced=len(Ks) == 1, |
740 | 752 | ) |
741 | 753 |
|
| 754 | + if cfg.visible_adam: |
| 755 | + gaussian_cnt = self.splats.means.shape[0] |
| 756 | + if cfg.packed: |
| 757 | + visibility_mask = torch.zeros_like( |
| 758 | + self.splats["opacities"], dtype=bool |
| 759 | + ) |
| 760 | + visibility_mask.scatter_(0, info["gaussian_ids"], 1) |
| 761 | + else: |
| 762 | + visibility_mask = (info["radii"] > 0).any(0) |
| 763 | + |
742 | 764 | # optimize |
743 | 765 | for optimizer in self.optimizers.values(): |
744 | | - optimizer.step() |
| 766 | + if cfg.visible_adam: |
| 767 | + optimizer.step(visibility_mask) |
| 768 | + else: |
| 769 | + optimizer.step() |
745 | 770 | optimizer.zero_grad(set_to_none=True) |
746 | 771 | for optimizer in self.pose_optimizers: |
747 | 772 | optimizer.step() |
|
0 commit comments