Skip to content

Commit 78fc92b

Browse files
authored
Merge pull request #107 from lbr-stack/feature/roboreg-37/io_refactor
IO refactor 37
2 parents 773d3e1 + f9b9464 commit 78fc92b

File tree

13 files changed

+304
-296
lines changed

13 files changed

+304
-296
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ rr-cam-swarm \
183183
--xacro-path urdf/med7/med7.xacro \
184184
--root-link-name lbr_link_0 \
185185
--end-link-name lbr_link_7 \
186-
--target-reduction 0.95 \
186+
--target-reduction 0.8 \
187187
--scale 0.1 \
188188
--n-samples 1 \
189189
--camera-info-file test/assets/lbr_med7/zed2i/left_camera_info.yaml \

cli/rr_cam_swarm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def main() -> None:
249249
mask_files = np.array(mask_files)[random_indices].tolist()
250250
joint_states_files = np.array(joint_states_files)[random_indices].tolist()
251251
images, joint_states, masks = parse_mono_data(
252-
path=args.path,
253252
image_files=image_files,
254253
mask_files=mask_files,
255254
joint_states_files=joint_states_files,

cli/rr_hydra.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def main():
155155
mask_files = find_files(args.path, args.mask_pattern)
156156
depth_files = find_files(args.path, args.depth_pattern)
157157
joint_states, masks, depths = parse_hydra_data(
158-
path=args.path,
159158
joint_states_files=joint_states_files,
160159
mask_files=mask_files,
161160
depth_files=depth_files,

cli/rr_mono_dr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def main() -> None:
153153
joint_states_files = find_files(args.path, args.joint_states_pattern)
154154
mask_files = find_files(args.path, args.mask_pattern)
155155
images, joint_states, masks = parse_mono_data(
156-
path=args.path,
157156
image_files=image_files,
158157
joint_states_files=joint_states_files,
159158
mask_files=mask_files,

cli/rr_render.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import os
3-
import pathlib
3+
from pathlib import Path
44

55
import cv2
66
import numpy as np
@@ -142,7 +142,7 @@ def main():
142142
num_workers=args.num_workers,
143143
)
144144

145-
output_path = pathlib.Path(args.output_path)
145+
output_path = Path(args.output_path)
146146
if not output_path.exists():
147147
output_path.mkdir(parents=True)
148148

@@ -162,13 +162,12 @@ def main():
162162
images = images.numpy()
163163
renders = (renders * 255.0).squeeze(-1).cpu().numpy().astype(np.uint8)
164164
for render, image, image_file in zip(renders, images, image_files):
165-
image_stem = pathlib.Path(image_file).stem
166-
image_suffix = pathlib.Path(image_file).suffix
165+
image_file = Path(image_file)
166+
output_file = (
167+
output_path / f"overlay_render_{image_file.stem + image_file.suffix}"
168+
)
167169
cv2.imwrite(
168-
os.path.join(
169-
str(output_path.absolute()),
170-
f"overlay_render_{image_stem + image_suffix}",
171-
),
170+
output_file,
172171
overlay_mask(image, render, args.color, scale=1.0),
173172
)
174173

cli/rr_sam2.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import argparse
2-
import os
3-
import pathlib
42

53
import cv2
64
import numpy as np
@@ -48,8 +46,7 @@ def args_factory() -> argparse.Namespace:
4846

4947
def main():
5048
args = args_factory()
51-
path = pathlib.Path(args.path)
52-
image_names = find_files(path.absolute(), args.pattern)
49+
image_files = find_files(args.path, args.pattern)
5350

5451
# detect
5552
detector = OpenCVDetector(
@@ -60,23 +57,21 @@ def main():
6057
# segment
6158
segmentor = Sam2Segmentor(model_id=args.model_id, device=args.device)
6259

63-
for image_name in progress.track(image_names, description="Generating masks..."):
64-
image_stem = pathlib.Path(image_name).stem
65-
image_suffix = pathlib.Path(image_name).suffix
66-
img = cv2.imread(os.path.join(path.absolute(), image_name))
60+
for image_file in progress.track(image_files, description="Generating masks..."):
61+
img = cv2.imread(image_file)
6762
annotations = False
6863
if args.pre_annotated:
6964
try:
7065
samples, labels = detector.read(
71-
path=os.path.join(path.absolute(), f"{image_stem}_samples.csv")
66+
path=image_file.parent / f"{image_file.stem}_samples.csv"
7267
)
7368
annotations = True
7469
except FileNotFoundError:
7570
pass
7671
if not annotations:
7772
samples, labels = detector.detect(img)
7873
detector.write(
79-
path=os.path.join(path.absolute(), f"{image_stem}_samples.csv"),
74+
path=image_file.parent / f"{image_file.stem}_samples.csv",
8075
samples=samples,
8176
labels=labels,
8277
)
@@ -86,15 +81,9 @@ def main():
8681
overlay = overlay_mask(img, mask, mode="g", scale=1.0)
8782

8883
# write probability and mask
89-
probability_path = os.path.join(
90-
path.absolute(), f"probability_sam2_{image_stem + image_suffix}"
91-
)
92-
mask_path = os.path.join(
93-
path.absolute(), f"mask_sam2_{image_stem + image_suffix}"
94-
)
95-
overlay_path = os.path.join(
96-
path.absolute(), f"overlay_sam2_{image_stem + image_suffix}"
97-
)
84+
probability_path = image_file.parent / f"probability_sam2_{image_file.name}"
85+
mask_path = image_file.parent / f"mask_sam2_{image_file.name}"
86+
overlay_path = image_file.parent / f"overlay_sam2_{image_file.name}"
9887
cv2.imwrite(probability_path, (probability * 255.0).astype(np.uint8))
9988
cv2.imwrite(mask_path, mask)
10089
cv2.imwrite(overlay_path, overlay)

cli/rr_stereo_dr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def main() -> None:
186186
right_mask_files = find_files(args.path, args.right_mask_pattern)
187187
left_images, right_images, joint_states, left_masks, right_masks = (
188188
parse_stereo_data(
189-
path=args.path,
190189
left_image_files=left_image_files,
191190
right_image_files=right_image_files,
192191
joint_states_files=joint_states_files,

roboreg/io/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .datasets import *
2+
from .filesystem import *
3+
from .parsers import *

roboreg/io/datasets.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pathlib import Path
2+
from typing import Tuple, Union
3+
4+
import cv2
5+
import numpy as np
6+
import rich
7+
from torch.utils.data import Dataset
8+
9+
from .filesystem import find_files
10+
11+
12+
class MonocularDataset(Dataset):
13+
def __init__(
14+
self,
15+
images_path: Union[Path, str],
16+
image_pattern: str,
17+
joint_states_path: Union[Path, str],
18+
joint_states_pattern: str,
19+
):
20+
self._image_files = find_files(images_path, image_pattern)
21+
self._joint_states_files = find_files(joint_states_path, joint_states_pattern)
22+
23+
rich.print("Found the following files:")
24+
rich.print(f"Images: {[f.name for f in self._image_files]}")
25+
rich.print(f"Joint states: {[f.name for f in self._joint_states_files]}")
26+
27+
if len(self._image_files) != len(self._joint_states_files):
28+
raise ValueError(
29+
f"Number of images '{len(self._image_files)}' and joint states '{len(self._joint_states_files)}' do not match."
30+
)
31+
32+
if len(self._image_files) == 0:
33+
raise ValueError("No images found.")
34+
35+
if len(self._joint_states_files) == 0:
36+
raise ValueError("No joint states found.")
37+
38+
for image_file, joint_states_file in zip(
39+
self._image_files, self._joint_states_files
40+
):
41+
image_index = image_file.stem.split("_")[-1]
42+
joint_states_index = joint_states_file.stem.split("_")[-1]
43+
if image_index != joint_states_index:
44+
raise ValueError(
45+
f"Image file index '{image_file.name}' and joint states file index '{joint_states_file.name}' do not match."
46+
)
47+
48+
def __len__(self):
49+
return len(self._image_files)
50+
51+
def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, str]:
52+
image_file = self._image_files[idx]
53+
joint_states_file = self._joint_states_files[idx]
54+
image = cv2.imread(image_file)
55+
joint_states = np.load(joint_states_file)
56+
return image, joint_states, image_file.name

roboreg/io/filesystem.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pathlib import Path
2+
from typing import List, Union
3+
4+
5+
def find_files(path: Union[Path, str], pattern: str = "image_*.png") -> List[Path]:
6+
"""Find files in a directory.
7+
8+
Args:
9+
path (Union[Path, str]): Path to the directory.
10+
pattern (str): Pattern to match. Must include '_{number}.ext' format.
11+
12+
Returns:
13+
List[Path]: Sorted file paths.
14+
"""
15+
path = Path(path)
16+
file_paths = list(path.glob(pattern))
17+
return sorted(file_paths, key=lambda x: int(x.stem.split("_")[-1]))

0 commit comments

Comments
 (0)