11import argparse
22import importlib
33import os
4+ from enum import Enum
45
56import cv2
67import numpy as np
1011import torch
1112
1213from roboreg .io import find_files , parse_stereo_data
13- from roboreg .util import mask_distance_transform , overlay_mask
14+ from roboreg .losses import soft_dice_loss
15+ from roboreg .util import mask_distance_transform , mask_exponential_decay , overlay_mask
1416from roboreg .util .factories import create_robot_scene , create_virtual_camera
1517
1618
19+ class REGISTRATION_MODE (Enum ):
20+ DISTANCE_FUNCTION = "distance-function"
21+ SEGMENTATION = "segmentation"
22+
23+
1724def args_factory () -> argparse .Namespace :
1825 parser = argparse .ArgumentParser (
1926 formatter_class = argparse .ArgumentDefaultsHelpFormatter
@@ -48,6 +55,13 @@ def args_factory() -> argparse.Namespace:
4855 default = 1.0 ,
4956 help = "Gamma for the learning rate scheduler." ,
5057 )
58+ parser .add_argument (
59+ "--mode" ,
60+ type = str ,
61+ choices = [mode .value for mode in REGISTRATION_MODE ],
62+ default = REGISTRATION_MODE .DISTANCE_FUNCTION .value ,
63+ help = "Registration mode." ,
64+ )
5165 parser .add_argument (
5266 "--display-progress" ,
5367 action = "store_true" ,
@@ -162,6 +176,7 @@ def main() -> None:
162176 args = args_factory ()
163177 device = "cuda" if torch .cuda .is_available () else "cpu"
164178 os .environ ["MAX_JOBS" ] = str (args .max_jobs ) # limit number of concurrent jobs
179+ mode = REGISTRATION_MODE (args .mode )
165180
166181 # load data
167182 left_image_files = find_files (args .path , args .left_image_pattern )
@@ -184,13 +199,19 @@ def main() -> None:
184199 joint_states = torch .tensor (
185200 np .array (joint_states ), dtype = torch .float32 , device = device
186201 )
187- left_distance_maps = [mask_distance_transform (mask ) for mask in left_masks ]
188- right_distance_maps = [mask_distance_transform (mask ) for mask in right_masks ]
189- left_distance_maps = torch .tensor (
190- np .array (left_distance_maps ), dtype = torch .float32 , device = device
202+ if mode == REGISTRATION_MODE .DISTANCE_FUNCTION :
203+ left_targets = [mask_distance_transform (mask ) for mask in left_masks ]
204+ right_targets = [mask_distance_transform (mask ) for mask in right_masks ]
205+ elif mode == REGISTRATION_MODE .SEGMENTATION :
206+ left_targets = [mask_exponential_decay (mask ) for mask in left_masks ]
207+ right_targets = [mask_exponential_decay (mask ) for mask in right_masks ]
208+ else :
209+ raise ValueError ("Invalid registration mode." )
210+ left_targets = torch .tensor (
211+ np .array (left_targets ), dtype = torch .float32 , device = device
191212 ).unsqueeze (- 1 )
192- right_distance_maps = torch .tensor (
193- np .array (right_distance_maps ), dtype = torch .float32 , device = device
213+ right_targets = torch .tensor (
214+ np .array (right_targets ), dtype = torch .float32 , device = device
194215 ).unsqueeze (- 1 )
195216
196217 # instantiate:
@@ -252,9 +273,17 @@ def main() -> None:
252273 "left" : scene .observe_from ("left" ),
253274 "right" : scene .observe_from ("right" ),
254275 }
255- loss = torch .nn .functional .mse_loss (
256- left_distance_maps , renders ["left" ]
257- ) + torch .nn .functional .mse_loss (right_distance_maps , renders ["right" ])
276+ if mode == REGISTRATION_MODE .DISTANCE_FUNCTION :
277+ loss = torch .nn .functional .mse_loss (
278+ left_targets , renders ["left" ]
279+ ) + torch .nn .functional .mse_loss (right_targets , renders ["right" ])
280+ elif mode == REGISTRATION_MODE .SEGMENTATION :
281+ loss = (
282+ soft_dice_loss (left_targets , renders ["left" ]).mean ()
283+ + soft_dice_loss (right_targets , renders ["right" ]).mean ()
284+ )
285+ else :
286+ raise ValueError ("Invalid registration mode." )
258287 optimizer .zero_grad ()
259288 loss .backward ()
260289 optimizer .step ()
0 commit comments