|
44 | 44 | from gsplat.distributed import cli |
45 | 45 | from gsplat.rendering import rasterization |
46 | 46 | from gsplat.strategy import DefaultStrategy, MCMCStrategy |
| 47 | +from gsplat.optimizers import SelectiveAdam |
| 48 | + |
47 | 49 | from gsplat.compression_simulation import CompressionSimulation |
48 | 50 | from gsplat.compression_simulation.entropy_model import Entropy_factorized_optimized_refactor, Entropy_gaussian |
49 | 51 |
|
@@ -196,6 +198,8 @@ class Config: |
196 | 198 | packed: bool = False |
197 | 199 | # Use sparse gradients for optimization. (experimental) |
198 | 200 | sparse_grad: bool = False |
| 201 | + # Use visible adam from Taming 3DGS. (experimental) |
| 202 | + visible_adam: bool = False |
199 | 203 | # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. |
200 | 204 | antialiased: bool = False |
201 | 205 |
|
@@ -272,6 +276,7 @@ def create_splats_with_optimizers( |
272 | 276 | scene_scale: float = 1.0, |
273 | 277 | sh_degree: int = 3, |
274 | 278 | sparse_grad: bool = False, |
| 279 | + visible_adam: bool = False, |
275 | 280 | batch_size: int = 1, |
276 | 281 | feature_dim: Optional[int] = None, |
277 | 282 | device: str = "cuda", |
@@ -328,8 +333,15 @@ def create_splats_with_optimizers( |
328 | 333 | # Note that this would not make the training exactly equivalent, see |
329 | 334 | # https://arxiv.org/pdf/2402.18824v1 |
330 | 335 | BS = batch_size * world_size |
| 336 | + optimizer_class = None |
| 337 | + if sparse_grad: |
| 338 | + optimizer_class = torch.optim.SparseAdam |
| 339 | + elif visible_adam: |
| 340 | + optimizer_class = SelectiveAdam |
| 341 | + else: |
| 342 | + optimizer_class = torch.optim.Adam |
331 | 343 | optimizers = { |
332 | | - name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( |
| 344 | + name: optimizer_class( |
333 | 345 | [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], |
334 | 346 | eps=1e-15 / math.sqrt(BS), |
335 | 347 | # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. |
@@ -397,6 +409,7 @@ def __init__( |
397 | 409 | scene_scale=self.scene_scale, |
398 | 410 | sh_degree=cfg.sh_degree, |
399 | 411 | sparse_grad=cfg.sparse_grad, |
| 412 | + visible_adam=cfg.visible_adam, |
400 | 413 | batch_size=cfg.batch_size, |
401 | 414 | feature_dim=feature_dim, |
402 | 415 | device=self.device, |
@@ -926,6 +939,34 @@ def train(self): |
926 | 939 | is_coalesced=len(Ks) == 1, |
927 | 940 | ) |
928 | 941 |
|
| 942 | + if cfg.visible_adam: |
| 943 | + gaussian_cnt = self.splats.means.shape[0] |
| 944 | + if cfg.packed: |
| 945 | + visibility_mask = torch.zeros_like( |
| 946 | + self.splats["opacities"], dtype=bool |
| 947 | + ) |
| 948 | + visibility_mask.scatter_(0, info["gaussian_ids"], 1) |
| 949 | + else: |
| 950 | + visibility_mask = (info["radii"] > 0).any(0) |
| 951 | + |
| 952 | + # optimize |
| 953 | + for optimizer in self.optimizers.values(): |
| 954 | + if cfg.visible_adam: |
| 955 | + optimizer.step(visibility_mask) |
| 956 | + else: |
| 957 | + optimizer.step() |
| 958 | + optimizer.zero_grad(set_to_none=True) |
| 959 | + for optimizer in self.pose_optimizers: |
| 960 | + optimizer.step() |
| 961 | + optimizer.zero_grad(set_to_none=True) |
| 962 | + for optimizer in self.app_optimizers: |
| 963 | + optimizer.step() |
| 964 | + optimizer.zero_grad(set_to_none=True) |
| 965 | + for optimizer in self.bil_grid_optimizers: |
| 966 | + optimizer.step() |
| 967 | + optimizer.zero_grad(set_to_none=True) |
| 968 | + for scheduler in schedulers: |
| 969 | + scheduler.step() |
929 | 970 | # optimize |
930 | 971 | for optimizer in self.optimizers.values(): |
931 | 972 | optimizer.step() |
|
0 commit comments