Skip to content

Commit 15b6a80

Browse files
authored
Merge pull request #13 from nerfstudio-project/v0.1
bump to v0.1
2 parents eb6c3e6 + 2490299 commit 15b6a80

File tree

11 files changed

+1638
-116
lines changed

11 files changed

+1638
-116
lines changed

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,19 @@ import nerfview
5454

5555

5656
def render_fn(
57-
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
57+
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
5858
) -> np.ndarray:
5959
# Parse camera state for camera-to-world matrix (c2w) and intrinsic (K) as
6060
# float64 numpy arrays.
61+
if render_tab_state.preview_render:
62+
width = render_tab_state.render_width
63+
height = render_tab_state.render_height
64+
else:
65+
width = render_tab_state.viewer_width
66+
height = render_tab_state.viewer_height
67+
6168
c2w = camera_state.c2w
62-
K = camera_state.get_K(img_wh)
69+
K = camera_state.get_K([width, height])
6370
# Do your things and get an image as a uint8 numpy array.
6471
img = your_rendering_logic(...)
6572
return img
@@ -139,7 +146,7 @@ which we include here to be self-contained.
139146
# Only need to run once the first time.
140147
bash examples/assets/download_gsplat_ckpt.sh
141148
CUDA_VISIBLE_DEVICES=0 python examples/03_gsplat_rendering.py \
142-
--ckpt results/garden/ckpts/ckpt_6999_crop.pt
149+
--ckpt examples/assets/ckpt_6999_crop.pt
143150
```
144151

145152
</details>

examples/00_dummy_rendering.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,26 @@ def main(port: int = 8080, rendering_latency: float = 0.0):
2424
"""
2525

2626
def render_fn(
27-
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
27+
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
2828
) -> UInt8[np.ndarray, "H W 3"]:
2929
# Get camera parameters.
30-
W, H = img_wh
30+
if render_tab_state.preview_render:
31+
width = render_tab_state.render_width
32+
height = render_tab_state.render_height
33+
else:
34+
width = render_tab_state.viewer_width
35+
height = render_tab_state.viewer_height
3136
c2w = camera_state.c2w
32-
K = camera_state.get_K(img_wh)
37+
K = camera_state.get_K([width, height])
3338

3439
# Render a dummy image as a function of camera direction.
3540
camera_dirs = np.einsum(
3641
"ij,hwj->hwi",
3742
np.linalg.inv(K),
3843
np.pad(
39-
np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), -1)
44+
np.stack(
45+
np.meshgrid(np.arange(width), np.arange(height), indexing="xy"), -1
46+
)
4047
+ 0.5,
4148
((0, 0), (0, 0), (0, 1)),
4249
constant_values=1.0,

examples/01_dummy_training.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,26 @@ def main(port: int = 8080, max_steps: int = 50, rendering_latency: float = 0.0):
2929
step: int = 0
3030

3131
def render_fn(
32-
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
32+
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
3333
) -> UInt8[np.ndarray, "H W 3"]:
3434
# Get camera parameters.
35-
W, H = img_wh
35+
if render_tab_state.preview_render:
36+
width = render_tab_state.render_width
37+
height = render_tab_state.render_height
38+
else:
39+
width = render_tab_state.viewer_width
40+
height = render_tab_state.viewer_height
3641
c2w = camera_state.c2w
37-
K = camera_state.get_K(img_wh)
42+
K = camera_state.get_K([width, height])
3843

3944
# Render a dummy image as a function of camera direction.
4045
camera_dirs = np.einsum(
4146
"ij,hwj->hwi",
4247
np.linalg.inv(K),
4348
np.pad(
44-
np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), -1)
49+
np.stack(
50+
np.meshgrid(np.arange(width), np.arange(height), indexing="xy"), -1
51+
)
4552
+ 0.5,
4653
((0, 0), (0, 0), (0, 1)),
4754
constant_values=1.0,
@@ -83,11 +90,11 @@ def training_step():
8390
# Optionally make the training utility lower such that we update the scene
8491
# more frequently in this example. You dont need to do this in your own
8592
# code.
86-
viewer._train_util_slider.value = 0.5
93+
viewer._training_tab_handles["train_util_slider"].value = 0.5
8794

8895
for step in tqdm(range(max_steps)):
8996
# Allow user to pause the training process.
90-
while viewer.state.status == "paused":
97+
while viewer.state == "paused":
9198
time.sleep(0.01)
9299
# Do the training step and compute the number of training rays per second.
93100
tic = time.time()
@@ -96,7 +103,7 @@ def training_step():
96103
num_train_steps_per_sec = 1.0 / (time.time() - tic)
97104
num_train_rays_per_sec = num_train_rays_per_step * num_train_steps_per_sec
98105
# Update the viewer state.
99-
viewer.state.num_train_rays_per_sec = num_train_rays_per_sec
106+
viewer.render_tab_state.num_train_rays_per_sec = num_train_rays_per_sec
100107
# Update the scene.
101108
viewer.update(step, num_train_rays_per_step)
102109
viewer.complete()

examples/02_mesh_rendering.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,21 @@ def main(port: int = 8080):
172172
)
173173

174174
def render_fn(
175-
camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
175+
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
176176
) -> UInt8[np.ndarray, "H W 3"]:
177-
# nvdiffrast requires the image size to be multiples of 8.
178-
img_wh = (img_wh[0] // 8 * 8, img_wh[1] // 8 * 8)
179-
180177
# Get camera parameters.
178+
if render_tab_state.preview_render:
179+
width = render_tab_state.render_width
180+
height = render_tab_state.render_height
181+
else:
182+
width = render_tab_state.viewer_width
183+
height = render_tab_state.viewer_height
184+
185+
# nvdiffrast requires the image size to be multiples of 8.
186+
width = width // 8 * 8
187+
height = height // 8 * 8
181188
c2w = camera_state.c2w
189+
img_wh = [width, height]
182190
K = camera_state.get_K(img_wh)
183191

184192
# Compute the normal map.

examples/03_gsplat_rendering.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,17 @@
135135

136136
# register and open viewer
137137
@torch.no_grad()
138-
def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]):
139-
width, height = img_wh
138+
def viewer_render_fn(
139+
camera_state: nerfview.CameraState, render_tab_state: nerfview.RenderTabState
140+
):
141+
if render_tab_state.preview_render:
142+
width = render_tab_state.render_width
143+
height = render_tab_state.render_height
144+
else:
145+
width = render_tab_state.viewer_width
146+
height = render_tab_state.viewer_height
140147
c2w = camera_state.c2w
141-
K = camera_state.get_K(img_wh)
148+
K = camera_state.get_K([width, height])
142149
c2w = torch.from_numpy(c2w).float().to(device)
143150
K = torch.from_numpy(K).float().to(device)
144151
viewmat = c2w.inverse()

examples/04_gsplat_training.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -916,20 +916,31 @@ def render_traj(self, step: int):
916916

917917
@torch.no_grad()
918918
def _viewer_render_fn(
919-
self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
919+
self,
920+
camera_state: nerfview.CameraState,
921+
render_tab_state: nerfview.RenderTabState,
920922
):
921923
"""Callable function for the viewer."""
922-
W, H = img_wh
924+
if render_tab_state.preview_render:
925+
width, height = (
926+
render_tab_state.render_width,
927+
render_tab_state.render_height,
928+
)
929+
else:
930+
width, height = (
931+
render_tab_state.viewer_width,
932+
render_tab_state.viewer_height,
933+
)
923934
c2w = camera_state.c2w
924-
K = camera_state.get_K(img_wh)
935+
K = camera_state.get_K([width, height])
925936
c2w = torch.from_numpy(c2w).float().to(self.device)
926937
K = torch.from_numpy(K).float().to(self.device)
927938

928939
render_colors, _, _ = self.rasterize_splats(
929940
camtoworlds=c2w[None],
930941
Ks=K[None],
931-
width=W,
932-
height=H,
942+
width=width,
943+
height=height,
933944
sh_degree=self.cfg.sh_degree, # active all SH degrees
934945
radius_clip=3.0, # skip GSs that have small image radius (in pixels)
935946
backgrounds=torch.ones(1, 3, device=self.device),

nerfview/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
from .render_panel import RenderTabState
12
from .version import __version__
23
from .viewer import VIEWER_LOCK, CameraState, Viewer, with_viewer_lock

nerfview/_renderer.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Modified from nerfview/_renderer.py
3+
"""
4+
15
import dataclasses
26
import os
37
import sys
@@ -9,7 +13,7 @@
913
import viser
1014

1115
if TYPE_CHECKING:
12-
from .viewer import CameraState, Viewer
16+
from examples.viewer import CameraState, Viewer
1317

1418
RenderState = Literal["low_move", "low_static", "high"]
1519
RenderAction = Literal["rerender", "move", "static", "update"]
@@ -53,14 +57,15 @@ def __init__(
5357
self.lock = lock
5458

5559
self.running = True
56-
self.is_prepared_fn = lambda: self.viewer.state.status != "preparing"
60+
self.is_prepared_fn = lambda: self.viewer.state != "preparing"
5761

5862
self._render_event = threading.Event()
5963
self._state: RenderState = "low_static"
6064
self._task: Optional[RenderTask] = None
6165

6266
self._target_fps = 30
6367
self._may_interrupt_render = False
68+
self._old_version = False
6469

6570
self._define_transitions()
6671

@@ -84,16 +89,17 @@ def _may_interrupt_trace(self, frame, event, arg):
8489
return self._may_interrupt_trace
8590

8691
def _get_img_wh(self, aspect: float) -> Tuple[int, int]:
87-
max_img_res = self.viewer._max_img_res_slider.value
88-
if self._state == "high":
92+
# we always trade off speed for quality
93+
max_img_res = self.viewer.render_tab_state.viewer_res
94+
if self._state in ["high"]:
8995
# if True:
9096
H = max_img_res
9197
W = int(H * aspect)
9298
if W > max_img_res:
9399
W = max_img_res
94100
H = int(W / aspect)
95101
elif self._state in ["low_move", "low_static"]:
96-
num_view_rays_per_sec = self.viewer.state.num_view_rays_per_sec
102+
num_view_rays_per_sec = self.viewer.render_tab_state.num_view_rays_per_sec
97103
target_fps = self._target_fps
98104
num_viewer_rays = num_view_rays_per_sec / target_fps
99105
H = (num_viewer_rays / aspect) ** 0.5
@@ -141,13 +147,31 @@ def run(self):
141147
with self.lock, set_trace_context(self._may_interrupt_trace):
142148
tic = time.time()
143149
W, H = img_wh = self._get_img_wh(task.camera_state.aspect)
144-
rendered = self.viewer.render_fn(task.camera_state, img_wh)
150+
self.viewer.render_tab_state.viewer_width = W
151+
self.viewer.render_tab_state.viewer_height = H
152+
153+
if not self._old_version:
154+
try:
155+
rendered = self.viewer.render_fn(
156+
task.camera_state,
157+
self.viewer.render_tab_state,
158+
)
159+
except TypeError:
160+
self._old_version = True
161+
print(
162+
"[WARNING] Your API will be deprecated in the future, please update your render_fn."
163+
)
164+
rendered = self.viewer.render_fn(task.camera_state, img_wh)
165+
else:
166+
rendered = self.viewer.render_fn(task.camera_state, img_wh)
167+
168+
self.viewer._after_render()
145169
if isinstance(rendered, tuple):
146170
img, depth = rendered
147171
else:
148172
img, depth = rendered, None
149-
self.viewer.state.num_view_rays_per_sec = (W * H) / (
150-
max(time.time() - tic, 1e-6)
173+
self.viewer.render_tab_state.num_view_rays_per_sec = (W * H) / (
174+
time.time() - tic
151175
)
152176
except InterruptRenderException:
153177
continue
@@ -160,4 +184,3 @@ def run(self):
160184
jpeg_quality=70 if task.action in ["static", "update"] else 40,
161185
depth=depth,
162186
)
163-
self.client.flush()

0 commit comments

Comments
 (0)