Skip to content

Commit 8d3c54b

Browse files
committed
Merge branch 'viewer_for_dyn_scene' into main
2 parents b19780c + 25e2314 commit 8d3c54b

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

examples/simple_viewer_dyn.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
"""A simple example to render a (large-scale) Gaussian Splats
2+
3+
```bash
4+
python examples/simple_viewer.py --scene_grid 13
5+
```
6+
"""
7+
8+
import argparse
9+
import math
10+
import os
11+
import time
12+
from typing import Tuple
13+
14+
import imageio
15+
import nerfview
16+
import numpy as np
17+
import torch
18+
import torch.nn.functional as F
19+
import tqdm
20+
import viser
21+
22+
from gsplat._helper import load_test_data
23+
from gsplat.distributed import cli
24+
from gsplat.rendering import rasterization
25+
26+
def trbfunction(x):
27+
return torch.exp(-1*x.pow(2))
28+
29+
def qvec2rotmat(qvec):
30+
return np.array(
31+
[
32+
[
33+
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
34+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
35+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
36+
],
37+
[
38+
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
39+
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
40+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
41+
],
42+
[
43+
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
44+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
45+
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
46+
],
47+
]
48+
)
49+
50+
def get_c2w(camera):
51+
c2w = np.eye(4, dtype=np.float32)
52+
c2w[:3, :3] = qvec2rotmat(camera.wxyz)
53+
c2w[:3, 3] = camera.position
54+
return c2w
55+
56+
def get_w2c(camera):
57+
c2w = get_c2w(camera)
58+
w2c = np.linalg.inv(c2w)
59+
return w2c
60+
61+
62+
class DynGSRenderer:
63+
def __init__(self, args):
64+
65+
splats = torch.load(args.ckpt[0], map_location="cuda")["splats"]
66+
67+
self.means = splats["means"] # [N, 3]
68+
self.quats = splats["quats"] # [N, 4]
69+
self.scales = torch.exp(splats["scales"]) # [N, 3]
70+
self.opacities = torch.sigmoid(splats["opacities"]) # [N,]
71+
72+
self.trbfcenter = splats["trbf_center"] # [N, 1]
73+
self.trbfscale = torch.exp(splats["trbf_scale"]) # [N, 1]
74+
75+
self.motion = splats["motion"] # [N, 9]
76+
self.omega = splats["omega"] # [N, 4]
77+
self.feature_color = splats["colors"] # [N, 3]
78+
self.feature_dir = splats["features_dir"] # [N, 3]
79+
self.feature_time = splats["features_time"] # [N, 3]
80+
81+
self.device = self.means.device
82+
83+
if args.backend == "gsplat":
84+
self.rasterization_fn = rasterization
85+
86+
def slice_dyngs_to_3dgs(self, timestamp):
87+
pointtimes = torch.ones((self.means.shape[0],1), dtype=self.means.dtype, requires_grad=False, device="cuda") + 0 #
88+
timestamp = timestamp
89+
90+
trbfdistanceoffset = timestamp * pointtimes - self.trbfcenter
91+
trbfdistance = trbfdistanceoffset / (math.sqrt(2) * self.trbfscale)
92+
trbfoutput = trbfunction(trbfdistance)
93+
94+
# opacity decay
95+
opacity = self.opacities * trbfoutput.squeeze()
96+
97+
tforpoly = trbfdistanceoffset.detach()
98+
# Calculate Polynomial Motion Trajectory
99+
means_motion = self.means + self.motion[:, 0:3] * tforpoly + self.motion[:, 3:6] * tforpoly * tforpoly + self.motion[:, 6:9] * tforpoly *tforpoly * tforpoly
100+
# Calculate rotations
101+
rotations = torch.nn.functional.normalize(self.quats + tforpoly * self.omega)
102+
103+
# Calculate feature
104+
# colors_precomp = torch.cat((feature_color, feature_dir, tforpoly * feature_time), dim=1)
105+
colors_precomp = self.feature_color
106+
107+
return means_motion, rotations, self.scales, opacity, colors_precomp
108+
109+
def render(self, cameraHandle, timestamp=0, ):
110+
means_t, quats_t, scales_t, opa_t, colors_t = self.slice_dyngs_to_3dgs(timestamp)
111+
112+
c2w = get_c2w(cameraHandle)
113+
c2w = torch.from_numpy(c2w).float().to(self.device)
114+
viewmat = c2w.inverse()
115+
116+
W, H = 1920, 1080
117+
focal_length = H / 2.0 / np.tan(cameraHandle.fov / 2.0)
118+
K = np.array(
119+
[
120+
[focal_length, 0.0, W / 2.0],
121+
[0.0, focal_length, H / 2.0],
122+
[0.0, 0.0, 1.0],
123+
]
124+
)
125+
K = torch.from_numpy(K).float().to(self.device)
126+
127+
render_colors, render_alphas, meta = self.rasterization_fn(
128+
means_t, # [N, 3]
129+
quats_t, # [N, 4]
130+
scales_t, # [N, 3]
131+
opa_t, # [N]
132+
colors_t, # [N, S, 3]
133+
viewmat[None], # [1, 4, 4]
134+
K[None], # [1, 3, 3]
135+
W,
136+
H,
137+
# sh_degree=sh_degree,
138+
render_mode="RGB",
139+
# this is to speedup large-scale rendering by skipping far-away Gaussians.
140+
# radius_clip=3,
141+
)
142+
143+
render_rgbs = render_colors[0, ..., 0:3].cpu().numpy()
144+
return render_rgbs
145+
146+
class ViserViewer:
147+
def __init__(self, port):
148+
self.port = port
149+
self.server = viser.ViserServer(port=port)
150+
151+
self.need_update = False
152+
153+
with self.server.gui.add_folder("Playback"):
154+
self.gui_playing = self.server.gui.add_checkbox("Playing", True)
155+
self.timestamp = self.server.add_slider(
156+
"Timestamp", min=0, max=49, step=1, initial_value=0
157+
)
158+
self.gui_next_frame = self.server.gui.add_button("Next Frame", disabled=True)
159+
self.gui_prev_frame = self.server.gui.add_button("Prev Frame", disabled=True)
160+
161+
@self.gui_playing.on_update
162+
def _(_) -> None:
163+
self.timestamp.disabled = self.gui_playing.value
164+
self.gui_next_frame.disabled = self.gui_playing.value
165+
self.gui_prev_frame.disabled = self.gui_playing.value
166+
167+
@self.timestamp.on_update
168+
def _(_):
169+
self.need_update = True
170+
171+
# Frame step buttons.
172+
@self.gui_next_frame.on_click
173+
def _(_) -> None:
174+
self.timestamp.value = (self.timestamp.value + 1) % 50
175+
176+
@self.gui_prev_frame.on_click
177+
def _(_) -> None:
178+
self.timestamp.value = (self.timestamp.value - 1) % 50
179+
180+
@self.server.on_client_connect
181+
def _(client: viser.ClientHandle):
182+
@client.camera.on_update
183+
def _(_):
184+
self.need_update = True
185+
186+
# self.scene_rep = DynGSRenderer(args)
187+
188+
def set_scene_rep(self, scene_rep):
189+
self.scene_rep = scene_rep
190+
191+
def render(self, camera, timestamp):
192+
return self.scene_rep.render(camera, timestamp)
193+
194+
@torch.no_grad()
195+
def update(self):
196+
if self.need_update:
197+
start = time.time()
198+
for client in self.server.get_clients().values():
199+
camera = client.camera
200+
timestamp = self.timestamp.value / 50
201+
# w2c = get_w2c(camera)
202+
try:
203+
# W = 1920
204+
# H = int(W/camera.aspect)
205+
# focal_x = W/2/np.tan(camera.fov/2)
206+
# focal_y = H/2/np.tan(camera.fov/2)
207+
208+
start_cuda = torch.cuda.Event(enable_timing=True)
209+
end_cuda = torch.cuda.Event(enable_timing=True)
210+
start_cuda.record()
211+
212+
out = self.render(camera, timestamp)
213+
214+
end_cuda.record()
215+
torch.cuda.synchronize()
216+
interval = start_cuda.elapsed_time(end_cuda)/1000.
217+
218+
except RuntimeError as e:
219+
print(e)
220+
interval = 1
221+
continue
222+
223+
client.set_background_image(out, format="jpeg")
224+
225+
# self.need_update = False
226+
227+
def main(local_rank: int, world_rank, world_size: int, args):
228+
torch.manual_seed(42)
229+
device = torch.device("cuda", local_rank)
230+
231+
dyn_gs = DynGSRenderer(args)
232+
233+
gui = ViserViewer(port=8080)
234+
235+
gui.set_scene_rep(dyn_gs)
236+
237+
while(True):
238+
if gui.gui_playing.value:
239+
gui.timestamp.value = (gui.timestamp.value + 1) % 50
240+
gui.update()
241+
242+
243+
if __name__ == "__main__":
244+
"""
245+
# Use single GPU to view the scene
246+
CUDA_VISIBLE_DEVICES=0 python simple_viewer.py \
247+
--ckpt results/garden/ckpts/ckpt_3499_rank0.pt results/garden/ckpts/ckpt_3499_rank1.pt \
248+
--port 8081
249+
"""
250+
parser = argparse.ArgumentParser()
251+
parser.add_argument(
252+
"--output_dir", type=str, default="results/", help="where to dump outputs"
253+
)
254+
parser.add_argument(
255+
"--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN"
256+
)
257+
parser.add_argument(
258+
"--ckpt", type=str, nargs="+", default=None, help="path to the .pt file"
259+
)
260+
parser.add_argument(
261+
"--port", type=int, default=8080, help="port for the viewer server"
262+
)
263+
parser.add_argument("--backend", type=str, default="gsplat", help="gsplat, inria")
264+
args = parser.parse_args()
265+
assert args.scene_grid % 2 == 1, "scene_grid must be odd"
266+
267+
cli(main, args, verbose=True)

0 commit comments

Comments
 (0)