Skip to content

Commit 5e3341e

Browse files
authored
Merge pull request #73 from lbr-stack/dev-dr-robustness-prod
Improve DR robustness
2 parents 4bbdb51 + a6dc9f5 commit 5e3341e

File tree

13 files changed

+432
-315
lines changed

13 files changed

+432
-315
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ The camera swarm optimization can serve for finding an initial guess to [Monocul
157157

158158
```shell
159159
rr-cam-swarm \
160+
--collision-meshes \
160161
--n-cameras 1000 \
161162
--min-distance 0.5 \
162163
--max-distance 3.0 \

roboreg/cli/rr_cam_swarm.py

Lines changed: 26 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import argparse
22
import os
3-
from typing import Tuple
43

54
import cv2
65
import numpy as np
7-
import rich
8-
import rich.progress
96
import torch
107

118
from roboreg import differentiable as rrd
12-
from roboreg.io import find_files, parse_camera_info
9+
from roboreg.io import find_files, parse_camera_info, parse_mono_data
1310
from roboreg.losses import soft_dice_loss
1411
from roboreg.optim import LinearParticleSwarm, ParticleSwarmOptimizer
1512
from roboreg.util import (
1613
look_at_from_angle,
17-
mask_exponential_distance_transform,
14+
mask_exponential_decay,
1815
overlay_mask,
1916
random_fov_eye_space_coordinates,
2017
)
@@ -123,9 +120,9 @@ def args_factory() -> argparse.Namespace:
123120
help="Scale the camera resolution by this factor. Reduces memory usage.",
124121
)
125122
parser.add_argument(
126-
"--visual-meshes",
123+
"--collision-meshes",
127124
action="store_true",
128-
help="If set, visual meshes will be used instead of collision meshes.",
125+
help="If set, collision meshes will be used instead of visual meshes.",
129126
)
130127
parser.add_argument(
131128
"--camera-info-file",
@@ -173,58 +170,6 @@ def args_factory() -> argparse.Namespace:
173170
return parser.parse_args()
174171

175172

176-
def parse_data(
177-
path: str,
178-
image_pattern: str,
179-
mask_pattern: str,
180-
joint_states_pattern: str,
181-
n_samples: int = 5,
182-
device: str = "cuda",
183-
) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]:
184-
image_files = find_files(path, image_pattern)
185-
mask_files = find_files(path, mask_pattern)
186-
joint_states_files = find_files(path, joint_states_pattern)
187-
188-
rich.print("Found the following files:")
189-
rich.print(f"Images: {image_files}")
190-
rich.print(f"Masks: {mask_files}")
191-
rich.print(f"Joint states: {joint_states_files}")
192-
193-
# randomly sample n_samples
194-
if n_samples > len(image_files):
195-
n_samples = len(image_files)
196-
random_indices = np.random.choice(len(image_files), n_samples, replace=False)
197-
image_files = np.array(image_files)[random_indices].tolist()
198-
mask_files = np.array(mask_files)[random_indices].tolist()
199-
joint_states_files = np.array(joint_states_files)[random_indices].tolist()
200-
201-
rich.print(f"Randomly sampled the following {n_samples} files:")
202-
rich.print(f"Images: {image_files}")
203-
rich.print(f"Masks: {mask_files}")
204-
rich.print(f"Joint states: {joint_states_files}")
205-
206-
if len(mask_files) != len(joint_states_files):
207-
raise ValueError("Number of masks and joint states do not match.")
208-
209-
images = [
210-
cv2.imread(os.path.join(path, file), cv2.IMREAD_COLOR) for file in image_files
211-
]
212-
masks = [
213-
mask_exponential_distance_transform(
214-
cv2.imread(os.path.join(path, file), cv2.IMREAD_GRAYSCALE)
215-
)
216-
for file in mask_files
217-
]
218-
joint_states = [np.load(os.path.join(path, file)) for file in joint_states_files]
219-
220-
masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device)
221-
joint_states = torch.tensor(
222-
np.array(joint_states), dtype=torch.float32, device=device
223-
)
224-
225-
return images, joint_states, masks
226-
227-
228173
def instantiate_particles(
229174
n_particles: int,
230175
height: int,
@@ -293,15 +238,30 @@ def main() -> None:
293238
height, width, intrinsics = parse_camera_info(
294239
camera_info_file=args.camera_info_file
295240
)
296-
images, joint_states, masks = parse_data(
241+
image_files = find_files(args.path, args.image_pattern)
242+
mask_files = find_files(args.path, args.mask_pattern)
243+
joint_states_files = find_files(args.path, args.joint_states_pattern)
244+
n_samples = args.n_samples
245+
if n_samples > len(image_files): # randomly sample n_samples
246+
n_samples = len(image_files)
247+
random_indices = np.random.choice(len(image_files), n_samples, replace=False)
248+
image_files = np.array(image_files)[random_indices].tolist()
249+
mask_files = np.array(mask_files)[random_indices].tolist()
250+
joint_states_files = np.array(joint_states_files)[random_indices].tolist()
251+
images, joint_states, masks = parse_mono_data(
297252
path=args.path,
298-
image_pattern=args.image_pattern,
299-
mask_pattern=args.mask_pattern,
300-
joint_states_pattern=args.joint_states_pattern,
301-
n_samples=args.n_samples,
302-
device=device,
253+
image_files=image_files,
254+
mask_files=mask_files,
255+
joint_states_files=joint_states_files,
256+
)
257+
258+
# pre-process data
259+
joint_states = torch.tensor(
260+
np.array(joint_states), dtype=torch.float32, device=device
303261
)
304262
n_joint_states = joint_states.shape[0]
263+
masks = [mask_exponential_decay(mask) for mask in masks]
264+
masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device)
305265

306266
# scale image data (memory reduction)
307267
height = int(height * args.scale)
@@ -348,7 +308,7 @@ def main() -> None:
348308
urdf_parser=urdf_parser,
349309
root_link_name=args.root_link_name,
350310
end_link_name=args.end_link_name,
351-
visual=args.visual_meshes,
311+
collision=args.collision_meshes,
352312
batch_size=batch_size,
353313
device=device,
354314
target_reduction=args.target_reduction, # reduce mesh vertex count for memory reduction

roboreg/cli/rr_hydra.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from roboreg.differentiable import Robot
99
from roboreg.hydra_icp import hydra_centroid_alignment, hydra_robust_icp
10-
from roboreg.io import URDFParser, parse_camera_info, parse_hydra_data
10+
from roboreg.io import URDFParser, find_files, parse_camera_info, parse_hydra_data
1111
from roboreg.util import (
1212
RegistrationVisualizer,
1313
clean_xyz,
@@ -74,9 +74,9 @@ def args_factory() -> argparse.Namespace:
7474
help="End link name. If unspecified, the last link with mesh will be used, which may cause errors.",
7575
)
7676
parser.add_argument(
77-
"--visual-meshes",
77+
"--collision-meshes",
7878
action="store_true",
79-
help="If set, visual meshes will be used instead of collision meshes.",
79+
help="If set, collision meshes will be used instead of visual meshes.",
8080
)
8181
parser.add_argument(
8282
"--depth-conversion-factor",
@@ -151,11 +151,14 @@ def main():
151151
device = "cuda" if torch.cuda.is_available() else "cpu"
152152

153153
# load data
154+
joint_states_files = find_files(args.path, args.joint_states_pattern)
155+
mask_files = find_files(args.path, args.mask_pattern)
156+
depth_files = find_files(args.path, args.depth_pattern)
154157
joint_states, masks, depths = parse_hydra_data(
155158
path=args.path,
156-
joint_states_pattern=args.joint_states_pattern,
157-
mask_pattern=args.mask_pattern,
158-
depth_pattern=args.depth_pattern,
159+
joint_states_files=joint_states_files,
160+
mask_files=mask_files,
161+
depth_files=depth_files,
159162
)
160163
height, width, intrinsics = parse_camera_info(args.camera_info_file)
161164

@@ -165,16 +168,16 @@ def main():
165168
root_link_name = args.root_link_name
166169
end_link_name = args.end_link_name
167170
if root_link_name == "":
168-
root_link_name = urdf_parser.link_names_with_meshes(visual=args.visual_meshes)[
169-
0
170-
]
171+
root_link_name = urdf_parser.link_names_with_meshes(
172+
collision=args.collision_meshes
173+
)[0]
171174
rich.print(
172175
f"Root link name not provided. Using the first link with mesh: '{root_link_name}'."
173176
)
174177
if end_link_name == "":
175-
end_link_name = urdf_parser.link_names_with_meshes(visual=args.visual_meshes)[
176-
-1
177-
]
178+
end_link_name = urdf_parser.link_names_with_meshes(
179+
collision=args.collision_meshes
180+
)[-1]
178181
rich.print(
179182
f"End link name not provided. Using the last link with mesh: '{end_link_name}'."
180183
)
@@ -185,7 +188,7 @@ def main():
185188
urdf_parser=urdf_parser,
186189
root_link_name=root_link_name,
187190
end_link_name=end_link_name,
188-
visual=args.visual_meshes,
191+
collision=args.collision_meshes,
189192
batch_size=batch_size,
190193
)
191194

roboreg/cli/rr_mono_dr.py

Lines changed: 27 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import importlib
33
import os
4-
from typing import Tuple
54

65
import cv2
76
import numpy as np
@@ -10,9 +9,8 @@
109
import rich.progress
1110
import torch
1211

13-
from roboreg.io import find_files
14-
from roboreg.losses import soft_dice_loss
15-
from roboreg.util import mask_exponential_distance_transform, overlay_mask
12+
from roboreg.io import find_files, parse_mono_data
13+
from roboreg.util import mask_distance_transform, overlay_mask
1614
from roboreg.util.factories import create_robot_scene, create_virtual_camera
1715

1816

@@ -50,12 +48,6 @@ def args_factory() -> argparse.Namespace:
5048
default=1.0,
5149
help="Gamma for the learning rate scheduler.",
5250
)
53-
parser.add_argument(
54-
"--sigma",
55-
type=float,
56-
default=2.0,
57-
help="Sigma for the exponential distance transform on target masks.",
58-
)
5951
parser.add_argument(
6052
"--display-progress",
6153
action="store_true",
@@ -86,9 +78,9 @@ def args_factory() -> argparse.Namespace:
8678
help="End link name. If unspecified, the last link with mesh will be used, which may cause errors.",
8779
)
8880
parser.add_argument(
89-
"--visual-meshes",
81+
"--collision-meshes",
9082
action="store_true",
91-
help="If set, visual meshes will be used instead of collision meshes.",
83+
help="If set, collision meshes will be used instead of visual meshes.",
9284
)
9385
parser.add_argument(
9486
"--camera-info-file",
@@ -136,59 +128,30 @@ def args_factory() -> argparse.Namespace:
136128
return parser.parse_args()
137129

138130

139-
def parse_data(
140-
path: str,
141-
image_pattern: str,
142-
joint_states_pattern: str,
143-
mask_pattern: str,
144-
sigma: float = 2.0,
145-
device: str = "cuda",
146-
) -> Tuple[np.ndarray, torch.FloatTensor, torch.FloatTensor]:
147-
image_files = find_files(path, image_pattern)
148-
joint_states_files = find_files(path, joint_states_pattern)
149-
left_mask_files = find_files(path, mask_pattern)
150-
151-
rich.print("Found the following files:")
152-
rich.print(f"Images: {image_files}")
153-
rich.print(f"Joint states: {joint_states_files}")
154-
rich.print(f"Masks: {left_mask_files}")
155-
156-
if len(image_files) != len(joint_states_files) or len(image_files) != len(
157-
left_mask_files
158-
):
159-
raise ValueError("Number of images, joint states, masks do not match.")
160-
161-
images = [cv2.imread(os.path.join(path, file)) for file in image_files]
162-
joint_states = [np.load(os.path.join(path, file)) for file in joint_states_files]
163-
masks = [
164-
mask_exponential_distance_transform(
165-
cv2.imread(os.path.join(path, file), cv2.IMREAD_GRAYSCALE), sigma=sigma
166-
)
167-
for file in left_mask_files
168-
]
169-
170-
images = np.array(images)
171-
joint_states = torch.tensor(
172-
np.array(joint_states), dtype=torch.float32, device=device
173-
)
174-
masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device).unsqueeze(
175-
-1
176-
)
177-
return images, joint_states, masks
178-
179-
180131
def main() -> None:
181132
args = args_factory()
182133
device = "cuda" if torch.cuda.is_available() else "cpu"
183134
os.environ["MAX_JOBS"] = str(args.max_jobs) # limit number of concurrent jobs
184-
images, joint_states, masks = parse_data(
135+
136+
# load data
137+
image_files = find_files(args.path, args.image_pattern)
138+
joint_states_files = find_files(args.path, args.joint_states_pattern)
139+
mask_files = find_files(args.path, args.mask_pattern)
140+
images, joint_states, masks = parse_mono_data(
185141
path=args.path,
186-
image_pattern=args.image_pattern,
187-
joint_states_pattern=args.joint_states_pattern,
188-
mask_pattern=args.mask_pattern,
189-
sigma=args.sigma,
190-
device=device,
142+
image_files=image_files,
143+
joint_states_files=joint_states_files,
144+
mask_files=mask_files,
145+
)
146+
147+
# pre-process data
148+
joint_states = torch.tensor(
149+
np.array(joint_states), dtype=torch.float32, device=device
191150
)
151+
distance_maps = [mask_distance_transform(mask) for mask in masks]
152+
distance_maps = torch.tensor(
153+
np.array(distance_maps), dtype=torch.float32, device=device
154+
).unsqueeze(-1)
192155

193156
# instantiate camera with default identity extrinsics because we optimize for robot pose instead
194157
camera = {
@@ -205,7 +168,7 @@ def main() -> None:
205168
xacro_path=args.xacro_path,
206169
root_link_name=args.root_link_name,
207170
end_link_name=args.end_link_name,
208-
visual=args.visual_meshes,
171+
collision=args.collision_meshes,
209172
cameras=camera,
210173
device=device,
211174
)
@@ -241,7 +204,7 @@ def main() -> None:
241204
renders = {
242205
"camera": scene.observe_from("camera"),
243206
}
244-
loss = soft_dice_loss(renders["camera"], masks).mean()
207+
loss = torch.nn.functional.mse_loss(distance_maps, renders["camera"])
245208
optimizer.zero_grad()
246209
loss.backward()
247210
optimizer.step()
@@ -268,15 +231,15 @@ def main() -> None:
268231
# difference left / right render / mask
269232
difference = (
270233
cv2.cvtColor(
271-
np.abs(render - masks[0].squeeze().cpu().numpy()),
234+
np.abs(render - masks[0].astype(np.float32) / 255.0),
272235
cv2.COLOR_GRAY2BGR,
273236
)
274237
* 255.0
275238
).astype(np.uint8)
276239
# overlay segmentation mask
277240
segmentation_overlay = overlay_mask(
278241
image,
279-
(masks[0].squeeze().cpu().numpy() * 255.0).astype(np.uint8),
242+
masks[0],
280243
mode="b",
281244
scale=1.0,
282245
)
@@ -305,7 +268,7 @@ def main() -> None:
305268
for i, render in enumerate(renders):
306269
render = render.squeeze().cpu().numpy()
307270
overlay = overlay_mask(images[i], (render * 255.0).astype(np.uint8), scale=1.0)
308-
difference = np.abs(render - masks[i].squeeze().cpu().numpy())
271+
difference = np.abs(render - masks[i].astype(np.float32) / 255.0)
309272

310273
cv2.imwrite(os.path.join(args.path, f"dr_overlay_{i}.png"), overlay)
311274
cv2.imwrite(

0 commit comments

Comments
 (0)