Skip to content

Commit 9943e10

Browse files
authored
Merge pull request #312 from JdeRobot/refactor_dataset
Refactoring dataset and model classes
2 parents 9c02164 + 70412ad commit 9943e10

30 files changed

+2074
-219
lines changed

.DS_Store

6 KB
Binary file not shown.

detectionmetrics/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from detectionmetrics.datasets.rugd import RUGDImageSegmentationDataset
1818
from detectionmetrics.datasets.wildscenes import WildscenesImageSegmentationDataset
19-
19+
from detectionmetrics.datasets.coco import CocoDataset
2020

2121
REGISTRY = {
2222
"gaia_image_segmentation": GaiaImageSegmentationDataset,
@@ -29,4 +29,5 @@
2929
"rellis3d_lidar_segmentation": Rellis3DLiDARSegmentationDataset,
3030
"rugd_image_segmentation": RUGDImageSegmentationDataset,
3131
"wildscenes_image_segmentation": WildscenesImageSegmentationDataset,
32+
"coco_image_detection": CocoDataset,
3233
}

detectionmetrics/datasets/coco.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from pycocotools.coco import COCO
2+
import os
3+
import pandas as pd
4+
from typing import Tuple, List, Optional
5+
6+
from detectionmetrics.datasets.detection import ImageDetectionDataset
7+
8+
9+
def build_coco_dataset(
10+
annotation_file: str,
11+
image_dir: str,
12+
coco_obj: Optional[COCO] = None,
13+
split: str = "train",
14+
) -> Tuple[pd.DataFrame, dict]:
15+
"""Build dataset and ontology dictionaries from COCO dataset structure
16+
17+
:param annotation_file: Path to the COCO-format JSON annotation file
18+
:type annotation_file: str
19+
:param image_dir: Path to the directory containing image files
20+
:type image_dir: str
21+
:param coco_obj: Optional pre-loaded COCO object to reuse
22+
:type coco_obj: COCO
23+
:param split: Dataset split name (e.g., "train", "val", "test")
24+
:type split: str
25+
:return: Dataset DataFrame and ontology dictionary
26+
:rtype: Tuple[pd.DataFrame, dict]
27+
"""
28+
# Check that provided paths exist
29+
assert os.path.isfile(
30+
annotation_file
31+
), f"Annotation file not found: {annotation_file}"
32+
assert os.path.isdir(image_dir), f"Image directory not found: {image_dir}"
33+
34+
# Load COCO annotations (reuse if provided)
35+
if coco_obj is None:
36+
coco = COCO(annotation_file)
37+
else:
38+
coco = coco_obj
39+
40+
# Build ontology from COCO categories
41+
ontology = {}
42+
for cat in coco.loadCats(coco.getCatIds()):
43+
ontology[cat["name"]] = {
44+
"idx": cat["id"],
45+
# "name": cat["name"],
46+
"rgb": [0, 0, 0], # Placeholder; COCO doesn't define RGB colors
47+
}
48+
49+
# Build dataset DataFrame from COCO image IDs
50+
rows = []
51+
for img_id in coco.getImgIds():
52+
img_info = coco.loadImgs(img_id)[0]
53+
rows.append(
54+
{
55+
"image": img_info["file_name"],
56+
"annotation": str(img_id),
57+
"split": split, # Use provided split parameter
58+
}
59+
)
60+
61+
dataset = pd.DataFrame(rows)
62+
dataset.attrs = {"ontology": ontology}
63+
64+
return dataset, ontology
65+
66+
67+
class CocoDataset(ImageDetectionDataset):
68+
"""
69+
Specific class for COCO-styled object detection datasets.
70+
71+
:param annotation_file: Path to the COCO-format JSON annotation file
72+
:type annotation_file: str
73+
:param image_dir: Path to the directory containing image files
74+
:type image_dir: str
75+
:param split: Dataset split name (e.g., "train", "val", "test")
76+
:type split: str
77+
"""
78+
79+
def __init__(self, annotation_file: str, image_dir: str, split: str = "train"):
80+
# Load COCO object once
81+
self.coco = COCO(annotation_file)
82+
self.image_dir = image_dir
83+
self.split = split
84+
85+
# Build dataset using the same COCO object and split
86+
dataset, ontology = build_coco_dataset(
87+
annotation_file, image_dir, self.coco, split=split
88+
)
89+
90+
super().__init__(dataset=dataset, dataset_dir=image_dir, ontology=ontology)
91+
92+
def read_annotation(
93+
self, fname: str
94+
) -> Tuple[List[List[float]], List[int], List[int]]:
95+
"""Return bounding boxes, labels, and category_ids for a given image ID.
96+
97+
:param fname: str (image_id in string form)
98+
:return: Tuple of (boxes, labels, category_ids)
99+
"""
100+
# Extract image ID (fname might be a path or ID string)
101+
try:
102+
image_id = int(
103+
os.path.basename(fname)
104+
) # handles both '123' and '/path/to/123'
105+
except ValueError:
106+
raise ValueError(f"Invalid annotation ID: {fname}")
107+
108+
ann_ids = self.coco.getAnnIds(imgIds=image_id)
109+
anns = self.coco.loadAnns(ann_ids)
110+
111+
boxes = []
112+
labels = []
113+
category_ids = []
114+
115+
for ann in anns:
116+
# Convert [x, y, width, height] to [x1, y1, x2, y2]
117+
x, y, w, h = ann["bbox"]
118+
boxes.append([x, y, x + w, y + h])
119+
labels.append(ann["category_id"])
120+
category_ids.append(ann["category_id"])
121+
122+
return boxes, labels, category_ids
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from abc import ABC, abstractmethod
2+
import os
3+
import shutil
4+
from typing import List, Optional, Tuple
5+
from typing_extensions import Self
6+
7+
import cv2
8+
import numpy as np
9+
import pandas as pd
10+
from tqdm import tqdm
11+
12+
from detectionmetrics.datasets.perception import PerceptionDataset
13+
import detectionmetrics.utils.io as uio
14+
import detectionmetrics.utils.conversion as uc
15+
16+
17+
class DetectionDataset(PerceptionDataset):
18+
"""Abstract perception detection dataset class."""
19+
20+
@abstractmethod
21+
def read_annotation(self, fname: str):
22+
"""Read detection annotation from a file.
23+
24+
:param fname: Annotation file name
25+
"""
26+
raise NotImplementedError
27+
28+
def get_label_count(self, splits: Optional[List[str]] = None):
29+
"""Count detection labels per class for given splits.
30+
31+
:param splits: List of splits to consider
32+
:return: Numpy array of label counts per class
33+
"""
34+
if splits is None:
35+
splits = ["train", "val"]
36+
37+
df = self.dataset[self.dataset["split"].isin(splits)]
38+
n_classes = max(c["idx"] for c in self.ontology.values()) + 1
39+
label_count = np.zeros(n_classes, dtype=np.uint64)
40+
41+
for annotation_file in tqdm(df["annotation"], desc="Counting labels"):
42+
annots = self.read_annotation(annotation_file)
43+
for annot in annots:
44+
class_idx = annot[
45+
"category_id"
46+
] # Should override the key category_id if needed in specific dataset class
47+
label_count[class_idx] += 1
48+
49+
return label_count
50+
51+
52+
class ImageDetectionDataset(DetectionDataset):
53+
"""Image detection dataset class."""
54+
55+
def make_fname_global(self):
56+
"""Convert relative filenames in 'image' and 'annotation' columns to global paths."""
57+
if self.dataset_dir is not None:
58+
self.dataset["image"] = self.dataset["image"].apply(
59+
lambda x: os.path.join(self.dataset_dir, x) if x is not None else None
60+
)
61+
self.dataset["annotation"] = self.dataset["annotation"].apply(
62+
lambda x: os.path.join(self.dataset_dir, x) if x is not None else None
63+
)
64+
self.dataset_dir = None
65+
66+
def read_annotation(self, fname: str):
67+
"""Read detection annotation from a file.
68+
69+
Override this based on annotation format (e.g., COCO JSON, XML, TXT).
70+
71+
:param fname: Annotation filename
72+
:return: Parsed annotations (e.g., list of dicts)
73+
"""
74+
# TODO implement COCO or VOC parsing in their classes separately.
75+
raise NotImplementedError("Implement annotation reading logic")
76+
77+
78+
class LiDARDetectionDataset(DetectionDataset):
79+
"""LiDAR detection dataset class."""
80+
81+
def __init__(
82+
self,
83+
dataset: pd.DataFrame,
84+
dataset_dir: str,
85+
ontology: dict,
86+
is_kitti_format: bool = True,
87+
):
88+
super().__init__(dataset, dataset_dir, ontology)
89+
self.is_kitti_format = is_kitti_format
90+
91+
def make_fname_global(self):
92+
if self.dataset_dir is not None:
93+
self.dataset["points"] = self.dataset["points"].apply(
94+
lambda x: os.path.join(self.dataset_dir, x) if x is not None else None
95+
)
96+
self.dataset["annotation"] = self.dataset["annotation"].apply(
97+
lambda x: os.path.join(self.dataset_dir, x) if x is not None else None
98+
)
99+
self.dataset_dir = None
100+
101+
def read_annotation(self, fname: str):
102+
"""Read LiDAR detection annotation.
103+
104+
For example, read KITTI format label files or custom format.
105+
106+
:param fname: Annotation file path
107+
:return: Parsed annotations (e.g., list of dicts)
108+
"""
109+
# TODO Implement format specific parsing
110+
raise NotImplementedError("Implement LiDAR detection annotation reading")

