|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import Optional, Sequence |
| 3 | + |
| 4 | +import torch |
| 5 | +from mmengine.registry import MODELS |
| 6 | +from torch import Tensor, nn |
| 7 | + |
| 8 | +from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE |
| 9 | +from mmdet3d.utils import OptMultiConfig |
| 10 | +from .minkunet_backbone import MinkUNetBackbone |
| 11 | + |
| 12 | +if IS_TORCHSPARSE_AVAILABLE: |
| 13 | + import torchsparse |
| 14 | + import torchsparse.nn.functional as F |
| 15 | + from torchsparse.nn.utils import get_kernel_offsets |
| 16 | + from torchsparse.tensor import PointTensor, SparseTensor |
| 17 | +else: |
| 18 | + PointTensor = SparseTensor = None |
| 19 | + |
| 20 | + |
| 21 | +@MODELS.register_module() |
| 22 | +class SPVCNNBackbone(MinkUNetBackbone): |
| 23 | + """SPVCNN backbone with torchsparse backend. |
| 24 | +
|
| 25 | + More details can be found in `paper <https://arxiv.org/abs/2007.16100>`_ . |
| 26 | +
|
| 27 | + Args: |
| 28 | + in_channels (int): Number of input voxel feature channels. |
| 29 | + Defaults to 4. |
| 30 | + base_channels (int): The input channels for first encoder layer. |
| 31 | + Defaults to 32. |
| 32 | + encoder_channels (List[int]): Convolutional channels of each encode |
| 33 | + layer. Defaults to [32, 64, 128, 256]. |
| 34 | + decoder_channels (List[int]): Convolutional channels of each decode |
| 35 | + layer. Defaults to [256, 128, 96, 96]. |
| 36 | + num_stages (int): Number of stages in encoder and decoder. |
| 37 | + Defaults to 4. |
| 38 | + drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3. |
| 39 | + init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`] |
| 40 | + , optional): Initialization config dict. Defaults to None. |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, |
| 44 | + in_channels: int = 4, |
| 45 | + base_channels: int = 32, |
| 46 | + encoder_channels: Sequence[int] = [32, 64, 128, 256], |
| 47 | + decoder_channels: Sequence[int] = [256, 128, 96, 96], |
| 48 | + num_stages: int = 4, |
| 49 | + drop_ratio: float = 0.3, |
| 50 | + init_cfg: OptMultiConfig = None) -> None: |
| 51 | + super().__init__( |
| 52 | + in_channels=in_channels, |
| 53 | + base_channels=base_channels, |
| 54 | + encoder_channels=encoder_channels, |
| 55 | + decoder_channels=decoder_channels, |
| 56 | + num_stages=num_stages, |
| 57 | + init_cfg=init_cfg) |
| 58 | + |
| 59 | + self.point_transforms = nn.ModuleList([ |
| 60 | + nn.Sequential( |
| 61 | + nn.Linear(base_channels, encoder_channels[-1]), |
| 62 | + nn.BatchNorm1d(encoder_channels[-1]), nn.ReLU(True)), |
| 63 | + nn.Sequential( |
| 64 | + nn.Linear(encoder_channels[-1], decoder_channels[2]), |
| 65 | + nn.BatchNorm1d(decoder_channels[2]), nn.ReLU(True)), |
| 66 | + nn.Sequential( |
| 67 | + nn.Linear(decoder_channels[2], decoder_channels[4]), |
| 68 | + nn.BatchNorm1d(decoder_channels[4]), nn.ReLU(True)) |
| 69 | + ]) |
| 70 | + self.dropout = nn.Dropout(drop_ratio, True) |
| 71 | + |
| 72 | + def forward(self, voxel_features: Tensor, coors: Tensor) -> PointTensor: |
| 73 | + """Forward function. |
| 74 | +
|
| 75 | + Args: |
| 76 | + voxel_features (Tensor): Voxel features in shape (N, C). |
| 77 | + coors (Tensor): Coordinates in shape (N, 4), |
| 78 | + the columns in the order of (x_idx, y_idx, z_idx, batch_idx). |
| 79 | +
|
| 80 | + Returns: |
| 81 | + PointTensor: Backbone features. |
| 82 | + """ |
| 83 | + voxels = SparseTensor(voxel_features, coors) |
| 84 | + points = PointTensor(voxels.F, voxels.C.float()) |
| 85 | + voxels = self.initial_voxelize(points) |
| 86 | + |
| 87 | + voxels = self.conv_input(voxels) |
| 88 | + points = self.voxel_to_point(voxels, points) |
| 89 | + voxels = self.point_to_voxel(voxels, points) |
| 90 | + laterals = [voxels] |
| 91 | + for encoder in self.encoder: |
| 92 | + voxels = encoder(voxels) |
| 93 | + laterals.append(voxels) |
| 94 | + laterals = laterals[:-1][::-1] |
| 95 | + |
| 96 | + points = self.voxel_to_point(voxels, points, self.point_transforms[0]) |
| 97 | + voxels = self.point_to_voxel(voxels, points) |
| 98 | + voxels.F = self.dropout(voxels.F) |
| 99 | + |
| 100 | + decoder_outs = [] |
| 101 | + for i, decoder in enumerate(self.decoder): |
| 102 | + voxels = decoder[0](voxels) |
| 103 | + voxels = torchsparse.cat((voxels, laterals[i])) |
| 104 | + voxels = decoder[1](voxels) |
| 105 | + decoder_outs.append(voxels) |
| 106 | + if i == 1: |
| 107 | + points = self.voxel_to_point(voxels, points, |
| 108 | + self.point_transforms[1]) |
| 109 | + voxels = self.point_to_voxel(voxels, points) |
| 110 | + voxels.F = self.dropout(voxels.F) |
| 111 | + |
| 112 | + points = self.voxel_to_point(voxels, points, self.point_transforms[2]) |
| 113 | + return points |
| 114 | + |
| 115 | + def initial_voxelize(self, points: PointTensor) -> SparseTensor: |
| 116 | + """Voxelization again based on input PointTensor. |
| 117 | +
|
| 118 | + Args: |
| 119 | + points (PointTensor): Input points after voxelization. |
| 120 | +
|
| 121 | + Returns: |
| 122 | + SparseTensor: New voxels. |
| 123 | + """ |
| 124 | + pc_hash = F.sphash(torch.floor(points.C).int()) |
| 125 | + sparse_hash = torch.unique(pc_hash) |
| 126 | + idx_query = F.sphashquery(pc_hash, sparse_hash) |
| 127 | + counts = F.spcount(idx_query.int(), len(sparse_hash)) |
| 128 | + |
| 129 | + inserted_coords = F.spvoxelize( |
| 130 | + torch.floor(points.C), idx_query, counts) |
| 131 | + inserted_coords = torch.round(inserted_coords).int() |
| 132 | + inserted_feat = F.spvoxelize(points.F, idx_query, counts) |
| 133 | + |
| 134 | + new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) |
| 135 | + new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) |
| 136 | + points.additional_features['idx_query'][1] = idx_query |
| 137 | + points.additional_features['counts'][1] = counts |
| 138 | + return new_tensor |
| 139 | + |
| 140 | + def voxel_to_point(self, |
| 141 | + voxels: SparseTensor, |
| 142 | + points: PointTensor, |
| 143 | + point_transform: Optional[nn.Module] = None, |
| 144 | + nearest: bool = False) -> PointTensor: |
| 145 | + """Feed voxel features to points. |
| 146 | +
|
| 147 | + Args: |
| 148 | + voxels (SparseTensor): Input voxels. |
| 149 | + points (PointTensor): Input points. |
| 150 | + point_transform (nn.Module, optional): Point transform module |
| 151 | + for input point features. Defaults to None. |
| 152 | + nearest (bool): Whether to use nearest neighbor interpolation. |
| 153 | + Defaults to False. |
| 154 | +
|
| 155 | + Returns: |
| 156 | + PointTensor: Points with new features. |
| 157 | + """ |
| 158 | + if points.idx_query is None or points.weights is None or \ |
| 159 | + points.idx_query.get(voxels.s) is None or \ |
| 160 | + points.weights.get(voxels.s) is None: |
| 161 | + offsets = get_kernel_offsets( |
| 162 | + 2, voxels.s, 1, device=points.F.device) |
| 163 | + old_hash = F.sphash( |
| 164 | + torch.cat([ |
| 165 | + torch.floor(points.C[:, :3] / voxels.s[0]).int() * |
| 166 | + voxels.s[0], points.C[:, -1].int().view(-1, 1) |
| 167 | + ], 1), offsets) |
| 168 | + pc_hash = F.sphash(voxels.C.to(points.F.device)) |
| 169 | + idx_query = F.sphashquery(old_hash, pc_hash) |
| 170 | + weights = F.calc_ti_weights( |
| 171 | + points.C, idx_query, |
| 172 | + scale=voxels.s[0]).transpose(0, 1).contiguous() |
| 173 | + idx_query = idx_query.transpose(0, 1).contiguous() |
| 174 | + if nearest: |
| 175 | + weights[:, 1:] = 0. |
| 176 | + idx_query[:, 1:] = -1 |
| 177 | + new_features = F.spdevoxelize(voxels.F, idx_query, weights) |
| 178 | + new_tensor = PointTensor( |
| 179 | + new_features, |
| 180 | + points.C, |
| 181 | + idx_query=points.idx_query, |
| 182 | + weights=points.weights) |
| 183 | + new_tensor.additional_features = points.additional_features |
| 184 | + new_tensor.idx_query[voxels.s] = idx_query |
| 185 | + new_tensor.weights[voxels.s] = weights |
| 186 | + points.idx_query[voxels.s] = idx_query |
| 187 | + points.weights[voxels.s] = weights |
| 188 | + else: |
| 189 | + new_features = F.spdevoxelize(voxels.F, |
| 190 | + points.idx_query.get(voxels.s), |
| 191 | + points.weights.get(voxels.s)) |
| 192 | + new_tensor = PointTensor( |
| 193 | + new_features, |
| 194 | + points.C, |
| 195 | + idx_query=points.idx_query, |
| 196 | + weights=points.weights) |
| 197 | + new_tensor.additional_features = points.additional_features |
| 198 | + |
| 199 | + if point_transform is not None: |
| 200 | + new_tensor.F = new_tensor.F + point_transform(points.F) |
| 201 | + |
| 202 | + return new_tensor |
| 203 | + |
| 204 | + def point_to_voxel(self, voxels: SparseTensor, |
| 205 | + points: PointTensor) -> SparseTensor: |
| 206 | + """Feed point features to voxels. |
| 207 | +
|
| 208 | + Args: |
| 209 | + voxels (SparseTensor): Input voxels. |
| 210 | + points (PointTensor): Input points. |
| 211 | +
|
| 212 | + Returns: |
| 213 | + SparseTensor: Voxels with new features. |
| 214 | + """ |
| 215 | + if points.additional_features is None or \ |
| 216 | + points.additional_features.get('idx_query') is None or \ |
| 217 | + points.additional_features['idx_query'].get(voxels.s) is None: |
| 218 | + pc_hash = F.sphash( |
| 219 | + torch.cat([ |
| 220 | + torch.floor(points.C[:, :3] / voxels.s[0]).int() * |
| 221 | + voxels.s[0], points.C[:, -1].int().view(-1, 1) |
| 222 | + ], 1)) |
| 223 | + sparse_hash = F.sphash(voxels.C) |
| 224 | + idx_query = F.sphashquery(pc_hash, sparse_hash) |
| 225 | + counts = F.spcount(idx_query.int(), voxels.C.shape[0]) |
| 226 | + points.additional_features['idx_query'][voxels.s] = idx_query |
| 227 | + points.additional_features['counts'][voxels.s] = counts |
| 228 | + else: |
| 229 | + idx_query = points.additional_features['idx_query'][voxels.s] |
| 230 | + counts = points.additional_features['counts'][voxels.s] |
| 231 | + |
| 232 | + inserted_features = F.spvoxelize(points.F, idx_query, counts) |
| 233 | + new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s) |
| 234 | + new_tensor.cmaps = voxels.cmaps |
| 235 | + new_tensor.kmaps = voxels.kmaps |
| 236 | + |
| 237 | + return new_tensor |
0 commit comments