Skip to content

Commit 3af9d0c

Browse files
committed
Support MMPOSE RTMO
1 parent e81a8c7 commit 3af9d0c

File tree

12 files changed

+638
-27
lines changed

12 files changed

+638
-27
lines changed

README.MD

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ It currently focuses on two types of pipeline:
1616

1717
## Features
1818

19-
- [Detectron2](https://detectron2.readthedocs.io/en/latest/index.html) is used for object detection (Faster-RCNN) and instance segmentation (Mask-RCNN)
19+
- [Detectron2](https://detectron2.readthedocs.io/en/latest/index.html) for Object Detection (Faster-RCNN) and Instance Segmentation (Mask-RCNN)
2020

21-
- [JDE](https://github.com/Zhongdao/Towards-Realtime-MOT) is used for Object Tracking
21+
- [JDE](https://github.com/Zhongdao/Towards-Realtime-MOT) for Object Tracking
2222

23-
- [YOLOX-Darknet53](https://github.com/Megvii-BaseDetection/YOLOX) is used for object detection
23+
- [YOLOX-Darknet53](https://github.com/Megvii-BaseDetection/YOLOX) for Object Detection
24+
25+
- [MMPOSE RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) for Pose Estimation (Bottom Up)
2426

2527
## Documentation
2628

@@ -139,3 +141,4 @@ Fabien Racapé, Hyomin Choi, Eimran Eimon, Sampsa Riikonen, Jacky Yat-Hong Lam
139141
* [Detectron2](https://detectron2.readthedocs.io/en/latest/index.html)
140142
* [JDE](https://github.com/Zhongdao/Towards-Realtime-MOT.git)
141143
* [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
144+
* [MMPOSE RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo)

cfgs/evaluator/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
type: "COCO-EVAL"
2-
output: "${pipeline.evaluation.evaluation_dir}"
2+
output_dir: "${pipeline.evaluation.evaluation_dir}"
33
overwrite_results: False
4-
eval_criteria: ""
4+
eval_criteria: ""

cfgs/vision_model/default.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,10 @@ yolox_darknet53:
4747
nms_thres: 0.65
4848
weights: "weights/yolox/darknet53/yolox_darknet.pth"
4949
splits: "l13" #"l37"
50-
squeeze_at_split: False
50+
squeeze_at_split: False
51+
52+
rtmo_multi_person_pose_estimation:
53+
model_path_prefix: ${..model_root_path}
54+
cfg: "models/mmpose/configs/body_2d_keypoint/rtmo/coco/rtmo-l_16xb16-600e_coco-640x640.py"
55+
weights: "weights/mmpose/rtmo_coco/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth"
56+
splits: "backbone"

compressai_vision/config/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ def create_evaluator(
122122
if conf.type is None:
123123
return None
124124

125-
return EVALUATORS[conf.type](
126-
catalog, datasetname, dataset, conf.output, conf.eval_criteria
127-
)
125+
return EVALUATORS[conf.type](catalog, datasetname, dataset, **dict(conf))
128126

129127

130128
def create_pipline(conf: DictConfig, device: DictConfig):

compressai_vision/datasets/image.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
from compressai_vision.registry import register_datacatalog, register_dataset
4949

50-
from .utils import JDECustomMapper, LinearMapper, YOLOXCustomMapper
50+
from .utils import JDECustomMapper, LinearMapper, MMPOSECustomMapper, YOLOXCustomMapper
5151

5252

5353
def manual_load_data(path, ext):
@@ -337,6 +337,48 @@ def __len__(self):
337337
return len(self.mapDataset)
338338

339339

340+
@register_dataset("MMPOSEDataset")
341+
class MMPOSEDataset(BaseDataset):
342+
def __init__(self, root, dataset_name, imgs_folder, **kwargs):
343+
super().__init__(root, dataset_name, imgs_folder, **kwargs)
344+
345+
self.dataset = kwargs["dataset"].dataset
346+
347+
self.sampler = InferenceSampler(len(kwargs["dataset"]))
348+
self.collate_fn = bypass_collator
349+
350+
_dataset = DatasetFromList(self.dataset, copy=False)
351+
352+
if kwargs["linear_mapper"] is True:
353+
mapper = LinearMapper()
354+
else:
355+
mapper = MMPOSECustomMapper(kwargs["patch_size"])
356+
357+
self.input_size = kwargs["patch_size"]
358+
self.mapDataset = MapDataset(_dataset, mapper)
359+
self._org_mapper_func = PicklableWrapper(
360+
MMPOSECustomMapper(kwargs["patch_size"])
361+
)
362+
363+
metaData = MetadataCatalog.get(dataset_name)
364+
try:
365+
self.thing_classes = metaData.thing_classes
366+
self.thing_dataset_id_to_contiguous_id = (
367+
metaData.thing_dataset_id_to_contiguous_id
368+
)
369+
except AttributeError:
370+
self.logger.warning("No attribute: thing_classes")
371+
372+
def get_org_mapper_func(self):
373+
return self._org_mapper_func
374+
375+
def __getitem__(self, idx):
376+
return self.mapDataset[idx]
377+
378+
def __len__(self):
379+
return len(self.mapDataset)
380+
381+
340382
class DataCatalog:
341383
def __init__(
342384
self,

compressai_vision/datasets/utils.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
import numpy as np
3636
import torch
3737
from jde.utils.datasets import letterbox
38+
from mmpose.structures.bbox import get_warp_matrix
3839
from torchvision import transforms
3940

40-
__all__ = ["YOLOXCustomMapper", "JDECustomMapper", "LinearMapper"]
41+
__all__ = ["MMPOSECustomMapper", "YOLOXCustomMapper", "JDECustomMapper", "LinearMapper"]
4142

4243

4344
def yolox_style_scaling(img, input_size, padding=False):
@@ -58,6 +59,112 @@ def yolox_style_scaling(img, input_size, padding=False):
5859
return resized_img
5960

6061

62+
class MMPOSECustomMapper:
63+
"""
64+
A callable which takes a dataset dict in CompressAI-Vision generic dataset format, but for MMPOSE (particularly, RTMO model) evaluation,
65+
and map it into a format used by the model.
66+
67+
This is the default callable to be used to map your dataset dict into inference data.
68+
69+
This callable function refers to
70+
preproc function at
71+
<https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/datasets/transforms/bottomup_transforms.py>
72+
73+
Full license statement can be found at
74+
<https://github.com/open-mmlab/mmpose?tab=Apache-2.0-1-ov-file#readme>
75+
76+
"""
77+
78+
def __init__(
79+
self,
80+
img_size=[640, 640],
81+
size_factor=32,
82+
pad_val=[114, 114, 114],
83+
aug_transforms=None,
84+
):
85+
"""
86+
Args:
87+
img_size: expected input size (Height, Width)
88+
"""
89+
90+
self.input_img_size = img_size
91+
self.pad_val = pad_val
92+
assert img_size[0] % size_factor == 0 and img_size[1] % size_factor == 0
93+
94+
if aug_transforms != None:
95+
self.aug_transforms = aug_transforms
96+
else:
97+
self.aug_transforms = transforms.Compose([transforms.ToTensor()])
98+
99+
def compute_scale_and_center(self, src_img_width, src_img_height):
100+
_input_h, _input_w = self.input_img_size
101+
_ratio = src_img_width / src_img_height
102+
_scaled_input_w = min(_input_w, _input_h * _ratio)
103+
_scaled_input_h = min(_input_h, _input_w / _ratio)
104+
105+
center = np.array([src_img_width / 2, src_img_height / 2], dtype=np.float32)
106+
scale = np.array(
107+
[
108+
src_img_width * _input_w / _scaled_input_w,
109+
src_img_height * _input_h / _scaled_input_h,
110+
],
111+
dtype=np.float32,
112+
)
113+
114+
return scale, center
115+
116+
def __call__(self, dataset_dict):
117+
"""
118+
Args:
119+
dataset_dict (dict): Metadata of one image.
120+
121+
Returns:
122+
dict: a format that compressai-vision pipelines accept
123+
"""
124+
125+
dataset_dict = copy.deepcopy(dataset_dict)
126+
# the copied dictionary will be modified by code below
127+
128+
dataset_dict.pop("annotations", None)
129+
130+
# tried to replicate the implemetation of the original codes
131+
# Read image
132+
org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default
133+
134+
assert (
135+
len(org_img.shape) == 3
136+
), f"detect an input image with 2 chs, {dataset_dict['file_name']}"
137+
138+
img_h, img_w, _ = org_img.shape
139+
140+
dataset_dict["height"] = img_h
141+
dataset_dict["width"] = img_w
142+
143+
_input_h, _input_w = self.input_img_size
144+
# mmpose style scaling
145+
scale, center = self.compute_scale_and_center(img_w, img_h)
146+
147+
warp_mat = get_warp_matrix(
148+
center=center, scale=scale, rot=0, output_size=(_input_w, _input_h)
149+
)
150+
151+
resized_img = cv2.warpAffine(
152+
org_img,
153+
warp_mat,
154+
(_input_w, _input_h),
155+
flags=cv2.INTER_LINEAR,
156+
borderValue=self.pad_val,
157+
)
158+
159+
tensor_image = self.aug_transforms(
160+
np.ascontiguousarray(resized_img, dtype=np.float32)
161+
)
162+
163+
dataset_dict["image"] = tensor_image
164+
165+
return dataset_dict
166+
167+
61168
class YOLOXCustomMapper:
62169
"""
63170
A callable which takes a dataset dict in CompressAI-Vision generic dataset format, but for YOLOX evaluation,

0 commit comments

Comments
 (0)