|
1 | 1 | import argparse |
2 | 2 | import os |
3 | | -from typing import Tuple |
4 | 3 |
|
5 | 4 | import cv2 |
6 | 5 | import numpy as np |
7 | | -import rich |
8 | | -import rich.progress |
9 | 6 | import torch |
10 | 7 |
|
11 | 8 | from roboreg import differentiable as rrd |
12 | | -from roboreg.io import find_files, parse_camera_info |
| 9 | +from roboreg.io import find_files, parse_camera_info, parse_mono_data |
13 | 10 | from roboreg.losses import soft_dice_loss |
14 | 11 | from roboreg.optim import LinearParticleSwarm, ParticleSwarmOptimizer |
15 | 12 | from roboreg.util import ( |
16 | 13 | look_at_from_angle, |
17 | | - mask_exponential_distance_transform, |
| 14 | + mask_exponential_decay, |
18 | 15 | overlay_mask, |
19 | 16 | random_fov_eye_space_coordinates, |
20 | 17 | ) |
@@ -173,58 +170,6 @@ def args_factory() -> argparse.Namespace: |
173 | 170 | return parser.parse_args() |
174 | 171 |
|
175 | 172 |
|
176 | | -def parse_data( |
177 | | - path: str, |
178 | | - image_pattern: str, |
179 | | - mask_pattern: str, |
180 | | - joint_states_pattern: str, |
181 | | - n_samples: int = 5, |
182 | | - device: str = "cuda", |
183 | | -) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]: |
184 | | - image_files = find_files(path, image_pattern) |
185 | | - mask_files = find_files(path, mask_pattern) |
186 | | - joint_states_files = find_files(path, joint_states_pattern) |
187 | | - |
188 | | - rich.print("Found the following files:") |
189 | | - rich.print(f"Images: {image_files}") |
190 | | - rich.print(f"Masks: {mask_files}") |
191 | | - rich.print(f"Joint states: {joint_states_files}") |
192 | | - |
193 | | - # randomly sample n_samples |
194 | | - if n_samples > len(image_files): |
195 | | - n_samples = len(image_files) |
196 | | - random_indices = np.random.choice(len(image_files), n_samples, replace=False) |
197 | | - image_files = np.array(image_files)[random_indices].tolist() |
198 | | - mask_files = np.array(mask_files)[random_indices].tolist() |
199 | | - joint_states_files = np.array(joint_states_files)[random_indices].tolist() |
200 | | - |
201 | | - rich.print(f"Randomly sampled the following {n_samples} files:") |
202 | | - rich.print(f"Images: {image_files}") |
203 | | - rich.print(f"Masks: {mask_files}") |
204 | | - rich.print(f"Joint states: {joint_states_files}") |
205 | | - |
206 | | - if len(mask_files) != len(joint_states_files): |
207 | | - raise ValueError("Number of masks and joint states do not match.") |
208 | | - |
209 | | - images = [ |
210 | | - cv2.imread(os.path.join(path, file), cv2.IMREAD_COLOR) for file in image_files |
211 | | - ] |
212 | | - masks = [ |
213 | | - mask_exponential_distance_transform( |
214 | | - cv2.imread(os.path.join(path, file), cv2.IMREAD_GRAYSCALE) |
215 | | - ) |
216 | | - for file in mask_files |
217 | | - ] |
218 | | - joint_states = [np.load(os.path.join(path, file)) for file in joint_states_files] |
219 | | - |
220 | | - masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device) |
221 | | - joint_states = torch.tensor( |
222 | | - np.array(joint_states), dtype=torch.float32, device=device |
223 | | - ) |
224 | | - |
225 | | - return images, joint_states, masks |
226 | | - |
227 | | - |
228 | 173 | def instantiate_particles( |
229 | 174 | n_particles: int, |
230 | 175 | height: int, |
@@ -293,15 +238,30 @@ def main() -> None: |
293 | 238 | height, width, intrinsics = parse_camera_info( |
294 | 239 | camera_info_file=args.camera_info_file |
295 | 240 | ) |
296 | | - images, joint_states, masks = parse_data( |
| 241 | + image_files = find_files(args.path, args.image_pattern) |
| 242 | + mask_files = find_files(args.path, args.mask_pattern) |
| 243 | + joint_states_files = find_files(args.path, args.joint_states_pattern) |
| 244 | + n_samples = args.n_samples |
| 245 | + if n_samples > len(image_files): # randomly sample n_samples |
| 246 | + n_samples = len(image_files) |
| 247 | + random_indices = np.random.choice(len(image_files), n_samples, replace=False) |
| 248 | + image_files = np.array(image_files)[random_indices].tolist() |
| 249 | + mask_files = np.array(mask_files)[random_indices].tolist() |
| 250 | + joint_states_files = np.array(joint_states_files)[random_indices].tolist() |
| 251 | + images, joint_states, masks = parse_mono_data( |
297 | 252 | path=args.path, |
298 | | - image_pattern=args.image_pattern, |
299 | | - mask_pattern=args.mask_pattern, |
300 | | - joint_states_pattern=args.joint_states_pattern, |
301 | | - n_samples=args.n_samples, |
302 | | - device=device, |
| 253 | + image_files=image_files, |
| 254 | + mask_files=mask_files, |
| 255 | + joint_states_files=joint_states_files, |
| 256 | + ) |
| 257 | + |
| 258 | + # pre-process data |
| 259 | + joint_states = torch.tensor( |
| 260 | + np.array(joint_states), dtype=torch.float32, device=device |
303 | 261 | ) |
304 | 262 | n_joint_states = joint_states.shape[0] |
| 263 | + masks = [mask_exponential_decay(mask) for mask in masks] |
| 264 | + masks = torch.tensor(np.array(masks), dtype=torch.float32, device=device) |
305 | 265 |
|
306 | 266 | # scale image data (memory reduction) |
307 | 267 | height = int(height * args.scale) |
|
0 commit comments