detectionmetrics/datasets/gaia.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas as pd
55

6-
from detectionmetrics.datasets import dataset as dm_dataset
6+
from detectionmetrics.datasets import segmentation as dm_segmentation_dataset
77
import detectionmetrics.utils.io as uio
88

99

@@ -34,7 +34,7 @@ def build_dataset(dataset_fname: str) -> Tuple[pd.DataFrame, str, dict]:
3434
return dataset, dataset_dir, ontology
3535

3636

37-
class GaiaImageSegmentationDataset(dm_dataset.ImageSegmentationDataset):
37+
class GaiaImageSegmentationDataset(dm_segmentation_dataset.ImageSegmentationDataset):
3838
"""Specific class for GAIA-styled image segmentation datasets
3939
4040
:param dataset_fname: Parquet dataset filename
@@ -46,7 +46,7 @@ def __init__(self, dataset_fname: str):
4646
super().__init__(dataset, dataset_dir, ontology)
4747

4848

49-
class GaiaLiDARSegmentationDataset(dm_dataset.LiDARSegmentationDataset):
49+
class GaiaLiDARSegmentationDataset(dm_segmentation_dataset.LiDARSegmentationDataset):
5050
"""Specific class for GAIA-styled LiDAR segmentation datasets
5151
5252
:param dataset_fname: Parquet dataset filename

