11import argparse
22import importlib
33import os
4- from typing import Tuple
54
65import cv2
76import numpy as np
109import rich .progress
1110import torch
1211
13- from roboreg .io import find_files
14- from roboreg .losses import soft_dice_loss
15- from roboreg .util import mask_exponential_distance_transform , overlay_mask
12+ from roboreg .io import find_files , parse_mono_data
13+ from roboreg .util import mask_distance_transform , overlay_mask
1614from roboreg .util .factories import create_robot_scene , create_virtual_camera
1715
1816
@@ -50,12 +48,6 @@ def args_factory() -> argparse.Namespace:
5048 default = 1.0 ,
5149 help = "Gamma for the learning rate scheduler." ,
5250 )
53- parser .add_argument (
54- "--sigma" ,
55- type = float ,
56- default = 2.0 ,
57- help = "Sigma for the exponential distance transform on target masks." ,
58- )
5951 parser .add_argument (
6052 "--display-progress" ,
6153 action = "store_true" ,
@@ -86,9 +78,9 @@ def args_factory() -> argparse.Namespace:
8678 help = "End link name. If unspecified, the last link with mesh will be used, which may cause errors." ,
8779 )
8880 parser .add_argument (
89- "--visual -meshes" ,
81+ "--collision -meshes" ,
9082 action = "store_true" ,
91- help = "If set, visual meshes will be used instead of collision meshes." ,
83+ help = "If set, collision meshes will be used instead of visual meshes." ,
9284 )
9385 parser .add_argument (
9486 "--camera-info-file" ,
@@ -136,59 +128,30 @@ def args_factory() -> argparse.Namespace:
136128 return parser .parse_args ()
137129
138130
139- def parse_data (
140- path : str ,
141- image_pattern : str ,
142- joint_states_pattern : str ,
143- mask_pattern : str ,
144- sigma : float = 2.0 ,
145- device : str = "cuda" ,
146- ) -> Tuple [np .ndarray , torch .FloatTensor , torch .FloatTensor ]:
147- image_files = find_files (path , image_pattern )
148- joint_states_files = find_files (path , joint_states_pattern )
149- left_mask_files = find_files (path , mask_pattern )
150-
151- rich .print ("Found the following files:" )
152- rich .print (f"Images: { image_files } " )
153- rich .print (f"Joint states: { joint_states_files } " )
154- rich .print (f"Masks: { left_mask_files } " )
155-
156- if len (image_files ) != len (joint_states_files ) or len (image_files ) != len (
157- left_mask_files
158- ):
159- raise ValueError ("Number of images, joint states, masks do not match." )
160-
161- images = [cv2 .imread (os .path .join (path , file )) for file in image_files ]
162- joint_states = [np .load (os .path .join (path , file )) for file in joint_states_files ]
163- masks = [
164- mask_exponential_distance_transform (
165- cv2 .imread (os .path .join (path , file ), cv2 .IMREAD_GRAYSCALE ), sigma = sigma
166- )
167- for file in left_mask_files
168- ]
169-
170- images = np .array (images )
171- joint_states = torch .tensor (
172- np .array (joint_states ), dtype = torch .float32 , device = device
173- )
174- masks = torch .tensor (np .array (masks ), dtype = torch .float32 , device = device ).unsqueeze (
175- - 1
176- )
177- return images , joint_states , masks
178-
179-
180131def main () -> None :
181132 args = args_factory ()
182133 device = "cuda" if torch .cuda .is_available () else "cpu"
183134 os .environ ["MAX_JOBS" ] = str (args .max_jobs ) # limit number of concurrent jobs
184- images , joint_states , masks = parse_data (
135+
136+ # load data
137+ image_files = find_files (args .path , args .image_pattern )
138+ joint_states_files = find_files (args .path , args .joint_states_pattern )
139+ mask_files = find_files (args .path , args .mask_pattern )
140+ images , joint_states , masks = parse_mono_data (
185141 path = args .path ,
186- image_pattern = args .image_pattern ,
187- joint_states_pattern = args .joint_states_pattern ,
188- mask_pattern = args .mask_pattern ,
189- sigma = args .sigma ,
190- device = device ,
142+ image_files = image_files ,
143+ joint_states_files = joint_states_files ,
144+ mask_files = mask_files ,
145+ )
146+
147+ # pre-process data
148+ joint_states = torch .tensor (
149+ np .array (joint_states ), dtype = torch .float32 , device = device
191150 )
151+ distance_maps = [mask_distance_transform (mask ) for mask in masks ]
152+ distance_maps = torch .tensor (
153+ np .array (distance_maps ), dtype = torch .float32 , device = device
154+ ).unsqueeze (- 1 )
192155
193156 # instantiate camera with default identity extrinsics because we optimize for robot pose instead
194157 camera = {
@@ -205,7 +168,7 @@ def main() -> None:
205168 xacro_path = args .xacro_path ,
206169 root_link_name = args .root_link_name ,
207170 end_link_name = args .end_link_name ,
208- visual = args .visual_meshes ,
171+ collision = args .collision_meshes ,
209172 cameras = camera ,
210173 device = device ,
211174 )
@@ -241,7 +204,7 @@ def main() -> None:
241204 renders = {
242205 "camera" : scene .observe_from ("camera" ),
243206 }
244- loss = soft_dice_loss ( renders ["camera" ], masks ). mean ( )
207+ loss = torch . nn . functional . mse_loss ( distance_maps , renders ["camera" ])
245208 optimizer .zero_grad ()
246209 loss .backward ()
247210 optimizer .step ()
@@ -268,15 +231,15 @@ def main() -> None:
268231 # difference left / right render / mask
269232 difference = (
270233 cv2 .cvtColor (
271- np .abs (render - masks [0 ].squeeze (). cpu (). numpy () ),
234+ np .abs (render - masks [0 ].astype ( np . float32 ) / 255.0 ),
272235 cv2 .COLOR_GRAY2BGR ,
273236 )
274237 * 255.0
275238 ).astype (np .uint8 )
276239 # overlay segmentation mask
277240 segmentation_overlay = overlay_mask (
278241 image ,
279- ( masks [0 ]. squeeze (). cpu (). numpy () * 255.0 ). astype ( np . uint8 ) ,
242+ masks [0 ],
280243 mode = "b" ,
281244 scale = 1.0 ,
282245 )
@@ -305,7 +268,7 @@ def main() -> None:
305268 for i , render in enumerate (renders ):
306269 render = render .squeeze ().cpu ().numpy ()
307270 overlay = overlay_mask (images [i ], (render * 255.0 ).astype (np .uint8 ), scale = 1.0 )
308- difference = np .abs (render - masks [i ].squeeze (). cpu (). numpy () )
271+ difference = np .abs (render - masks [i ].astype ( np . float32 ) / 255.0 )
309272
310273 cv2 .imwrite (os .path .join (args .path , f"dr_overlay_{ i } .png" ), overlay )
311274 cv2 .imwrite (
0 commit comments