Skip to content

Commit cbf1a99

Browse files
committed
support multiple registration modes (#75)
1 parent f0e313e commit cbf1a99

File tree

2 files changed

+70
-15
lines changed

2 files changed

+70
-15
lines changed

roboreg/cli/rr_mono_dr.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import importlib
33
import os
4+
from enum import Enum
45

56
import cv2
67
import numpy as np
@@ -10,10 +11,16 @@
1011
import torch
1112

1213
from roboreg.io import find_files, parse_mono_data
13-
from roboreg.util import mask_distance_transform, overlay_mask
14+
from roboreg.losses import soft_dice_loss
15+
from roboreg.util import mask_distance_transform, mask_exponential_decay, overlay_mask
1416
from roboreg.util.factories import create_robot_scene, create_virtual_camera
1517

1618

19+
class REGISTRATION_MODE(Enum):
20+
DISTANCE_FUNCTION = "distance-function"
21+
SEGMENTATION = "segmentation"
22+
23+
1724
def args_factory() -> argparse.Namespace:
1825
parser = argparse.ArgumentParser(
1926
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -48,6 +55,13 @@ def args_factory() -> argparse.Namespace:
4855
default=1.0,
4956
help="Gamma for the learning rate scheduler.",
5057
)
58+
parser.add_argument(
59+
"--mode",
60+
type=str,
61+
choices=[mode.value for mode in REGISTRATION_MODE],
62+
default=REGISTRATION_MODE.DISTANCE_FUNCTION.value,
63+
help="Registration mode.",
64+
)
5165
parser.add_argument(
5266
"--display-progress",
5367
action="store_true",
@@ -132,6 +146,7 @@ def main() -> None:
132146
args = args_factory()
133147
device = "cuda" if torch.cuda.is_available() else "cpu"
134148
os.environ["MAX_JOBS"] = str(args.max_jobs) # limit number of concurrent jobs
149+
mode = REGISTRATION_MODE(args.mode)
135150

136151
# load data
137152
image_files = find_files(args.path, args.image_pattern)
@@ -148,9 +163,14 @@ def main() -> None:
148163
joint_states = torch.tensor(
149164
np.array(joint_states), dtype=torch.float32, device=device
150165
)
151-
distance_maps = [mask_distance_transform(mask) for mask in masks]
152-
distance_maps = torch.tensor(
153-
np.array(distance_maps), dtype=torch.float32, device=device
166+
if mode == REGISTRATION_MODE.DISTANCE_FUNCTION:
167+
targets = [mask_distance_transform(mask) for mask in masks]
168+
elif mode == REGISTRATION_MODE.SEGMENTATION:
169+
targets = [mask_exponential_decay(mask) for mask in masks]
170+
else:
171+
raise ValueError("Invalid registration mode.")
172+
targets = torch.tensor(
173+
np.array(targets), dtype=torch.float32, device=device
154174
).unsqueeze(-1)
155175

156176
# instantiate camera with default identity extrinsics because we optimize for robot pose instead
@@ -204,7 +224,13 @@ def main() -> None:
204224
renders = {
205225
"camera": scene.observe_from("camera"),
206226
}
207-
loss = torch.nn.functional.mse_loss(distance_maps, renders["camera"])
227+
if mode == REGISTRATION_MODE.DISTANCE_FUNCTION:
228+
loss = torch.nn.functional.mse_loss(targets, renders["camera"])
229+
elif mode == REGISTRATION_MODE.SEGMENTATION:
230+
loss = soft_dice_loss(targets, renders["camera"]).mean()
231+
else:
232+
raise ValueError("Invalid registration mode.")
233+
208234
optimizer.zero_grad()
209235
loss.backward()
210236
optimizer.step()

roboreg/cli/rr_stereo_dr.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import importlib
33
import os
4+
from enum import Enum
45

56
import cv2
67
import numpy as np
@@ -10,10 +11,16 @@
1011
import torch
1112

1213
from roboreg.io import find_files, parse_stereo_data
13-
from roboreg.util import mask_distance_transform, overlay_mask
14+
from roboreg.losses import soft_dice_loss
15+
from roboreg.util import mask_distance_transform, mask_exponential_decay, overlay_mask
1416
from roboreg.util.factories import create_robot_scene, create_virtual_camera
1517

1618

19+
class REGISTRATION_MODE(Enum):
20+
DISTANCE_FUNCTION = "distance-function"
21+
SEGMENTATION = "segmentation"
22+
23+
1724
def args_factory() -> argparse.Namespace:
1825
parser = argparse.ArgumentParser(
1926
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -48,6 +55,13 @@ def args_factory() -> argparse.Namespace:
4855
default=1.0,
4956
help="Gamma for the learning rate scheduler.",
5057
)
58+
parser.add_argument(
59+
"--mode",
60+
type=str,
61+
choices=[mode.value for mode in REGISTRATION_MODE],
62+
default=REGISTRATION_MODE.DISTANCE_FUNCTION.value,
63+
help="Registration mode.",
64+
)
5165
parser.add_argument(
5266
"--display-progress",
5367
action="store_true",
@@ -162,6 +176,7 @@ def main() -> None:
162176
args = args_factory()
163177
device = "cuda" if torch.cuda.is_available() else "cpu"
164178
os.environ["MAX_JOBS"] = str(args.max_jobs) # limit number of concurrent jobs
179+
mode = REGISTRATION_MODE(args.mode)
165180

166181
# load data
167182
left_image_files = find_files(args.path, args.left_image_pattern)
@@ -184,13 +199,19 @@ def main() -> None:
184199
joint_states = torch.tensor(
185200
np.array(joint_states), dtype=torch.float32, device=device
186201
)
187-
left_distance_maps = [mask_distance_transform(mask) for mask in left_masks]
188-
right_distance_maps = [mask_distance_transform(mask) for mask in right_masks]
189-
left_distance_maps = torch.tensor(
190-
np.array(left_distance_maps), dtype=torch.float32, device=device
202+
if mode == REGISTRATION_MODE.DISTANCE_FUNCTION:
203+
left_targets = [mask_distance_transform(mask) for mask in left_masks]
204+
right_targets = [mask_distance_transform(mask) for mask in right_masks]
205+
elif mode == REGISTRATION_MODE.SEGMENTATION:
206+
left_targets = [mask_exponential_decay(mask) for mask in left_masks]
207+
right_targets = [mask_exponential_decay(mask) for mask in right_masks]
208+
else:
209+
raise ValueError("Invalid registration mode.")
210+
left_targets = torch.tensor(
211+
np.array(left_targets), dtype=torch.float32, device=device
191212
).unsqueeze(-1)
192-
right_distance_maps = torch.tensor(
193-
np.array(right_distance_maps), dtype=torch.float32, device=device
213+
right_targets = torch.tensor(
214+
np.array(right_targets), dtype=torch.float32, device=device
194215
).unsqueeze(-1)
195216

196217
# instantiate:
@@ -252,9 +273,17 @@ def main() -> None:
252273
"left": scene.observe_from("left"),
253274
"right": scene.observe_from("right"),
254275
}
255-
loss = torch.nn.functional.mse_loss(
256-
left_distance_maps, renders["left"]
257-
) + torch.nn.functional.mse_loss(right_distance_maps, renders["right"])
276+
if mode == REGISTRATION_MODE.DISTANCE_FUNCTION:
277+
loss = torch.nn.functional.mse_loss(
278+
left_targets, renders["left"]
279+
) + torch.nn.functional.mse_loss(right_targets, renders["right"])
280+
elif mode == REGISTRATION_MODE.SEGMENTATION:
281+
loss = (
282+
soft_dice_loss(left_targets, renders["left"]).mean()
283+
+ soft_dice_loss(right_targets, renders["right"]).mean()
284+
)
285+
else:
286+
raise ValueError("Invalid registration mode.")
258287
optimizer.zero_grad()
259288
loss.backward()
260289
optimizer.step()

0 commit comments

Comments
 (0)