Skip to content

Commit a6dc9f5

Browse files
committed
using mse loss and distance transform (#73)
1 parent 4ef627c commit a6dc9f5

File tree

8 files changed

+122
-147
lines changed

8 files changed

+122
-147
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ The camera swarm optimization can serve for finding an initial guess to [Monocul
157157

158158
```shell
159159
rr-cam-swarm \
160+
--collision-meshes \
160161
--n-cameras 1000 \
161162
--min-distance 0.5 \
162163
--max-distance 3.0 \

roboreg/cli/rr_cam_swarm.py

Lines changed: 23 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import argparse
22
import os
3-
from typing import Tuple
43

54
import cv2
65
import numpy as np
7-
import rich
8-
import rich.progress
96
import torch
107

118
from roboreg import differentiable as rrd
12-
from roboreg.io import find_files, parse_camera_info
9+
from roboreg.io import find_files, parse_camera_info, parse_mono_data
1310
from roboreg.losses import soft_dice_loss
1411
from roboreg.optim import LinearParticleSwarm, ParticleSwarmOptimizer
1512
from roboreg.util import (
1613
look_at_from_angle,
17-
mask_exponential_distance_transform,
14+
mask_exponential_decay,
1815
overlay_mask,
1916
random_fov_eye_space_coordinates,
2017
)
@@ -173,58 +170,6 @@ def args_factory() -> argparse.Namespace:
173170
return parser.parse_args()
174171

175172

176-
def parse_data(
177-
path: str,
178-
image_pattern: str,
179-
mask_pattern: str,
180-
joint_states_pattern: str,
181-
n_samples: int = 5,
182-
device: str = "cuda",
183-
) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]:
184-
image_files = find_files(path, image_pattern)
185-
mask_files = find_files(path, mask_pattern)
186-
joint_states_files = find_files(path, joint_states_pattern)
187-
188-
rich.print("Found the following files:")
189-
rich.print(f"Images: {image_files}")
190-
rich.print(f"Masks: {mask_files}")
191-
rich.print(f"Joint states: {joint_states_files}")
192-
193-
# randomly sample n_samples
194-
if n_samples > len(image_files):
195-
n_samples = len(image_files)
196-
random_indices = np.random.choice(len(image_files), n_samples, replace=False)
197-
image_files = np.array(image_files)[random_indices].tolist()
198-
mask_files = np.array(mask_files)[random_indices].tolist()
199-
joint_states_files = np.array(joint_states_files)[random_indices].tolist()
200-
201-
rich.print(f"Randomly sampled the following {n_samples} files:")
202-
rich.print(f"Images: {image_files}")
203-
rich.print(f"Masks: {mask_files}")
204-
rich.print(f"Joint states: {joint_states_files}")
205-
206-
if len(mask_files) != len(joint_states_files):
207-
raise ValueError("Number of masks and joint states do not match.")
208-
209-
images = [
210-
cv2.imread(os.path.join(path, file), cv2.IMREAD_COLOR) for file in image_files
211-
]
212-
masks = [
213-
mask_exponential_distance_transform(
214-
cv2.imread(os.path.join(path, file), cv2.IMREAD_GRAYSCALE)
215-
)
216-
for file in mask_files
217-
]
218-
joint_states = [np.load(os.path.join(path, file)) for file in joint_states_files]
219-
220-
masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device)
221-
joint_states = torch.tensor(
222-
np.array(joint_states), dtype=torch.float32, device=device
223-
)
224-
225-
return images, joint_states, masks
226-
227-
228173
def instantiate_particles(
229174
n_particles: int,
230175
height: int,
@@ -293,15 +238,30 @@ def main() -> None:
293238
height, width, intrinsics = parse_camera_info(
294239
camera_info_file=args.camera_info_file
295240
)
296-
images, joint_states, masks = parse_data(
241+
image_files = find_files(args.path, args.image_pattern)
242+
mask_files = find_files(args.path, args.mask_pattern)
243+
joint_states_files = find_files(args.path, args.joint_states_pattern)
244+
n_samples = args.n_samples
245+
if n_samples > len(image_files): # randomly sample n_samples
246+
n_samples = len(image_files)
247+
random_indices = np.random.choice(len(image_files), n_samples, replace=False)
248+
image_files = np.array(image_files)[random_indices].tolist()
249+
mask_files = np.array(mask_files)[random_indices].tolist()
250+
joint_states_files = np.array(joint_states_files)[random_indices].tolist()
251+
images, joint_states, masks = parse_mono_data(
297252
path=args.path,
298-
image_pattern=args.image_pattern,
299-
mask_pattern=args.mask_pattern,
300-
joint_states_pattern=args.joint_states_pattern,
301-
n_samples=args.n_samples,
302-
device=device,
253+
image_files=image_files,
254+
mask_files=mask_files,
255+
joint_states_files=joint_states_files,
256+
)
257+
258+
# pre-process data
259+
joint_states = torch.tensor(
260+
np.array(joint_states), dtype=torch.float32, device=device
303261
)
304262
n_joint_states = joint_states.shape[0]
263+
masks = [mask_exponential_decay(mask) for mask in masks]
264+
masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device)
305265

306266
# scale image data (memory reduction)
307267
height = int(height * args.scale)

roboreg/cli/rr_mono_dr.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
import rich.progress
1010
import torch
1111

12-
from roboreg.io import find_files, parse_mono_dr_data
13-
from roboreg.losses import soft_dice_loss
14-
from roboreg.util import mask_exponential_distance_transform, overlay_mask
12+
from roboreg.io import find_files, parse_mono_data
13+
from roboreg.util import mask_distance_transform, overlay_mask
1514
from roboreg.util.factories import create_robot_scene, create_virtual_camera
1615

1716

@@ -49,12 +48,6 @@ def args_factory() -> argparse.Namespace:
4948
default=1.0,
5049
help="Gamma for the learning rate scheduler.",
5150
)
52-
parser.add_argument(
53-
"--sigma",
54-
type=float,
55-
default=2.0,
56-
help="Sigma for the exponential distance transform on target masks.",
57-
)
5851
parser.add_argument(
5952
"--display-progress",
6053
action="store_true",
@@ -144,7 +137,7 @@ def main() -> None:
144137
image_files = find_files(args.path, args.image_pattern)
145138
joint_states_files = find_files(args.path, args.joint_states_pattern)
146139
mask_files = find_files(args.path, args.mask_pattern)
147-
images, joint_states, masks = parse_mono_dr_data(
140+
images, joint_states, masks = parse_mono_data(
148141
path=args.path,
149142
image_files=image_files,
150143
joint_states_files=joint_states_files,
@@ -155,12 +148,10 @@ def main() -> None:
155148
joint_states = torch.tensor(
156149
np.array(joint_states), dtype=torch.float32, device=device
157150
)
158-
masks = [
159-
mask_exponential_distance_transform(mask, sigma=args.sigma) for mask in masks
160-
]
161-
masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device).unsqueeze(
162-
-1
163-
)
151+
distance_maps = [mask_distance_transform(mask) for mask in masks]
152+
distance_maps = torch.tensor(
153+
np.array(distance_maps), dtype=torch.float32, device=device
154+
).unsqueeze(-1)
164155

165156
# instantiate camera with default identity extrinsics because we optimize for robot pose instead
166157
camera = {
@@ -213,7 +204,7 @@ def main() -> None:
213204
renders = {
214205
"camera": scene.observe_from("camera"),
215206
}
216-
loss = soft_dice_loss(renders["camera"], masks).mean()
207+
loss = torch.nn.functional.mse_loss(distance_maps, renders["camera"])
217208
optimizer.zero_grad()
218209
loss.backward()
219210
optimizer.step()
@@ -240,15 +231,15 @@ def main() -> None:
240231
# difference left / right render / mask
241232
difference = (
242233
cv2.cvtColor(
243-
np.abs(render - masks[0].squeeze().cpu().numpy()),
234+
np.abs(render - masks[0].astype(np.float32) / 255.0),
244235
cv2.COLOR_GRAY2BGR,
245236
)
246237
* 255.0
247238
).astype(np.uint8)
248239
# overlay segmentation mask
249240
segmentation_overlay = overlay_mask(
250241
image,
251-
(masks[0].squeeze().cpu().numpy() * 255.0).astype(np.uint8),
242+
masks[0],
252243
mode="b",
253244
scale=1.0,
254245
)
@@ -277,7 +268,7 @@ def main() -> None:
277268
for i, render in enumerate(renders):
278269
render = render.squeeze().cpu().numpy()
279270
overlay = overlay_mask(images[i], (render * 255.0).astype(np.uint8), scale=1.0)
280-
difference = np.abs(render - masks[i].squeeze().cpu().numpy())
271+
difference = np.abs(render - masks[i].astype(np.float32) / 255.0)
281272

282273
cv2.imwrite(os.path.join(args.path, f"dr_overlay_{i}.png"), overlay)
283274
cv2.imwrite(

roboreg/cli/rr_stereo_dr.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
import rich.progress
1010
import torch
1111

12-
from roboreg.io import find_files, parse_stereo_dr_data
13-
from roboreg.losses import soft_dice_loss
14-
from roboreg.util import mask_exponential_distance_transform, overlay_mask
12+
from roboreg.io import find_files, parse_stereo_data
13+
from roboreg.util import mask_distance_transform, overlay_mask
1514
from roboreg.util.factories import create_robot_scene, create_virtual_camera
1615

1716

@@ -49,12 +48,6 @@ def args_factory() -> argparse.Namespace:
4948
default=1.0,
5049
help="Gamma for the learning rate scheduler.",
5150
)
52-
parser.add_argument(
53-
"--sigma",
54-
type=float,
55-
default=2.0,
56-
help="Sigma for the exponential distance transform on target masks.",
57-
)
5851
parser.add_argument(
5952
"--display-progress",
6053
action="store_true",
@@ -177,7 +170,7 @@ def main() -> None:
177170
left_mask_files = find_files(args.path, args.left_mask_pattern)
178171
right_mask_files = find_files(args.path, args.right_mask_pattern)
179172
left_images, right_images, joint_states, left_masks, right_masks = (
180-
parse_stereo_dr_data(
173+
parse_stereo_data(
181174
path=args.path,
182175
left_image_files=left_image_files,
183176
right_image_files=right_image_files,
@@ -191,19 +184,13 @@ def main() -> None:
191184
joint_states = torch.tensor(
192185
np.array(joint_states), dtype=torch.float32, device=device
193186
)
194-
left_masks = [
195-
mask_exponential_distance_transform(mask, sigma=args.sigma)
196-
for mask in left_masks
197-
]
198-
right_masks = [
199-
mask_exponential_distance_transform(mask, sigma=args.sigma)
200-
for mask in right_masks
201-
]
202-
left_masks = torch.tensor(
203-
np.array(left_masks), dtype=torch.float32, device=device
187+
left_distance_maps = [mask_distance_transform(mask) for mask in left_masks]
188+
right_distance_maps = [mask_distance_transform(mask) for mask in right_masks]
189+
left_distance_maps = torch.tensor(
190+
np.array(left_distance_maps), dtype=torch.float32, device=device
204191
).unsqueeze(-1)
205-
right_masks = torch.tensor(
206-
np.array(right_masks), dtype=torch.float32, device=device
192+
right_distance_maps = torch.tensor(
193+
np.array(right_distance_maps), dtype=torch.float32, device=device
207194
).unsqueeze(-1)
208195

209196
# instantiate:
@@ -265,10 +252,9 @@ def main() -> None:
265252
"left": scene.observe_from("left"),
266253
"right": scene.observe_from("right"),
267254
}
268-
loss = (
269-
soft_dice_loss(renders["left"], left_masks).mean()
270-
+ soft_dice_loss(renders["right"], right_masks).mean()
271-
)
255+
loss = torch.nn.functional.mse_loss(
256+
left_distance_maps, renders["left"]
257+
) + torch.nn.functional.mse_loss(right_distance_maps, renders["right"])
272258
optimizer.zero_grad()
273259
loss.backward()
274260
optimizer.step()
@@ -309,7 +295,7 @@ def main() -> None:
309295
differences.append(
310296
(
311297
cv2.cvtColor(
312-
np.abs(left_render - left_masks[0].squeeze().cpu().numpy()),
298+
np.abs(left_render - left_masks[0].astype(np.float32) / 255.0),
313299
cv2.COLOR_GRAY2BGR,
314300
)
315301
* 255.0
@@ -318,7 +304,9 @@ def main() -> None:
318304
differences.append(
319305
(
320306
cv2.cvtColor(
321-
np.abs(right_render - right_masks[0].squeeze().cpu().numpy()),
307+
np.abs(
308+
right_render - right_masks[0].astype(np.float32) / 255.0
309+
),
322310
cv2.COLOR_GRAY2BGR,
323311
)
324312
* 255.0
@@ -329,15 +317,15 @@ def main() -> None:
329317
segmentation_overlays.append(
330318
overlay_mask(
331319
left_image,
332-
(left_masks[0].squeeze().cpu().numpy() * 255.0).astype(np.uint8),
320+
left_masks[0],
333321
mode="b",
334322
scale=1.0,
335323
)
336324
)
337325
segmentation_overlays.append(
338326
overlay_mask(
339327
right_image,
340-
(right_masks[0].squeeze().cpu().numpy() * 255.0).astype(np.uint8),
328+
right_masks[0],
341329
mode="b",
342330
scale=1.0,
343331
)
@@ -378,8 +366,10 @@ def main() -> None:
378366
right_overlay = overlay_mask(
379367
right_images[i], (right_render * 255.0).astype(np.uint8), scale=1.0
380368
)
381-
left_difference = np.abs(left_render - left_masks[i].squeeze().cpu().numpy())
382-
right_difference = np.abs(right_render - right_masks[i].squeeze().cpu().numpy())
369+
left_difference = np.abs(left_render - left_masks[i].astype(np.float32) / 255.0)
370+
right_difference = np.abs(
371+
right_render - right_masks[i].astype(np.float32) / 255.0
372+
)
383373

384374
cv2.imwrite(os.path.join(args.path, f"left_dr_overlay_{i}.png"), left_overlay)
385375
cv2.imwrite(os.path.join(args.path, f"right_dr_overlay_{i}.png"), right_overlay)

roboreg/io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ def parse_hydra_data(
371371
return joint_states, masks, depths
372372

373373

374-
def parse_mono_dr_data(
374+
def parse_mono_data(
375375
path: str,
376376
image_files: List[str],
377377
joint_states_files: List[str],
378378
mask_files: List[str],
379379
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
380-
r"""Parse data for monocular differentiable rendering.
380+
r"""Parse monocular data.
381381
382382
Args:
383383
path (str): Path to the data.
@@ -424,7 +424,7 @@ def parse_mono_dr_data(
424424
return images, joint_states, masks
425425

426426

427-
def parse_stereo_dr_data(
427+
def parse_stereo_data(
428428
path: str,
429429
left_image_files: List[str],
430430
right_image_files: List[str],
@@ -438,7 +438,7 @@ def parse_stereo_dr_data(
438438
List[np.ndarray],
439439
List[np.ndarray],
440440
]:
441-
r"""Parse data for stereo differentiable rendering.
441+
r"""Parse stereo data.
442442
443443
Args:
444444
path (str): Path to the data.

0 commit comments

Comments
 (0)