diff --git a/pcdet/models/backbones_3d/__init__.py b/pcdet/models/backbones_3d/__init__.py index 0a25c626a..dceb5c3b4 100644 --- a/pcdet/models/backbones_3d/__init__.py +++ b/pcdet/models/backbones_3d/__init__.py @@ -6,7 +6,7 @@ from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D from .spconv_unet import UNetV2 from .dsvt import DSVT - +from .voxel_mamba import Voxel_Mamba __all__ = { 'VoxelBackBone8x': VoxelBackBone8x, 'UNetV2': UNetV2, @@ -19,4 +19,5 @@ 'PillarBackBone8x': PillarBackBone8x, 'PillarRes18BackBone8x': PillarRes18BackBone8x, 'DSVT': DSVT, + 'Voxel_Mamba': Voxel_Mamba, } diff --git a/pcdet/models/backbones_3d/voxel_mamba.py b/pcdet/models/backbones_3d/voxel_mamba.py new file mode 100644 index 000000000..96a7c90fc --- /dev/null +++ b/pcdet/models/backbones_3d/voxel_mamba.py @@ -0,0 +1,449 @@ +import torch +import torch.nn as nn + +import math +from functools import partial +from mamba_ssm.models.mixer_seq_simple import create_block +from ..model_utils.voxel_mamba_utils import get_hilbert_index_3d_mamba_lite + +# try: +# from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +# except ImportError: +# RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + +from ...utils.spconv_utils import replace_feature, spconv +from .spconv_backbone import post_act_block + + +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class Voxel_Mamba(nn.Module): + '''Group-free Voxel Mamba Backbone. + ''' + def __init__(self, model_cfg, grid_size, **kwargs): + super().__init__() + + self.model_cfg = model_cfg + # self.hilbert_input_layer = HilbertCurveInputLayer(self.model_cfg.INPUT_LAYER) + + num_stage = self.model_cfg.num_stage + self.num_stage = num_stage + self.d_model = self.model_cfg.d_model + self.rms_norm = self.model_cfg.rms_norm + self.norm_epsilon = self.model_cfg.norm_epsilon + self.fused_add_norm = self.model_cfg.fused_add_norm + self.device = self.model_cfg.device + self.residual_in_fp32 = self.model_cfg.residual_in_fp32 + self.extra_down = self.model_cfg.extra_down + self.dtype = torch.float32 + initializer_cfg = None + + # for downsampling + self.down_kernel_size = self.model_cfg.down_kernel_size + self.down_stride = self.model_cfg.down_stride + self.num_down = self.model_cfg.num_down + self.down_resolution = self.model_cfg.down_resolution + self.downsample_lvl = self.model_cfg.downsample_lvl + self.norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) + self.sparse_shape = grid_size[::-1] + [1, 0, 0] + + # Build Hilbert tempalte + self.curve_template = {} + self.hilbert_spatial_size = {} + self.load_template(self.model_cfg.INPUT_LAYER.curve_template_path_rank9, 9) + self.load_template(self.model_cfg.INPUT_LAYER.curve_template_path_rank8, 8) + self.load_template(self.model_cfg.INPUT_LAYER.curve_template_path_rank7, 7) + + factory_kwargs = {"device": self.device, "dtype": self.dtype} + + block_list = [] + for i, num_s in enumerate(num_stage): + for ns in range(num_s): + block_list.append( + DSB(self.d_model, ssm_cfg=None, norm_epsilon=self.norm_epsilon, rms_norm=self.rms_norm, + down_kernel_size=self.down_kernel_size[i], down_stride=self.down_stride[i], num_down=self.num_down[i], + norm_fn=self.norm_fn, indice_key=f'stem{i}_layer{ns}', sparse_shape=self.sparse_shape, hilbert_config=self.model_cfg.INPUT_LAYER, + downsample_lvl=self.downsample_lvl[i], + down_resolution=self.down_resolution[i], residual_in_fp32=True, fused_add_norm=self.fused_add_norm, + device=self.device, dtype=self.dtype) + ) + self.block_list = nn.ModuleList(block_list) + + downZ_list = [] + for i in range(len(num_stage)): + downZ_list.append( + spconv.SparseSequential( + spconv.SparseConv3d(self.d_model, self.d_model, (3, 1, 1), stride=(2, 1, 1), padding=0, bias=False, indice_key=f'downz_{i}'), + self.norm_fn(self.d_model), + nn.ReLU(),) + ) + self.downZ_list = nn.ModuleList(downZ_list) + + self.conv_out = spconv.SparseSequential( + spconv.SparseConv3d(self.d_model, self.d_model, (3, 1, 1), stride=(2, 1, 1), padding=0, bias=False, indice_key=f'final_conv_out'), + self.norm_fn(self.d_model), + nn.ReLU(),) + + self.pos_embed = nn.Sequential( + nn.Linear(9, self.d_model), + nn.BatchNorm1d(self.d_model), + nn.ReLU(inplace=True), + nn.Linear(self.d_model, self.d_model), + ) + + self._reset_parameters() + self.apply( + partial( + _init_weights, + n_layer=sum(num_stage), + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + + self.output_shape = self.model_cfg.output_shape + self.num_point_features = self.model_cfg.conv_out_channel + + def load_template(self, path, rank): + template = torch.load(path) + if isinstance(template, dict): + self.curve_template[f'curve_template_rank{rank}'] = template['data'].reshape(-1) + self.hilbert_spatial_size[f'curve_template_rank{rank}'] = template['size'] + else: + self.curve_template[f'curve_template_rank{rank}'] = template.reshape(-1) + spatial_size = 2 ** rank + self.hilbert_spatial_size[f'curve_template_rank{rank}'] = (1, spatial_size, spatial_size) #[z, y, x] + + def forward(self, batch_dict): + ''' + Args: + bacth_dict (dict): + The dict contains the following keys + - voxel_features (Tensor[float]): Voxel features after VFE. Shape of (N, d_model[0]), + where N is the number of input voxels. + - voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels. + Each row is (batch_id, z, y, x). + - ... + + Returns: + bacth_dict (dict): + The dict contains the following keys + - pillar_features (Tensor[float]): + - voxel_coords (Tensor[int]): + - ... + ''' + + # with self.timer.timing('3d_backbone'): + debug = False + batch_size = batch_dict['voxel_coords'][:, 0].max().item() + 1 + feat_3d = batch_dict['voxel_features'] + voxel_coords = batch_dict['voxel_coords'] + with torch.no_grad(): + for name, _ in self.curve_template.items(): + self.curve_template[name] = self.curve_template[name].to(voxel_coords.device) + + down_sparse_shape = self.sparse_shape + for i, block in enumerate(self.block_list): + + feat_3d, voxel_coords = block(feat_3d, voxel_coords, batch_size, down_sparse_shape, self.curve_template, self.hilbert_spatial_size, self.pos_embed, i, debug) + + if (i > 0) and (i % 2 == 1): + xd = spconv.SparseConvTensor( + features=feat_3d, + indices=voxel_coords.int(), + spatial_shape=down_sparse_shape, + batch_size=batch_size + ) + + if i == self.extra_down: + xd = self.conv_out(xd) + + xd = self.downZ_list[i//2](xd) + + feat_3d = xd.features + voxel_coords = xd.indices + down_sparse_shape = xd.spatial_shape + + + if self.training and torch.isnan(feat_3d).any().item(): + replacement_value = 0.0 + feat_3d = torch.where(torch.isnan(feat_3d), replacement_value, feat_3d) + + batch_dict['voxel_coords'] = voxel_coords + batch_dict['pillar_features'] = batch_dict['voxel_features'] = feat_3d + + return batch_dict + + def _reset_parameters(self): + for name, p in self.named_parameters(): + if p.dim() > 1 and 'scaler' not in name: + nn.init.xavier_uniform_(p) + + +class DSB(nn.Module): + ''' Dual-scale State Space Models Block + ''' + + def __init__(self, + d_model, + ssm_cfg, + norm_epsilon, + rms_norm, + down_kernel_size, + down_stride, + num_down, + norm_fn, + indice_key, + sparse_shape, + hilbert_config, + downsample_lvl, + down_resolution=True, + residual_in_fp32=True, + fused_add_norm=True, + device=None, + dtype=None): + super().__init__() + + # ssm_cfg = {} + factory_kwargs = {'device': device, 'dtype':dtype} + + # mamba layer + mamba_encoder_1 = create_block( + d_model=d_model, + ssm_cfg=ssm_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=0, + **factory_kwargs, + ) + + mamba_encoder_2 = create_block( + d_model=d_model, + ssm_cfg=ssm_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=1, + **factory_kwargs, + ) + + self.mamba_encoder_list = nn.ModuleList([mamba_encoder_1, mamba_encoder_2]) + + # downsampling operation # + self.conv_encoder = nn.ModuleList() + for idx in range(len(down_stride)): + self.conv_encoder.append( + DownSp(d_model, down_kernel_size[idx], down_stride[idx], num_down[idx], norm_fn, f"{indice_key}_{idx}")) + + # upsampling operation # + downsample_times = len(down_stride[1:]) + self.conv_decoder = nn.ModuleList() + self.conv_decoder_norm = nn.ModuleList() + for idx, kernel_size in enumerate(down_kernel_size[1:]): + if down_resolution: + self.conv_decoder.append( + post_act_block( + d_model, d_model, kernel_size, norm_fn=norm_fn, conv_type='inverseconv', + indice_key=f'spconv_{indice_key}_{downsample_times - idx}')) + self.conv_decoder_norm.append(norm_fn(d_model)) + else: + self.conv_decoder.append( + post_act_block( + d_model, d_model, kernel_size, norm_fn=norm_fn, conv_type='subm', + indice_key=f'{indice_key}_{downsample_times - idx}')) + self.conv_decoder_norm.append(norm_fn(d_model)) + + self.sparse_shape = sparse_shape + self.downsample_lvl = downsample_lvl + + norm_cls = partial( + nn.LayerNorm, eps=norm_epsilon, **factory_kwargs + ) + self.norm = norm_cls(d_model) + self.norm_back = norm_cls(d_model) + + def forward( + self, + voxel_features, + voxel_coords, + batch_size, + curt_spatial_shape, + curve_template, + hilbert_spatial_size, + pos_embed, + num_stage, + debug=False, + ): + + mamba_layer1 = self.mamba_encoder_list[0] + mamba_layer2 = self.mamba_encoder_list[1] + + x = spconv.SparseConvTensor( + features=voxel_features, + indices=voxel_coords.int(), + spatial_shape=curt_spatial_shape, + batch_size=batch_size + ) + + features = [] + for conv in self.conv_encoder: + x = conv(x) + features.append(x) + + x_s1 = features[0] + x_s2 = features[1] + feats_s2 = features[1].features + coords_s2 = features[1].indices + feats_s1 = features[0].features + coords_s1 = features[0].indices + + clvl_cruve_template_s1 = curve_template['curve_template_rank9'] + clvl_hilbert_spatial_size_s1 = hilbert_spatial_size['curve_template_rank9'] + index_info_s1 = get_hilbert_index_3d_mamba_lite(clvl_cruve_template_s1, coords_s1, batch_size, x_s1.spatial_shape[0], \ + clvl_hilbert_spatial_size_s1, shift=(num_stage, num_stage, num_stage)) + inds_curt_to_next_s1 = index_info_s1['inds_curt_to_next'] + inds_next_to_curt_s1 = index_info_s1['inds_next_to_curt'] + + clvl_cruve_template_s2 = curve_template[self.downsample_lvl] + clvl_hilbert_spatial_size_s2 = hilbert_spatial_size[self.downsample_lvl] + index_info_s2 = get_hilbert_index_3d_mamba_lite(clvl_cruve_template_s2, coords_s2, batch_size, x_s2.spatial_shape[0], + clvl_hilbert_spatial_size_s2, shift=(num_stage, num_stage, num_stage)) + inds_curt_to_next_s2 = index_info_s2['inds_curt_to_next'] + inds_next_to_curt_s2 = index_info_s2['inds_next_to_curt'] + + new_features = [] + # Low Resolution + out_feat_3d_s2 = torch.zeros_like(feats_s2) + out_feat_3d_s1 = torch.zeros_like(feats_s1) + + # Pos Embedding + pos_embed_coords_s2 = torch.zeros([coords_s2.shape[0], 9], device=coords_s2.device, dtype=torch.float32) + pos_embed_coords_s2[:, 0] = coords_s2[:, 1] / x_s2.spatial_shape[0] + pos_embed_coords_s2[:, 1:3] = (coords_s2[:, 2:] // 12) / (x_s2.spatial_shape[1]//12 + 1) + pos_embed_coords_s2[:, 3:5] = (coords_s2[:, 2:] % 12) / 12.0 + pos_embed_coords_s2[:, 5:7] = ((coords_s2[:, 2:] + 6) // 12) / (x_s2.spatial_shape[1]//12 + 1) + pos_embed_coords_s2[:, 7:9] = ((coords_s2[:, 2:] + 6) % 12) / 12.0 + pos_embed_s2 = pos_embed(pos_embed_coords_s2.float()) + + feats_s2 = feats_s2 + pos_embed_s2 + + # Borward SSMs + for i in range(batch_size): + b_mask_m2 = coords_s2[:, 0] == i + feat_m2 = feats_s2[b_mask_m2][inds_curt_to_next_s2[i]][None] + out_feat_m2 = mamba_layer1(feat_m2, None) + out_feat_3d_s2[b_mask_m2] = (out_feat_m2[0]).squeeze(0)[inds_next_to_curt_s2[i]] + + x_s2 = replace_feature(x_s2, self.norm(out_feat_3d_s2)) + + # Fackward SSMs + pos_embed_coords_s1 = torch.zeros([coords_s1.shape[0], 9], device=coords_s1.device, dtype=torch.float32) + pos_embed_coords_s1[:, 0] = coords_s1[:, 1] / x_s1.spatial_shape[0] + pos_embed_coords_s1[:, 1:3] = (coords_s1[:, 2:] // 12) / (x_s1.spatial_shape[1]//12 + 1) + pos_embed_coords_s1[:, 3:5] = (coords_s1[:, 2:] % 12) / 12.0 + pos_embed_coords_s1[:, 5:7] = ((coords_s1[:, 2:] + 6) // 12) / (x_s1.spatial_shape[1]//12 + 1) + pos_embed_coords_s1[:, 7:9] = ((coords_s1[:, 2:] + 6) % 12) / 12.0 + pos_embed_s1 = pos_embed(pos_embed_coords_s1.float()) + + feats_s1 = feats_s1 + pos_embed_s1 + for i in range(batch_size): + b_mask_m1 = coords_s1[:, 0] == i + feat_m1 = feats_s1[b_mask_m1][inds_curt_to_next_s1[i]][None] + feat_back = feat_m1.flip(1) + out_feat_back = mamba_layer2(feat_back, None) + out_feat_3d_s1[b_mask_m1] = (out_feat_back[0]).squeeze(0).flip(0)[inds_next_to_curt_s1[i]] + + x_s1 = replace_feature(x_s1, self.norm_back(out_feat_3d_s1)) + + # new_features.append(features[0]) + new_features.append(x_s1) + new_features.append(x_s2) + + x = x_s2 + + for deconv, norm, up_x in zip(self.conv_decoder, self.conv_decoder_norm, new_features[:-1][::-1]): + x = deconv(x) + x = replace_feature(x, x.features + up_x.features + features[0].features) + x = replace_feature(x, norm(x.features)) + + return x.features, x.indices + + +##### downsampling operation ##### + +class Sparse1ConvBlock(spconv.SparseModule): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, bias=None, norm_fn=None, downsample=None, indice_key=None): + super(Sparse1ConvBlock, self).__init__() + + assert norm_fn is not None + if bias is None: + bias = norm_fn is not None + self.conv1 = spconv.SubMConv3d( + inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key + ) + self.bn1 = norm_fn(planes) + self.relu = nn.ReLU() + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = replace_feature(out, self.bn1(out.features)) + out = replace_feature(out, out.features + identity.features) + out = replace_feature(out, self.relu(out.features)) + + return out + + +class DownSp(spconv.SparseModule): + + def __init__(self, dim, kernel_size, stride, num_down, norm_fn, indice_key): + super(DownSp, self).__init__() + + first_block = post_act_block( + dim, dim, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, + norm_fn=norm_fn, indice_key=f'spconv_{indice_key}', conv_type='spconv') + + block_list = [first_block if stride > 1 else nn.Identity()] + for _ in range(num_down): + block_list.append( + Sparse1ConvBlock(dim, dim, norm_fn=norm_fn, indice_key=indice_key)) + + self.blocks = spconv.SparseSequential(*block_list) + + def forward(self, x): + return self.blocks(x) diff --git a/pcdet/models/model_utils/voxel_mamba_utils.py b/pcdet/models/model_utils/voxel_mamba_utils.py new file mode 100644 index 000000000..1d1cd18f3 --- /dev/null +++ b/pcdet/models/model_utils/voxel_mamba_utils.py @@ -0,0 +1,64 @@ +import torch + +def get_hilbert_index_3d_mamba_lite(template, coors, batch_size, z_dim, hilbert_spatial_size, shift=(0, 0, 0), debug=True): + ''' + coors: (b, z, y, x) + shift: (shift_z, shift_y, shift_x) + hilbert_spatial_size: [z, y, x] + ''' + # new 3D + hil_size_z, hil_size_y, hil_size_x = hilbert_spatial_size + + x = coors[:, 3] + shift[2] + y = coors[:, 2] + shift[1] + z = coors[:, 1] + shift[0] + + flat_coors = (z * hil_size_y * hil_size_x + y * hil_size_x + x).long() + hil_inds = template[flat_coors].long() + + inds_curt_to_next = {} + inds_next_to_curt = {} + for i in range(batch_size): + batch_mask = coors[:, 0] == i + inds_curt_to_next[i] = torch.argsort(hil_inds[batch_mask]) + inds_next_to_curt[i] = torch.argsort(inds_curt_to_next[i]) + # inds_next_to_curt[name] = torch.argsort(inds_curt_to_next[name]) + + index_info = {} + index_info['inds_curt_to_next'] = inds_curt_to_next + index_info['inds_next_to_curt'] = inds_next_to_curt + + return index_info + + + +def get_hilbert_index_2d_mamba_lite(template, coors, batch_size, hilbert_spatial_size, shift=(0, 0), debug=True): + ''' + coors: (b, z, y, x) + shift: (shift_z, shift_y, shift_x) + hilbert_spatial_size: [z, y, x] + ''' + # new 3D + _, hil_size_y, hil_size_x = hilbert_spatial_size + + x = coors[:, 3] + shift[1] + y = coors[:, 2] + shift[0] + # z = coors[:, 1] + shift[0] + + # flat_coors = (z * hil_size_y * hil_size_x + y * hil_size_x + x).long() + flat_coors = (y * hil_size_x + x).long() + hil_inds = template[flat_coors].long() + + inds_curt_to_next = {} + inds_next_to_curt = {} + for i in range(batch_size): + batch_mask = coors[:, 0] == i + inds_curt_to_next[i] = torch.argsort(hil_inds[batch_mask]) + inds_next_to_curt[i] = torch.argsort(inds_curt_to_next[i]) + # inds_next_to_curt[name] = torch.argsort(inds_curt_to_next[name]) + + index_info = {} + index_info['inds_curt_to_next'] = inds_curt_to_next + index_info['inds_next_to_curt'] = inds_next_to_curt + + return index_info \ No newline at end of file diff --git a/tools/cfgs/waymo_models/voxel_mamba.yaml b/tools/cfgs/waymo_models/voxel_mamba.yaml new file mode 100644 index 000000000..93aefcd11 --- /dev/null +++ b/tools/cfgs/waymo_models/voxel_mamba.yaml @@ -0,0 +1,195 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml + SAMPLED_INTERVAL: {'train': 1, 'test': 1} + DATA_SPLIT: {'train': train,'test': val} + POINT_CLOUD_RANGE: [-74.88, -74.88, -2, 74.88, 74.88, 4.0] + POINTS_TANH_DIM: [3, 4] + DATA_AUGMENTOR: + DISABLE_AUG_LIST: ['placeholder'] + AUG_CONFIG_LIST: + - NAME: gt_sampling + USE_ROAD_PLANE: False + DB_INFO_PATH: + - waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl + + USE_SHARED_MEMORY: False # set it to True to speed up (it costs about 15GB shared memory) + DB_DATA_PATH: + - waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global.npy + + BACKUP_DB_INFO: + # if the above DB_INFO cannot be found, will use this backup one + DB_INFO_PATH: waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0.pkl + DB_DATA_PATH: waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_global.npy + NUM_POINT_FEATURES: 6 + + PREPARE: { + filter_by_min_points: ['Vehicle:5', 'Pedestrian:10', 'Cyclist:10'], + filter_by_difficulty: [-1], + } + + SAMPLE_GROUPS: ['Vehicle:15', 'Pedestrian:10', 'Cyclist:10'] + NUM_POINT_FEATURES: 5 + REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] + LIMIT_WHOLE_SCENE: True + + - NAME: random_world_flip + ALONG_AXIS_LIST: ['x', 'y'] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [0.95, 1.05] + + - NAME: random_world_translation + NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5] + + DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': True + } + + - NAME: transform_points_to_voxels_placeholder + VOXEL_SIZE: [ 0.32, 0.32, 0.1875] + +MODEL: + NAME: CenterPoint + + VFE: + NAME: DynamicVoxelVFE + WITH_DISTANCE: False + USE_ABSLOTE_XYZ: True + USE_NORM: True + NUM_FILTERS: [ 128, 128 ] + + BACKBONE_3D: + NAME: Voxel_Mamba + INPUT_LAYER: + curve_template_path_rank9: '../data/hilbert/curve_template_3d_rank_9.pth' + curve_template_path_rank8: '../data/hilbert/curve_template_3d_rank_8.pth' + curve_template_path_rank7: '../data/hilbert/curve_template_3d_rank_7.pth' + + # for mamba + d_model: 128 + fused_add_norm: True + rms_norm: True + norm_epsilon: 0.00001 + residual_in_fp32: True + device: 'cuda' + dtype: torch.float32 + output_shape: 468 + conv_out_channel: 128 + residual_in_fp32: True + + # for backward branch + extra_down: 5 + num_stage: [2, 2, 2] + num_down: [[0, 1], [0, 1], [0, 1]] + down_stride: [[1, 1], [1, 2], [1, 4]] + down_kernel_size: [[3, 3], [3, 3], [3, 5]] + down_resolution: [False, True, True] + downsample_lvl: ['curve_template_rank9', 'curve_template_rank8', 'curve_template_rank7'] + + MAP_TO_BEV: + NAME: PointPillarScatter3d + INPUT_SHAPE: [468, 468, 1] + NUM_BEV_FEATURES: 128 + + BACKBONE_2D: + NAME: BaseBEVResBackbone + LAYER_NUMS: [ 1, 2, 2 ] + LAYER_STRIDES: [ 1, 2, 2 ] + NUM_FILTERS: [ 128, 128, 256 ] + UPSAMPLE_STRIDES: [ 1, 2, 4 ] + NUM_UPSAMPLE_FILTERS: [ 128, 128, 128 ] + + DENSE_HEAD: + NAME: CenterHead + CLASS_AGNOSTIC: False + + CLASS_NAMES_EACH_HEAD: [ + ['Vehicle', 'Pedestrian', 'Cyclist'] + ] + + SHARED_CONV_CHANNEL: 64 + USE_BIAS_BEFORE_NORM: False + NUM_HM_CONV: 2 + + BN_EPS: 0.001 + BN_MOM: 0.01 + SEPARATE_HEAD_CFG: + HEAD_ORDER: ['center', 'center_z', 'dim', 'rot'] + HEAD_DICT: { + 'center': {'out_channels': 2, 'num_conv': 2}, + 'center_z': {'out_channels': 1, 'num_conv': 2}, + 'dim': {'out_channels': 3, 'num_conv': 2}, + 'rot': {'out_channels': 2, 'num_conv': 2}, + 'iou': {'out_channels': 1, 'num_conv': 2}, + } + + TARGET_ASSIGNER_CONFIG: + FEATURE_MAP_STRIDE: 1 + NUM_MAX_OBJS: 500 + GAUSSIAN_OVERLAP: 0.1 + MIN_RADIUS: 2 + + IOU_REG_LOSS: True + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + SCORE_THRESH: 0.1 + POST_CENTER_LIMIT_RANGE: [-80, -80, -10.0, 80, 80, 10.0] + MAX_OBJ_PER_SAMPLE: 500 + USE_IOU_TO_RECTIFY_SCORE: True + IOU_RECTIFIER: [0.68, 0.71, 0.65] + + NMS_CONFIG: + NMS_TYPE: class_specific_nms + NMS_THRESH: [0.75, 0.6, 0.55] + NMS_PRE_MAXSIZE: [4096, 4096, 4096] + NMS_POST_MAXSIZE: [500, 500, 500] + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + EVAL_METRIC: waymo + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 3 + NUM_EPOCHS: 24 + + OPTIMIZER: adam_onecycle + LR: 0.0025 # + WEIGHT_DECAY: 0.05 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.1 + DIV_FACTOR: 100 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 3 + LOSS_SCALE_FP16: 32.0 + +HOOK: + DisableAugmentationHook: + DISABLE_AUG_LIST: ['gt_sampling','random_world_flip','random_world_rotation','random_world_scaling', 'random_world_translation'] + NUM_LAST_EPOCHS: 1 \ No newline at end of file diff --git a/tools/process_tools/create_hilbert_curve.py b/tools/process_tools/create_hilbert_curve.py new file mode 100644 index 000000000..7e7895947 --- /dev/null +++ b/tools/process_tools/create_hilbert_curve.py @@ -0,0 +1,219 @@ +"""edited: +~change N into current at #flag1 +pep-8, variable name +assert +~deepcopy the input tensor "xy" (otherwise xy will be cleared!!) +~check: print(xy.sum() == 0) in the function "convert" +draw a curve +""" + +import torch +import pickle +from ipdb import set_trace + + +def convert_2d(xy:torch.Tensor, n:int) -> torch.Tensor: + """xy = Tensor([(x1, y1), (x2, y2), ...]), the size of the image is 2^n, + the output is idx = Tensor([idx1, idx2, ...]) + """ + N = 2 ** n + assert 0 <= xy.max() < N and isinstance(xy.sum().item(), int), "wrong input!" + xy = xy.clone().detach() + current_order = N >> 1 # the current order to be dealt with + idx = torch.zeros(xy.size()[0], dtype=torch.int) + + while(current_order > 0): # orders go from high to low + + xy_bit = xy & current_order # get the bit of x, y at the current order + y_cur_order = xy_bit[:, 1] > 0 + x_cur_order = xy_bit[:, 0] > 0 # the coordinate in the current order, denoted by the bool type. + + idx += current_order * current_order * (1 * x_cur_order + 1 * y_cur_order + 2 * (x_cur_order & (~y_cur_order))) + # the formula in "()": transforms (0,0) to 0, ..., (1,0) to 3 + + judge_reverse = (~y_cur_order) & x_cur_order # if the coordinate of the current order is (1, 0), reverse it + judge_rotate = ~y_cur_order # if the coordinate y of the current order is 1, rotate it + + xy[:, 0][judge_reverse] = current_order*2 - 1 - xy[:, 0][judge_reverse] #flag1 + xy[:, 1][judge_reverse] = current_order*2 - 1 - xy[:, 1][judge_reverse] + xy[:, 0][judge_rotate], xy[:,1][judge_rotate] = xy[:, 1][judge_rotate], xy[:, 0][judge_rotate] + + xy[:, 0][x_cur_order] -= current_order + xy[:, 1][y_cur_order] -= current_order # or : & ~(1 << current_order) + current_order >>= 1 # go to the next order + + return idx + +def get_bit_data(point_locs:torch.Tensor, bit:int) -> torch.Tensor: + """point_locs_dim.size(): (input_size, num_dims) + """ + bit_data = (point_locs >> bit) & 1 + return bit_data + +def convert_to_index(point_locs:torch.Tensor, num_dims:int, num_bits:int) -> torch.Tensor: + """Decode a tensor of locations in a cube into a Hilbert integer. + Params: + ------- + point_locs - Locations in a cube of num_dims dimensions, in + which each dimension runs from 0 to 2**num_bits-1. + The last dimension of the input has size num_dims. + + num_dims - The dimensionality of the cube. + + num_bits - The number of bits for each dimension. + + Returns: + -------- + The output is an tensor of int64 integers with the same shape as the + input, excluding the last dimension, which needs to be num_dims. + """ + + # check that the locations are valid. + if point_locs.shape[-1] != num_dims: + raise ValueError( + ''' + The shape of locs was surprising in that the last dimension was of size + %d, but num_dims=%d. These need to be equal. + ''' % (point_locs.shape[-1], num_dims) + ) + + if num_dims*num_bits >= 64: + raise ValueError( + ''' + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into an int64. Are you sure you need that many points on your Hilbert + curve? + ''' % (num_dims, num_bits, num_dims*num_bits) + ) + + # follow the device on which the input is located + if point_locs.device.type=='cuda': + device = point_locs.device + else: + device = torch.device('cpu') + + # deepcopy the input + point_locs = point_locs.clone().detach() + + # As num_dims*num_bits < 64, coordinates can be denoted by int32. + point_locs = point_locs.type(torch.int32) + + fig_size = 1 << num_bits + # Iterate forwards through the bits. + bit_pow = fig_size >> 1 + while bit_pow > 1: + + # Iterate forwards through the dimensions. + mask = bit_pow -1 + for dim in range(num_dims): + + judge_invert = (point_locs[:, dim] & bit_pow) > 0 + # Where this bit is on, invert the 0 dimension for lower bits. + point_locs[:, 0][judge_invert] ^= mask + + judge_exchange = ~judge_invert + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = (point_locs[:, 0] ^ point_locs[:, dim]) & mask + point_locs[:, 0][judge_exchange] ^= to_flip[judge_exchange] + point_locs[:, dim][judge_exchange] ^= to_flip[judge_exchange] + + bit_pow >>= 1 + + # Combine dims into one Gray code + gray_code = torch.zeros(point_locs.size(0), dtype=torch.int64).to(device) + for bit_current in range(num_bits): + + bit_data = get_bit_data(point_locs, bit_current) + + for dim in range(num_dims): + # send bit_data to the correct position + dim_shift = num_dims - 1 - dim # lower dim more significant + gray_code += bit_data[:, dim] << (bit_current * num_dims + dim_shift) + + + # Convert Gray code back to binary form of the index. + shift = 2 ** (int(torch.ceil(torch.log2(torch.tensor(num_dims*num_bits)))) - 1) + while shift > 0: + gray_code ^= (gray_code >> shift) + shift >>= 1 + + point_indices = gray_code + return point_indices + + +if __name__ == '__main__': + + + # Generate Hilbert Curves + # (bit=9, z_max=41) / (bit=8, z_max=17) / (bit=7, z_max=9) + dim = 3 # voxel-based dim is 3, pillar-based dim is 2 + # bit = 7 + # N = 2 ** bit # N must larger than BEV resolution + device = torch.device('cuda') + + # z_max = 33 # for our setting, downstride = 1/2/4 | z_max = 33/17/9 for Waymo, z_max = 41/10/5 for nuScene + # z_max = 9 # for our setting, downstride = 1/2/4 | z_max = 41/17/9 + # use_size = N * N * z_max # Truncate the curve, z_max must be larger than the Z-axis resolution + + # generate for original resolution + bit = 9 + N = 2 ** bit # N must larger than BEV resolution + z_max = 41 + use_size = N * N * z_max + + if dim == 3: + point_locs = torch.tensor([(z, y, x) for z in range(N) for y in range(N) for x in range(N)]).to(device) + elif dim == 2: + point_locs = torch.tensor([(y, x) for y in range(N) for x in range(N)]).to(device) + else: + raise ValueError( + ''' + The space dimension can only be 2 or 3 in the real world! + ''' + ) + + curve_index = convert_to_index(point_locs, dim, bit).cpu() + curve_index_used = curve_index[:use_size] + torch.save(curve_index_used, f'./data/hilbert/curve_template_3d_rank_{bit}.pth') + + # generate for downstride = 2 + bit = 8 + N = 2 ** bit # N must larger than BEV resolution + z_max = 17 + use_size = N * N * z_max + + if dim == 3: + point_locs = torch.tensor([(z, y, x) for z in range(N) for y in range(N) for x in range(N)]).to(device) + elif dim == 2: + point_locs = torch.tensor([(y, x) for y in range(N) for x in range(N)]).to(device) + else: + raise ValueError( + ''' + The space dimension can only be 2 or 3 in the real world! + ''' + ) + + curve_index = convert_to_index(point_locs, dim, bit).cpu() + curve_index_used = curve_index[:use_size] + torch.save(curve_index_used, f'./data/hilbert/curve_template_3d_rank_{bit}.pth') + + # generate for downstride = 4 + bit = 7 + N = 2 ** bit # N must larger than BEV resolution + z_max = 9 + use_size = N * N * z_max + + if dim == 3: + point_locs = torch.tensor([(z, y, x) for z in range(N) for y in range(N) for x in range(N)]).to(device) + elif dim == 2: + point_locs = torch.tensor([(y, x) for y in range(N) for x in range(N)]).to(device) + else: + raise ValueError( + ''' + The space dimension can only be 2 or 3 in the real world! + ''' + ) + + curve_index = convert_to_index(point_locs, dim, bit).cpu() + curve_index_used = curve_index[:use_size] + torch.save(curve_index_used, f'./data/hilbert/curve_template_3d_rank_{bit}.pth') \ No newline at end of file