|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import os.path as osp |
| 3 | +import warnings |
| 4 | +from typing import Dict, List, Optional, Sequence, Union |
| 5 | + |
| 6 | +import mmcv |
| 7 | +import mmengine |
| 8 | +import numpy as np |
| 9 | +from mmengine.dataset import Compose |
| 10 | +from mmengine.infer.infer import ModelType |
| 11 | +from mmengine.structures import InstanceData |
| 12 | + |
| 13 | +from mmdet3d.registry import INFERENCERS |
| 14 | +from mmdet3d.utils import ConfigType |
| 15 | +from .base_3d_inferencer import Base3DInferencer |
| 16 | + |
| 17 | +InstanceList = List[InstanceData] |
| 18 | +InputType = Union[str, np.ndarray] |
| 19 | +InputsType = Union[InputType, Sequence[InputType]] |
| 20 | +PredType = Union[InstanceData, InstanceList] |
| 21 | +ImgType = Union[np.ndarray, Sequence[np.ndarray]] |
| 22 | +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] |
| 23 | + |
| 24 | + |
| 25 | +@INFERENCERS.register_module(name='det3d-multi_modality') |
| 26 | +@INFERENCERS.register_module() |
| 27 | +class MultiModalityDet3DInferencer(Base3DInferencer): |
| 28 | + """The inferencer of multi-modality detection. |
| 29 | +
|
| 30 | + Args: |
| 31 | + model (str, optional): Path to the config file or the model name |
| 32 | + defined in metafile. For example, it could be |
| 33 | + "pointpillars_kitti-3class" or |
| 34 | + "configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py". # noqa: E501 |
| 35 | + If model is not specified, user must provide the |
| 36 | + `weights` saved by MMEngine which contains the config string. |
| 37 | + Defaults to None. |
| 38 | + weights (str, optional): Path to the checkpoint. If it is not specified |
| 39 | + and model is a model name of metafile, the weights will be loaded |
| 40 | + from metafile. Defaults to None. |
| 41 | + device (str, optional): Device to run inference. If None, the available |
| 42 | + device will be automatically used. Defaults to None. |
| 43 | + scope (str): The scope of registry. Defaults to 'mmdet3d'. |
| 44 | + palette (str): The palette of visualization. Defaults to 'none'. |
| 45 | + """ |
| 46 | + |
| 47 | + preprocess_kwargs: set = set() |
| 48 | + forward_kwargs: set = set() |
| 49 | + visualize_kwargs: set = { |
| 50 | + 'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr', |
| 51 | + 'img_out_dir' |
| 52 | + } |
| 53 | + postprocess_kwargs: set = { |
| 54 | + 'print_result', 'pred_out_file', 'return_datasample' |
| 55 | + } |
| 56 | + |
| 57 | + def __init__(self, |
| 58 | + model: Union[ModelType, str, None] = None, |
| 59 | + weights: Optional[str] = None, |
| 60 | + device: Optional[str] = None, |
| 61 | + scope: str = 'mmdet3d', |
| 62 | + palette: str = 'none') -> None: |
| 63 | + # A global counter tracking the number of frames processed, for |
| 64 | + # naming of the output results |
| 65 | + self.num_visualized_frames = 0 |
| 66 | + super(MultiModalityDet3DInferencer, self).__init__( |
| 67 | + model=model, |
| 68 | + weights=weights, |
| 69 | + device=device, |
| 70 | + scope=scope, |
| 71 | + palette=palette) |
| 72 | + |
| 73 | + def _inputs_to_list(self, inputs: Union[dict, list]) -> list: |
| 74 | + """Preprocess the inputs to a list. |
| 75 | +
|
| 76 | + Preprocess inputs to a list according to its type: |
| 77 | +
|
| 78 | + - list or tuple: return inputs |
| 79 | + - dict: the value with key 'points' is |
| 80 | + - Directory path: return all files in the directory |
| 81 | + - other cases: return a list containing the string. The string |
| 82 | + could be a path to file, a url or other types of string according |
| 83 | + to the task. |
| 84 | +
|
| 85 | + Args: |
| 86 | + inputs (Union[dict, list]): Inputs for the inferencer. |
| 87 | +
|
| 88 | + Returns: |
| 89 | + list: List of input for the :meth:`preprocess`. |
| 90 | + """ |
| 91 | + return super()._inputs_to_list(inputs, modality_key=['points', 'img']) |
| 92 | + |
| 93 | + def _init_pipeline(self, cfg: ConfigType) -> Compose: |
| 94 | + """Initialize the test pipeline.""" |
| 95 | + pipeline_cfg = cfg.test_dataloader.dataset.pipeline |
| 96 | + |
| 97 | + load_point_idx = self._get_transform_idx(pipeline_cfg, |
| 98 | + 'LoadPointsFromFile') |
| 99 | + load_mv_img_idx = self._get_transform_idx( |
| 100 | + pipeline_cfg, 'LoadMultiViewImageFromFiles') |
| 101 | + if load_mv_img_idx != -1: |
| 102 | + warnings.warn( |
| 103 | + 'LoadMultiViewImageFromFiles is not supported yet in the ' |
| 104 | + 'multi-modality inferencer. Please remove it') |
| 105 | + # Now, we only support ``LoadImageFromFile`` as the image loader in the |
| 106 | + # original piepline. `LoadMultiViewImageFromFiles` is not supported |
| 107 | + # yet. |
| 108 | + load_img_idx = self._get_transform_idx(pipeline_cfg, |
| 109 | + 'LoadImageFromFile') |
| 110 | + |
| 111 | + if load_point_idx == -1 or load_img_idx == -1: |
| 112 | + raise ValueError( |
| 113 | + 'Both LoadPointsFromFile and LoadImageFromFile must ' |
| 114 | + 'be specified the pipeline, but LoadPointsFromFile is ' |
| 115 | + f'{load_point_idx == -1} and LoadImageFromFile is ' |
| 116 | + f'{load_img_idx}') |
| 117 | + |
| 118 | + load_cfg = pipeline_cfg[load_point_idx] |
| 119 | + self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[ |
| 120 | + 'load_dim'] |
| 121 | + self.use_dim = list(range(load_cfg['use_dim'])) if isinstance( |
| 122 | + load_cfg['use_dim'], int) else load_cfg['use_dim'] |
| 123 | + |
| 124 | + load_point_args = pipeline_cfg[load_point_idx] |
| 125 | + load_point_args.pop('type') |
| 126 | + load_img_args = pipeline_cfg[load_img_idx] |
| 127 | + load_img_args.pop('type') |
| 128 | + |
| 129 | + load_idx = min(load_point_idx, load_img_idx) |
| 130 | + pipeline_cfg.pop(max(load_point_idx, load_img_idx)) |
| 131 | + |
| 132 | + pipeline_cfg[load_idx] = dict( |
| 133 | + type='MultiModalityDet3DInferencerLoader', |
| 134 | + load_point_args=load_point_args, |
| 135 | + load_img_args=load_img_args) |
| 136 | + |
| 137 | + return Compose(pipeline_cfg) |
| 138 | + |
| 139 | + def visualize(self, |
| 140 | + inputs: InputsType, |
| 141 | + preds: PredType, |
| 142 | + return_vis: bool = False, |
| 143 | + show: bool = False, |
| 144 | + wait_time: int = 0, |
| 145 | + draw_pred: bool = True, |
| 146 | + pred_score_thr: float = 0.3, |
| 147 | + img_out_dir: str = '') -> Union[List[np.ndarray], None]: |
| 148 | + """Visualize predictions. |
| 149 | +
|
| 150 | + Args: |
| 151 | + inputs (InputsType): Inputs for the inferencer. |
| 152 | + preds (PredType): Predictions of the model. |
| 153 | + return_vis (bool): Whether to return the visualization result. |
| 154 | + Defaults to False. |
| 155 | + show (bool): Whether to display the image in a popup window. |
| 156 | + Defaults to False. |
| 157 | + wait_time (float): The interval of show (s). Defaults to 0. |
| 158 | + draw_pred (bool): Whether to draw predicted bounding boxes. |
| 159 | + Defaults to True. |
| 160 | + pred_score_thr (float): Minimum score of bboxes to draw. |
| 161 | + Defaults to 0.3. |
| 162 | + img_out_dir (str): Output directory of visualization results. |
| 163 | + If left as empty, no file will be saved. Defaults to ''. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + List[np.ndarray] or None: Returns visualization results only if |
| 167 | + applicable. |
| 168 | + """ |
| 169 | + if self.visualizer is None or (not show and img_out_dir == '' |
| 170 | + and not return_vis): |
| 171 | + return None |
| 172 | + |
| 173 | + if getattr(self, 'visualizer') is None: |
| 174 | + raise ValueError('Visualization needs the "visualizer" term' |
| 175 | + 'defined in the config, but got None.') |
| 176 | + |
| 177 | + results = [] |
| 178 | + |
| 179 | + for single_input, pred in zip(inputs, preds): |
| 180 | + points_input = single_input['points'] |
| 181 | + if isinstance(points_input, str): |
| 182 | + pts_bytes = mmengine.fileio.get(points_input) |
| 183 | + points = np.frombuffer(pts_bytes, dtype=np.float32) |
| 184 | + points = points.reshape(-1, self.load_dim) |
| 185 | + points = points[:, self.use_dim] |
| 186 | + pc_name = osp.basename(points_input).split('.bin')[0] |
| 187 | + pc_name = f'{pc_name}.png' |
| 188 | + elif isinstance(points_input, np.ndarray): |
| 189 | + points = points_input.copy() |
| 190 | + pc_num = str(self.num_visualized_frames).zfill(8) |
| 191 | + pc_name = f'pc_{pc_num}.png' |
| 192 | + else: |
| 193 | + raise ValueError('Unsupported input type: ' |
| 194 | + f'{type(points_input)}') |
| 195 | + |
| 196 | + o3d_save_path = osp.join(img_out_dir, pc_name) \ |
| 197 | + if img_out_dir != '' else None |
| 198 | + |
| 199 | + img_input = single_input['img'] |
| 200 | + if isinstance(single_input['img'], str): |
| 201 | + img_bytes = mmengine.fileio.get(img_input) |
| 202 | + img = mmcv.imfrombytes(img_bytes) |
| 203 | + img = img[:, :, ::-1] |
| 204 | + img_name = osp.basename(img_input) |
| 205 | + elif isinstance(img_input, np.ndarray): |
| 206 | + img = img_input.copy() |
| 207 | + img_num = str(self.num_visualized_frames).zfill(8) |
| 208 | + img_name = f'{img_num}.jpg' |
| 209 | + else: |
| 210 | + raise ValueError('Unsupported input type: ' |
| 211 | + f'{type(img_input)}') |
| 212 | + |
| 213 | + out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ |
| 214 | + else None |
| 215 | + |
| 216 | + data_input = dict(points=points, img=img) |
| 217 | + self.visualizer.add_datasample( |
| 218 | + pc_name, |
| 219 | + data_input, |
| 220 | + pred, |
| 221 | + show=show, |
| 222 | + wait_time=wait_time, |
| 223 | + draw_gt=False, |
| 224 | + draw_pred=draw_pred, |
| 225 | + pred_score_thr=pred_score_thr, |
| 226 | + o3d_save_path=o3d_save_path, |
| 227 | + out_file=out_file, |
| 228 | + vis_task='multi-modality_det', |
| 229 | + ) |
| 230 | + results.append(points) |
| 231 | + self.num_visualized_frames += 1 |
| 232 | + |
| 233 | + return results |
0 commit comments