Skip to content

Commit e7a0d4e

Browse files
viewer now uses dummy camera frustums, new logic to detect mocking (needs reviewing), 3 TODOs
1 parent e05fa3c commit e7a0d4e

File tree

3 files changed

+79
-37
lines changed

3 files changed

+79
-37
lines changed

nerfstudio/scripts/viewer/run_viewer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,18 @@ def _start_viewer(config: TrainerConfig, pipeline: Pipeline, step: int):
8787
pipeline: Pipeline instance of which to load weights
8888
step: Step at which the pipeline was saved
8989
"""
90-
base_dir = config.get_base_dir()
91-
viewer_log_path = base_dir / config.viewer.relative_log_filename
90+
# Check if we're using a shared checkpoint (load_dir is outside the standard experiment structure)
91+
try:
92+
base_dir = config.get_base_dir()
93+
# If get_base_dir() would create a path that doesn't exist, we're likely using a shared checkpoint
94+
if not base_dir.parent.exists():
95+
# Use the checkpoint directory as the base for log files
96+
viewer_log_path = config.load_dir / config.viewer.relative_log_filename
97+
else:
98+
viewer_log_path = base_dir / config.viewer.relative_log_filename
99+
except (FileNotFoundError, OSError):
100+
# Fallback to using the checkpoint directory for shared checkpoints
101+
viewer_log_path = config.load_dir / config.viewer.relative_log_filename
92102
banner_messages = None
93103
viewer_state = None
94104
viewer_callback_lock = Lock()

nerfstudio/utils/eval_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def patch_config_for_mock_data(config: TrainerConfig) -> TrainerConfig:
4949
datamanager_data = getattr(config.pipeline.datamanager, 'data', None)
5050

5151
# The dataparser will use its own data field if it's meaningful, otherwise it inherits from datamanager
52+
# TODO: this is a hack, but I really need to change this
5253
if dataparser_data and str(dataparser_data) != "." and dataparser_data.name != "":
5354
actual_data_path = dataparser_data
5455
else:

nerfstudio/viewer/viewer.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -444,45 +444,76 @@ def init_scene(
444444
# draw the training cameras and images
445445
self.camera_handles: Dict[int, viser.CameraFrustumHandle] = {}
446446
self.original_c2w: Dict[int, np.ndarray] = {}
447-
image_indices = self._pick_drawn_image_idxs(len(train_dataset))
448-
for idx in image_indices:
449-
image = train_dataset[idx]["image"]
450-
camera = train_dataset.cameras[idx]
451-
image_uint8 = (image * 255).detach().type(torch.uint8)
452-
image_uint8 = image_uint8.permute(2, 0, 1)
453-
454-
# torchvision can be slow to import, so we do it lazily.
455-
import torchvision
456-
457-
image_uint8 = torchvision.transforms.functional.resize(image_uint8, 100, antialias=None) # type: ignore
458-
image_uint8 = image_uint8.permute(1, 2, 0)
459-
image_uint8 = image_uint8.cpu().numpy()
460-
c2w = camera.camera_to_worlds.cpu().numpy()
461-
R = vtf.SO3.from_matrix(c2w[:3, :3])
462-
R = R @ vtf.SO3.from_x_radians(np.pi)
463-
camera_handle = self.viser_server.scene.add_camera_frustum(
464-
name=f"/cameras/camera_{idx:05d}",
465-
fov=float(2 * np.arctan((camera.cx / camera.fx[0]).cpu())),
466-
scale=self.config.camera_frustum_scale,
467-
aspect=float((camera.cx[0] / camera.cy[0]).cpu()),
468-
image=image_uint8,
469-
wxyz=R.wxyz,
470-
position=c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO,
471-
)
447+
448+
# Check if we're using mock data - if so, skip image loading to avoid file errors
449+
is_mock_data = (
450+
len(train_dataset) > 0 and
451+
hasattr(train_dataset, '_dataparser_outputs') and
452+
len(train_dataset._dataparser_outputs.image_filenames) > 0 and
453+
str(train_dataset._dataparser_outputs.image_filenames[0]).startswith("mock_image")
454+
)
455+
456+
if is_mock_data:
457+
# For mock data, just draw camera frustums without images
458+
image_indices = self._pick_drawn_image_idxs(len(train_dataset))
459+
for idx in image_indices:
460+
camera = train_dataset.cameras[idx]
461+
c2w = camera.camera_to_worlds.cpu().numpy()
462+
R = vtf.SO3.from_matrix(c2w[:3, :3])
463+
R = R @ vtf.SO3.from_x_radians(np.pi)
464+
camera_handle = self.viser_server.scene.add_camera_frustum(
465+
name=f"/cameras/camera_{idx:05d}",
466+
fov=2 * np.arctan(camera.height / (2 * camera.fy)).item(),
467+
aspect=camera.width / camera.height,
468+
scale=0.1,
469+
color=(255, 255, 255),
470+
wxyz=R.wxyz,
471+
position=c2w[:3, 3],
472+
visible=False,
473+
)
474+
self.camera_handles[idx] = camera_handle
475+
self.original_c2w[idx] = c2w
476+
else:
477+
# Normal image loading for real data
478+
image_indices = self._pick_drawn_image_idxs(len(train_dataset))
479+
for idx in image_indices:
480+
image = train_dataset[idx]["image"]
481+
camera = train_dataset.cameras[idx]
482+
image_uint8 = (image * 255).detach().type(torch.uint8)
483+
image_uint8 = image_uint8.permute(2, 0, 1)
484+
485+
# torchvision can be slow to import, so we do it lazily.
486+
import torchvision
487+
488+
image_uint8 = torchvision.transforms.functional.resize(image_uint8, 100, antialias=None) # type: ignore
489+
image_uint8 = image_uint8.permute(1, 2, 0)
490+
image_uint8 = image_uint8.cpu().numpy()
491+
c2w = camera.camera_to_worlds.cpu().numpy()
492+
R = vtf.SO3.from_matrix(c2w[:3, :3])
493+
R = R @ vtf.SO3.from_x_radians(np.pi)
494+
camera_handle = self.viser_server.scene.add_camera_frustum(
495+
name=f"/cameras/camera_{idx:05d}",
496+
fov=float(2 * np.arctan((camera.cx / camera.fx[0]).cpu())),
497+
scale=self.config.camera_frustum_scale,
498+
aspect=float((camera.cx[0] / camera.cy[0]).cpu()),
499+
image=image_uint8,
500+
wxyz=R.wxyz,
501+
position=c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO,
502+
)
472503

473-
def create_on_click_callback(capture_idx):
474-
def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None:
475-
with event.client.atomic():
476-
event.client.camera.position = event.target.position
477-
event.client.camera.wxyz = event.target.wxyz
478-
self.current_camera_idx = capture_idx
504+
def create_on_click_callback(capture_idx):
505+
def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None:
506+
with event.client.atomic():
507+
event.client.camera.position = event.target.position
508+
event.client.camera.wxyz = event.target.wxyz
509+
self.current_camera_idx = capture_idx
479510

480-
return on_click_callback
511+
return on_click_callback
481512

482-
camera_handle.on_click(create_on_click_callback(idx))
513+
camera_handle.on_click(create_on_click_callback(idx))
483514

484-
self.camera_handles[idx] = camera_handle
485-
self.original_c2w[idx] = c2w
515+
self.camera_handles[idx] = camera_handle
516+
self.original_c2w[idx] = c2w
486517

487518
self.train_state = train_state
488519
self.train_util = 0.9

0 commit comments

Comments
 (0)