detectionmetrics/datasets/generic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pandas as pd
88

9-
from detectionmetrics.datasets import dataset as dm_dataset
9+
from detectionmetrics.datasets import segmentation as dm_segmentation_dataset
1010
import detectionmetrics.utils.io as uio
1111

1212

@@ -111,7 +111,7 @@ def build_dataset(
111111
return dataset, ontology
112112

113113

114-
class GenericImageSegmentationDataset(dm_dataset.ImageSegmentationDataset):
114+
class GenericImageSegmentationDataset(dm_segmentation_dataset.ImageSegmentationDataset):
115115
"""Generic class for image segmentation datasets.
116116
117117
:param data_suffix: File suffix to be used to filter data
@@ -160,7 +160,7 @@ def __init__(
160160
super().__init__(dataset, dataset_dir, ontology)
161161

162162

163-
class GenericLiDARSegmentationDataset(dm_dataset.LiDARSegmentationDataset):
163+
class GenericLiDARSegmentationDataset(dm_segmentation_dataset.LiDARSegmentationDataset):
164164
"""Generic class for LiDAR segmentation datasets.
165165
166166
:param data_suffix: File suffix to be used to filter data

detectionmetrics/datasets/goose.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pandas as pd
77

8-
from detectionmetrics.datasets import dataset as dm_dataset
8+
from detectionmetrics.datasets import segmentation as dm_segmentation_dataset
99
import detectionmetrics.utils.conversion as uc
1010

1111

@@ -84,7 +84,7 @@ def build_dataset(
8484
return dataset, ontology
8585

8686

87-
class GOOSEImageSegmentationDataset(dm_dataset.ImageSegmentationDataset):
87+
class GOOSEImageSegmentationDataset(dm_segmentation_dataset.ImageSegmentationDataset):
8888
"""Specific class for GOOSE-styled image segmentation datasets. All data can be
8989
downloaded from the official webpage (https://goose-dataset.de):
9090
train -> https://goose-dataset.de/storage/goose_2d_train.zip
@@ -128,7 +128,7 @@ def __init__(
128128
super().__init__(dataset, dataset_dir, ontology)
129129

130130

131-
class GOOSELiDARSegmentationDataset(dm_dataset.LiDARSegmentationDataset):
131+
class GOOSELiDARSegmentationDataset(dm_segmentation_dataset.LiDARSegmentationDataset):
132132
"""Specific class for GOOSE-styled LiDAR segmentation datasets. All data can be
133133
downloaded from the official webpage (https://goose-dataset.de):
134134
train -> https://goose-dataset.de/storage/goose_3d_train.zip

0 commit comments

Comments
 (0)